当前位置: 首页>>代码示例>>Python>>正文


Python config.max_len方法代码示例

本文整理汇总了Python中config.max_len方法的典型用法代码示例。如果您正苦于以下问题:Python config.max_len方法的具体用法?Python config.max_len怎么用?Python config.max_len使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在config的用法示例。


在下文中一共展示了config.max_len方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: pad_trunc_seq

# 需要导入模块: import config [as 别名]
# 或者: from config import max_len [as 别名]
def pad_trunc_seq(x, max_len):
    """Pad or truncate a sequence data to a fixed length. 
    
    Args:
      x: ndarray, input sequence data. 
      max_len: integer, length of sequence to be padded or truncated. 
      
    Returns:
      ndarray, Padded or truncated input sequence data. 
    """
    L = len(x)
    shape = x.shape
    if L < max_len:
        pad_shape = (max_len - L,) + shape[1:]
        pad = np.zeros(pad_shape)
        x_new = np.concatenate((x, pad), axis=0)
    else:
        x_new = x[0:max_len]
    return x_new
    
### Load data & scale data 
开发者ID:yongxuUSTC,项目名称:dcase2017_task4_cvssp,代码行数:23,代码来源:prepare_data.py

示例2: build_samples

# 需要导入模块: import config [as 别名]
# 或者: from config import max_len [as 别名]
def build_samples():
    word_map_zh = json.load(open('data/WORDMAP_zh.json', 'r'))
    word_map_en = json.load(open('data/WORDMAP_en.json', 'r'))

    for usage in ['train', 'valid']:
        if usage == 'train':
            translation_path_en = os.path.join(train_translation_folder, train_translation_en_filename)
            translation_path_zh = os.path.join(train_translation_folder, train_translation_zh_filename)
            filename = 'data/samples_train.json'
        else:
            translation_path_en = os.path.join(valid_translation_folder, valid_translation_en_filename)
            translation_path_zh = os.path.join(valid_translation_folder, valid_translation_zh_filename)
            filename = 'data/samples_valid.json'

        print('loading {} texts and vocab'.format(usage))
        with open(translation_path_en, 'r') as f:
            data_en = f.readlines()

        with open(translation_path_zh, 'r') as f:
            data_zh = f.readlines()

        print('building {} samples'.format(usage))
        samples = []
        for idx in tqdm(range(len(data_en))):
            sentence_zh = data_zh[idx].strip()
            seg_list = jieba.cut(sentence_zh)
            input_zh = encode_text(word_map_zh, list(seg_list))

            sentence_en = data_en[idx].strip().lower()
            tokens = [normalizeString(s) for s in nltk.word_tokenize(sentence_en) if len(normalizeString(s)) > 0]
            output_en = encode_text(word_map_en, tokens)

            if len(input_zh) <= max_len and len(
                    output_en) <= max_len and UNK_token not in input_zh and UNK_token not in output_en:
                samples.append({'input': list(input_zh), 'output': list(output_en)})

        with open(filename, 'w') as f:
            json.dump(samples, f, indent=4)

        print('{} {} samples created at: {}.'.format(len(samples), usage, filename)) 
开发者ID:foamliu,项目名称:Machine-Translation,代码行数:42,代码来源:pre_process.py

示例3: build_samples

# 需要导入模块: import config [as 别名]
# 或者: from config import max_len [as 别名]
def build_samples():
    word_map_zh = json.load(open('data/WORDMAP_zh.json', 'r'))
    word_map_en = json.load(open('data/WORDMAP_en.json', 'r'))

    for usage in ['train', 'valid']:
        if usage == 'train':
            translation_path_en = os.path.join(train_translation_folder, train_translation_en_filename)
            translation_path_zh = os.path.join(train_translation_folder, train_translation_zh_filename)
            filename = 'data/samples_train.json'
        else:
            translation_path_en = os.path.join(valid_translation_folder, valid_translation_en_filename)
            translation_path_zh = os.path.join(valid_translation_folder, valid_translation_zh_filename)
            filename = 'data/samples_valid.json'

        print('loading {} texts and vocab'.format(usage))
        with open(translation_path_en, 'r') as f:
            data_en = f.readlines()

        with open(translation_path_zh, 'r') as f:
            data_zh = f.readlines()

        print('building {} samples'.format(usage))
        samples = []
        for idx in tqdm(range(len(data_en))):
            sentence_en = data_en[idx].strip().lower()
            tokens = [normalizeString(s) for s in nltk.word_tokenize(sentence_en)]
            input_en = encode_text(word_map_en, tokens)

            sentence_zh = data_zh[idx].strip()
            seg_list = jieba.cut(sentence_zh)
            output_zh = encode_text(word_map_zh, list(seg_list))

            if len(input_en) <= max_len and len(
                    output_zh) <= max_len and UNK_token not in input_en and UNK_token not in output_zh:
                samples.append({'input': list(input_en), 'output': list(output_zh)})

        with open(filename, 'w') as f:
            json.dump(samples, f, indent=4)

        print('{} {} samples created at: {}.'.format(len(samples), usage, filename)) 
开发者ID:foamliu,项目名称:Machine-Translation-v2,代码行数:42,代码来源:pre_process.py

示例4: get_loader

# 需要导入模块: import config [as 别名]
# 或者: from config import max_len [as 别名]
def get_loader(src_file, trg_file, word2idx,
               batch_size, use_tag=False, debug=False, shuffle=False):
    dataset = SQuadDatasetWithTag(src_file, trg_file, config.max_len,
                                  word2idx, debug)
    dataloader = data.DataLoader(dataset=dataset,
                                 batch_size=batch_size,
                                 shuffle=shuffle,
                                 collate_fn=collate_fn_tag)

    return dataloader 
开发者ID:seanie12,项目名称:neural-question-generation,代码行数:12,代码来源:data_utils.py

示例5: infer_deep_model

# 需要导入模块: import config [as 别名]
# 或者: from config import max_len [as 别名]
def infer_deep_model(model_type='cnn',
                     data_path='',
                     model_save_path='',
                     label_vocab_path='',
                     max_len=300,
                     batch_size=128,
                     col_sep='\t',
                     pred_save_path=None):
    from keras.models import load_model
    # load data content
    data_set, true_labels = data_reader(data_path, col_sep)
    # init feature
    # han model need [doc sentence dim] feature(shape 3); others is [sentence dim] feature(shape 2)
    if model_type == 'han':
        feature_type = 'doc_vectorize'
    else:
        feature_type = 'vectorize'
    feature = Feature(data_set, feature_type=feature_type, is_infer=True, max_len=max_len)
    # get data feature
    data_feature = feature.get_feature()

    # load model
    model = load_model(model_save_path)
    # predict, in keras, predict_proba same with predict
    pred_label_probs = model.predict(data_feature, batch_size=batch_size)

    # label id map
    label_id = load_vocab(label_vocab_path)
    id_label = {v: k for k, v in label_id.items()}
    pred_labels = [prob.argmax() for prob in pred_label_probs]
    pred_labels = [id_label[i] for i in pred_labels]
    pred_output = [id_label[prob.argmax()] + col_sep + str(prob.max()) for prob in pred_label_probs]
    logger.info("save infer label and prob result to: %s" % pred_save_path)
    save_predict_result(pred_output, ture_labels=None, pred_save_path=pred_save_path, data_set=data_set)
    if true_labels:
        # evaluate
        assert len(pred_labels) == len(true_labels)
        for label, prob in zip(true_labels, pred_label_probs):
            logger.debug('label_true:%s\tprob_label:%s\tprob:%s' % (label, id_label[prob.argmax()], prob.max()))

        print('total eval:')
        try:
            print(classification_report(true_labels, pred_labels))
            print(confusion_matrix(true_labels, pred_labels))
        except UnicodeEncodeError:
            true_labels_id = [label_id[i] for i in true_labels]
            pred_labels_id = [label_id[i] for i in pred_labels]
            print(classification_report(true_labels_id, pred_labels_id))
            print(confusion_matrix(true_labels_id, pred_labels_id)) 
开发者ID:shibing624,项目名称:text-classifier,代码行数:51,代码来源:infer.py


注:本文中的config.max_len方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。