本文整理汇总了Python中dataset.DataSet.add方法的典型用法代码示例。如果您正苦于以下问题:Python DataSet.add方法的具体用法?Python DataSet.add怎么用?Python DataSet.add使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类dataset.DataSet
的用法示例。
在下文中一共展示了DataSet.add方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: parse_ifttt_dataset
# 需要导入模块: from dataset import DataSet [as 别名]
# 或者: from dataset.DataSet import add [as 别名]
def parse_ifttt_dataset():
WORD_FREQ_CUT_OFF = 2
annot_file = '/Users/yinpengcheng/Research/SemanticParsing/ifttt/Data/lang.all.txt'
code_file = '/Users/yinpengcheng/Research/SemanticParsing/ifttt/Data/code.all.txt'
data = preprocess_ifttt_dataset(annot_file, code_file)
# build the grammar
grammar = get_grammar([e['parse_tree'] for e in data])
annot_tokens = list(chain(*[e['query_tokens'] for e in data]))
annot_vocab = gen_vocab(annot_tokens, vocab_size=30000, freq_cutoff=WORD_FREQ_CUT_OFF)
logging.info('annot vocab. size: %d', annot_vocab.size)
# we have no terminal tokens in ifttt
all_terminal_tokens = []
terminal_vocab = gen_vocab(all_terminal_tokens, vocab_size=4000, freq_cutoff=WORD_FREQ_CUT_OFF)
# now generate the dataset!
train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'ifttt.train_data')
dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'ifttt.dev_data')
test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'ifttt.test_data')
all_examples = []
can_fully_reconstructed_examples_num = 0
examples_with_empty_actions_num = 0
for entry in data:
idx = entry['id']
query_tokens = entry['query_tokens']
code = entry['code']
parse_tree = entry['parse_tree']
# check if query tokens are valid
query_token_ids = [annot_vocab[token] for token in query_tokens if token not in string.punctuation]
valid_query_tokens_ids = [tid for tid in query_token_ids if tid != annot_vocab.unk]
# remove examples with rare words from train and dev, avoid overfitting
if len(valid_query_tokens_ids) == 0 and 0 <= idx < 77495 + 5171:
continue
rule_list, rule_parents = parse_tree.get_productions(include_value_node=True)
actions = []
can_fully_reconstructed = True
rule_pos_map = dict()
for rule_count, rule in enumerate(rule_list):
if not grammar.is_value_node(rule.parent):
assert rule.value is None
parent_rule = rule_parents[(rule_count, rule)][0]
if parent_rule:
parent_t = rule_pos_map[parent_rule]
else:
parent_t = 0
rule_pos_map[rule] = len(actions)
d = {'rule': rule, 'parent_t': parent_t, 'parent_rule': parent_rule}
action = Action(APPLY_RULE, d)
actions.append(action)
else:
raise RuntimeError('no terminals should be in ifttt dataset!')
if len(actions) == 0:
examples_with_empty_actions_num += 1
continue
example = DataEntry(idx, query_tokens, parse_tree, code, actions,
{'str_map': None, 'raw_code': entry['raw_code']})
if can_fully_reconstructed:
can_fully_reconstructed_examples_num += 1
# train, valid, test splits
if 0 <= idx < 77495:
train_data.add(example)
elif idx < 77495 + 5171:
dev_data.add(example)
else:
test_data.add(example)
all_examples.append(example)
# print statistics
max_query_len = max(len(e.query) for e in all_examples)
max_actions_len = max(len(e.actions) for e in all_examples)
# serialize_to_file([len(e.query) for e in all_examples], 'query.len')
# serialize_to_file([len(e.actions) for e in all_examples], 'actions.len')
logging.info('train_data examples: %d', train_data.count)
logging.info('dev_data examples: %d', dev_data.count)
logging.info('test_data examples: %d', test_data.count)
#.........这里部分代码省略.........
示例2: parse_hs_dataset
# 需要导入模块: from dataset import DataSet [as 别名]
# 或者: from dataset.DataSet import add [as 别名]
#.........这里部分代码省略.........
for rule_count, rule in enumerate(rule_list):
if not grammar.is_value_node(rule.parent):
assert rule.value is None
parent_rule = rule_parents[(rule_count, rule)][0]
if parent_rule:
parent_t = rule_pos_map[parent_rule]
else:
parent_t = 0
rule_pos_map[rule] = len(actions)
d = {'rule': rule, 'parent_t': parent_t, 'parent_rule': parent_rule}
action = Action(APPLY_RULE, d)
actions.append(action)
else:
assert rule.is_leaf
parent_rule = rule_parents[(rule_count, rule)][0]
parent_t = rule_pos_map[parent_rule]
terminal_val = rule.value
terminal_str = str(terminal_val)
terminal_tokens = get_terminal_tokens(terminal_str)
# assert len(terminal_tokens) > 0
for terminal_token in terminal_tokens:
term_tok_id = terminal_vocab[terminal_token]
tok_src_idx = -1
try:
tok_src_idx = query_tokens.index(terminal_token)
except ValueError:
pass
d = {'literal': terminal_token, 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t}
# cannot copy, only generation
# could be unk!
if tok_src_idx < 0 or tok_src_idx >= MAX_QUERY_LENGTH:
action = Action(GEN_TOKEN, d)
if terminal_token not in terminal_vocab:
if terminal_token not in query_tokens:
# print terminal_token
can_fully_reconstructed = False
else: # copy
if term_tok_id != terminal_vocab.unk:
d['source_idx'] = tok_src_idx
action = Action(GEN_COPY_TOKEN, d)
else:
d['source_idx'] = tok_src_idx
action = Action(COPY_TOKEN, d)
actions.append(action)
d = {'literal': '<eos>', 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t}
actions.append(Action(GEN_TOKEN, d))
if len(actions) == 0:
examples_with_empty_actions_num += 1
continue
example = DataEntry(idx, query_tokens, parse_tree, code, actions, {'str_map': None, 'raw_code': entry['raw_code']})
if can_fully_reconstructed:
can_fully_reconstructed_examples_num += 1
# train, valid, test splits
if 0 <= idx < 533:
train_data.add(example)
elif idx < 599:
dev_data.add(example)
else:
test_data.add(example)
all_examples.append(example)
# print statistics
max_query_len = max(len(e.query) for e in all_examples)
max_actions_len = max(len(e.actions) for e in all_examples)
# serialize_to_file([len(e.query) for e in all_examples], 'query.len')
# serialize_to_file([len(e.actions) for e in all_examples], 'actions.len')
logging.info('examples that can be fully reconstructed: %d/%d=%f',
can_fully_reconstructed_examples_num, len(all_examples),
can_fully_reconstructed_examples_num / len(all_examples))
logging.info('empty_actions_count: %d', examples_with_empty_actions_num)
logging.info('max_query_len: %d', max_query_len)
logging.info('max_actions_len: %d', max_actions_len)
train_data.init_data_matrices(max_query_length=70, max_example_action_num=350)
dev_data.init_data_matrices(max_query_length=70, max_example_action_num=350)
test_data.init_data_matrices(max_query_length=70, max_example_action_num=350)
serialize_to_file((train_data, dev_data, test_data),
'data/hs.freq{WORD_FREQ_CUT_OFF}.max_action350.pre_suf.unary_closure.bin'.format(WORD_FREQ_CUT_OFF=WORD_FREQ_CUT_OFF))
return train_data, dev_data, test_data