TC-Bot 阅读笔记

文章目录

整体代码执行流程

  • load用于训练的用户目标,并分成训练集、验证集和测试集
  • 初始化强化学习agent
  • 初始化User Simulator
  • load训练好的NLG模型,分别给agent 和user simulator配置上
  • load训练好的NLU模型,分别给agent 和user simulator配置上
  • 初始化对话管理
  • 启动Conversation Simulation
  • 热启动仿真(warm_start_simulation)
  • 开始强化学习训练

整体流程

image-20180416111029962

warm_start_simulation

1
2
3
4
5
def warm_start_simulation():
热启动仿真
开始仿真训练
当得到的经验大于给定的经验阈值时,停止训练
计算一下成功率、平均奖励、平均轮次,并打印

参数列表

参数 说明
dict_path path to the .json dictionary file
movie_kb_path path to the movie kb .json file
act_set path to dia act set; none for loading from labeled file
slot_set path to slot set; none for loading from labeled file
goal_file_path a list of user goals
diaact_nl_pairs path to the pre-defined dia_act&NL pairs
max_turn maximum length of each dialog (default=20, 0=no maximum length)
episodes Total number of episodes to run (default=1)
slot_err_prob the slot err probability
slot_err_mode slot_err_mode: 0 for slot_val only; 1 for three errs
intent_err_prob the intent err probability
agt Select an agent: 0 for a command line input, 1-6 for rule based agents
usr Select a user simulator. 0 is a Frozen user simulator.
epsilon Epsilon to determine stochasticity of epsilon-greedy agent policies

NLG和NLU相关参数列表

参数 说明
nlg_model_path path to model file
nlu_model_path path to the NLU model file
act_level 0 for dia_act level; 1 for NL level
run_mode run_mode: 0 for default NL; 1 for dia_act; 2 for both
auto_suggest 0 for no auto_suggest; 1 for auto_suggest
cmd_input_mode run_mode: 0 for NL; 1 for dia_act

强化学习agent相关参数列表

参数 说明
dqn_hidden_size the hidden size for DQN
batch_size batch size
gamma gamma for DQN
predict_mode predict model for DQN
simulation_epoch_size the size of validation set
warm_start 0: no warm start; 1: warm start for training
warm_start_epochs the number of epochs for warm start
trained_model_path the path for trained model
write_model_dir write model to disk
save_check_point number of epochs for saving model
success_rate_threshold the threshold for success rate
split_fold the number of folders to split the user goal
learning_phase train/test/all; default is all

agent种类

agt agent类型
0 AgentCmd
1 InformAgent
2 RequestAllAgent
3 RandomAgent
4 EchoAgent
5 RequestBasicsAgent
9 AgentDQN

DialogManager部分

用来控制agent和customer(user simulator)的交互

疑问:

  • add NL to Agent Dia_Act 其中NL是??
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class DialogManager:
""" A dialog manager to mediate the interaction between an agent and a customer """

def __init__(self, agent, user, act_set, slot_set, movie_dictionary):
self.agent = agent
self.user = user
self.act_set = act_set
self.slot_set = slot_set
self.state_tracker = StateTracker(act_set, slot_set, movie_dictionary)
self.user_action = None
self.reward = 0
self.episode_over = False

def initialize_episode(self):
自身成员值初始化

def next_turn(self, record_training_data=True):
得到agent的state
得到agent的action
使用agent的action更新state_tracker的状态
打印agent动作槽的回复
获取对话历史最后一条记录
获取user的下一个动作、对话回合结束状态、对话状态
根据对话状态获得奖励
如果对话回合没有结束,使用user的action更新state_tracker的状态、打印user动作
是否记录训练数据

def reward_function(self, dialog_status):
对话奖励函数

def reward_function_without_penalty(self, dialog_status):
对话奖励函数(没有对每轮和失败对话的惩罚)

def print_function(self, agent_action=None, user_action=None):
输出当前对话信息

对话过程

image-20180416110941461

StateTracker部分

对话跟踪器

疑问:

  • 创建类参数act_set和slot_set没有被使用??
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class StateTracker:
""" The state tracker maintains a record of which request slots are filled and which inform slots are filled """

def __init__(self, act_set, slot_set, movie_dictionary):
""" constructor for statetracker takes movie knowledge base and initializes a new episode

Arguments:
act_set -- The set of all acts availavle
slot_set -- The total set of available slots
movie_dictionary -- A representation of all the available movies. Generally this object is accessed via the KBHelper class

Class Variables:
history_vectors -- A record of the current dialog so far in vector format (act-slot, but no values)
history_dictionaries -- A record of the current dialog in dictionary format
current_slots -- A dictionary that keeps a running record of which slots are filled current_slots['inform_slots'] and which are requested current_slots['request_slots'] (but not filed)
action_dimension -- # TODO indicates the dimensionality of the vector representaiton of the action
kb_result_dimension -- A single integer denoting the dimension of the kb_results features.
turn_count -- A running count of which turn we are at in the present dialog
"""
self.movie_dictionary = movie_dictionary
self.initialize_episode()
self.history_vectors = None
self.history_dictionaries = None
self.current_slots = None
self.action_dimension = 10 # TODO REPLACE WITH REAL VALUE
self.kb_result_dimension = 10 # TODO REPLACE WITH REAL VALUE
self.turn_count = 0
self.kb_helper = KBHelper(movie_dictionary)


def initialize_episode(self):
初始化对话状态和跟踪的槽

def dialog_history_vectors(self):
返回向量表示的对话历史(user和agent的动作)

def dialog_history_dictionaries(self):
返回字典表示的对话历史

def kb_results_for_state(self):
根据当前的填槽在数据库中检索,并将检索结果返回

def get_state_for_agent(self):
获取agent当前的状态表示,并返回

def get_suggest_slots_values(self, request_slots):
获取当前请求槽的建议取值

def get_current_kb_results(self):
根据当前的状态搜索符合条件的电影

def update(self, agent_action=None, user_action=None):
根据最近的动作更新状态

NLU部分

NLG部分

Agent类型

  • AgentCmd

    通过交互式地从命令行中读取输入来生成action

  • InformAgent

    通知所有槽,然后发出task complete

  • RequestAllAgent

    请求所有槽,然后发出thanks

  • RandomAgent

    随机选择action

  • EchoAgent

    通知所有请求的槽,然后在用户停止发出请求时发出通知(task complete)。

  • RequestBasicsAgent

    请求所有基本槽,然后发出thanks

  • AgentDQN