本文整理汇总了Python中config.max_len方法的典型用法代码示例。如果您正苦于以下问题:Python config.max_len方法的具体用法?Python config.max_len怎么用?Python config.max_len使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类config
的用法示例。
在下文中一共展示了config.max_len方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: pad_trunc_seq
# 需要导入模块: import config [as 别名]
# 或者: from config import max_len [as 别名]
def pad_trunc_seq(x, max_len):
"""Pad or truncate a sequence data to a fixed length.
Args:
x: ndarray, input sequence data.
max_len: integer, length of sequence to be padded or truncated.
Returns:
ndarray, Padded or truncated input sequence data.
"""
L = len(x)
shape = x.shape
if L < max_len:
pad_shape = (max_len - L,) + shape[1:]
pad = np.zeros(pad_shape)
x_new = np.concatenate((x, pad), axis=0)
else:
x_new = x[0:max_len]
return x_new
### Load data & scale data
示例2: build_samples
# 需要导入模块: import config [as 别名]
# 或者: from config import max_len [as 别名]
def build_samples():
word_map_zh = json.load(open('data/WORDMAP_zh.json', 'r'))
word_map_en = json.load(open('data/WORDMAP_en.json', 'r'))
for usage in ['train', 'valid']:
if usage == 'train':
translation_path_en = os.path.join(train_translation_folder, train_translation_en_filename)
translation_path_zh = os.path.join(train_translation_folder, train_translation_zh_filename)
filename = 'data/samples_train.json'
else:
translation_path_en = os.path.join(valid_translation_folder, valid_translation_en_filename)
translation_path_zh = os.path.join(valid_translation_folder, valid_translation_zh_filename)
filename = 'data/samples_valid.json'
print('loading {} texts and vocab'.format(usage))
with open(translation_path_en, 'r') as f:
data_en = f.readlines()
with open(translation_path_zh, 'r') as f:
data_zh = f.readlines()
print('building {} samples'.format(usage))
samples = []
for idx in tqdm(range(len(data_en))):
sentence_zh = data_zh[idx].strip()
seg_list = jieba.cut(sentence_zh)
input_zh = encode_text(word_map_zh, list(seg_list))
sentence_en = data_en[idx].strip().lower()
tokens = [normalizeString(s) for s in nltk.word_tokenize(sentence_en) if len(normalizeString(s)) > 0]
output_en = encode_text(word_map_en, tokens)
if len(input_zh) <= max_len and len(
output_en) <= max_len and UNK_token not in input_zh and UNK_token not in output_en:
samples.append({'input': list(input_zh), 'output': list(output_en)})
with open(filename, 'w') as f:
json.dump(samples, f, indent=4)
print('{} {} samples created at: {}.'.format(len(samples), usage, filename))
示例3: build_samples
# 需要导入模块: import config [as 别名]
# 或者: from config import max_len [as 别名]
def build_samples():
word_map_zh = json.load(open('data/WORDMAP_zh.json', 'r'))
word_map_en = json.load(open('data/WORDMAP_en.json', 'r'))
for usage in ['train', 'valid']:
if usage == 'train':
translation_path_en = os.path.join(train_translation_folder, train_translation_en_filename)
translation_path_zh = os.path.join(train_translation_folder, train_translation_zh_filename)
filename = 'data/samples_train.json'
else:
translation_path_en = os.path.join(valid_translation_folder, valid_translation_en_filename)
translation_path_zh = os.path.join(valid_translation_folder, valid_translation_zh_filename)
filename = 'data/samples_valid.json'
print('loading {} texts and vocab'.format(usage))
with open(translation_path_en, 'r') as f:
data_en = f.readlines()
with open(translation_path_zh, 'r') as f:
data_zh = f.readlines()
print('building {} samples'.format(usage))
samples = []
for idx in tqdm(range(len(data_en))):
sentence_en = data_en[idx].strip().lower()
tokens = [normalizeString(s) for s in nltk.word_tokenize(sentence_en)]
input_en = encode_text(word_map_en, tokens)
sentence_zh = data_zh[idx].strip()
seg_list = jieba.cut(sentence_zh)
output_zh = encode_text(word_map_zh, list(seg_list))
if len(input_en) <= max_len and len(
output_zh) <= max_len and UNK_token not in input_en and UNK_token not in output_zh:
samples.append({'input': list(input_en), 'output': list(output_zh)})
with open(filename, 'w') as f:
json.dump(samples, f, indent=4)
print('{} {} samples created at: {}.'.format(len(samples), usage, filename))
示例4: get_loader
# 需要导入模块: import config [as 别名]
# 或者: from config import max_len [as 别名]
def get_loader(src_file, trg_file, word2idx,
batch_size, use_tag=False, debug=False, shuffle=False):
dataset = SQuadDatasetWithTag(src_file, trg_file, config.max_len,
word2idx, debug)
dataloader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=collate_fn_tag)
return dataloader
示例5: infer_deep_model
# 需要导入模块: import config [as 别名]
# 或者: from config import max_len [as 别名]
def infer_deep_model(model_type='cnn',
data_path='',
model_save_path='',
label_vocab_path='',
max_len=300,
batch_size=128,
col_sep='\t',
pred_save_path=None):
from keras.models import load_model
# load data content
data_set, true_labels = data_reader(data_path, col_sep)
# init feature
# han model need [doc sentence dim] feature(shape 3); others is [sentence dim] feature(shape 2)
if model_type == 'han':
feature_type = 'doc_vectorize'
else:
feature_type = 'vectorize'
feature = Feature(data_set, feature_type=feature_type, is_infer=True, max_len=max_len)
# get data feature
data_feature = feature.get_feature()
# load model
model = load_model(model_save_path)
# predict, in keras, predict_proba same with predict
pred_label_probs = model.predict(data_feature, batch_size=batch_size)
# label id map
label_id = load_vocab(label_vocab_path)
id_label = {v: k for k, v in label_id.items()}
pred_labels = [prob.argmax() for prob in pred_label_probs]
pred_labels = [id_label[i] for i in pred_labels]
pred_output = [id_label[prob.argmax()] + col_sep + str(prob.max()) for prob in pred_label_probs]
logger.info("save infer label and prob result to: %s" % pred_save_path)
save_predict_result(pred_output, ture_labels=None, pred_save_path=pred_save_path, data_set=data_set)
if true_labels:
# evaluate
assert len(pred_labels) == len(true_labels)
for label, prob in zip(true_labels, pred_label_probs):
logger.debug('label_true:%s\tprob_label:%s\tprob:%s' % (label, id_label[prob.argmax()], prob.max()))
print('total eval:')
try:
print(classification_report(true_labels, pred_labels))
print(confusion_matrix(true_labels, pred_labels))
except UnicodeEncodeError:
true_labels_id = [label_id[i] for i in true_labels]
pred_labels_id = [label_id[i] for i in pred_labels]
print(classification_report(true_labels_id, pred_labels_id))
print(confusion_matrix(true_labels_id, pred_labels_id))