当前位置: 首页>>代码示例>>Python>>正文


Python models.CNN属性代码示例

本文整理汇总了Python中models.CNN属性的典型用法代码示例。如果您正苦于以下问题:Python models.CNN属性的具体用法?Python models.CNN怎么用?Python models.CNN使用的例子?那么恭喜您, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在models的用法示例。


在下文中一共展示了models.CNN属性的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: build_model

# 需要导入模块: import models [as 别名]
# 或者: from models import CNN [as 别名]
def build_model(self):
        """Creates and initializes the shared and controller models."""
        if self.args.network_type == 'rnn':
            self.shared = models.RNN(self.args, self.dataset)
            self.controller = models.Controller(self.args)
        elif self.args.network_type == 'micro_cnn':
            self.shared = models.CNN(self.args, self.dataset)
            self.controller = models.CNNMicroController(self.args)
        else:
            raise NotImplementedError(f'Network type '
                                      f'`{self.args.network_type}` is not '
                                      f'defined')

        if self.args.num_gpu == 1:
            if torch.__version__ == '0.3.1':
                self.shared.cuda()
                self.controller.cuda()
            else:
                self.shared.to(self.device)
                self.controller.to(self.device)

        elif self.args.num_gpu > 1:
            raise NotImplementedError('`num_gpu > 1` is in progress') 
开发者ID:kcyu2014,项目名称:eval-nas,代码行数:25,代码来源:trainer.py

示例2: build_model

# 需要导入模块: import models [as 别名]
# 或者: from models import CNN [as 别名]
def build_model(self):
        """Creates and initializes the shared and controller models."""
        if self.args.network_type == 'rnn':
            self.shared = models.RNN(self.args, self.dataset)
        elif self.args.network_type == 'cnn':
            self.shared = models.CNN(self.args, self.dataset)
        else:
            raise NotImplementedError(f'Network type '
                                      f'`{self.args.network_type}` is not '
                                      f'defined')
        self.controller = models.Controller(self.args)

        if self.args.num_gpu == 1:
            self.shared.cuda()
            self.controller.cuda()
        elif self.args.num_gpu > 1:
            raise NotImplementedError('`num_gpu > 1` is in progress') 
开发者ID:carpedm20,项目名称:ENAS-pytorch,代码行数:19,代码来源:trainer.py

示例3: CNN_train

# 需要导入模块: import models [as 别名]
# 或者: from models import CNN [as 别名]
def CNN_train(test_fold, feat):
    """
    Training CNN using extracted feature
    :param test_fold: test fold of 5-fold cross validation
    :param feat: which feature to use
    """
    # 学习率衰减策略,每20个epoch衰减一次,变为0.1倍。
    def scheduler(epoch):
        if epoch in [20, 40]:
            lr = K.get_value(model.optimizer.lr)
            K.set_value(model.optimizer.lr, lr * 0.1)
            print("lr changed to {}".format(lr * 0.1))
        return K.get_value(model.optimizer.lr)

    # 读取特征数据
    train_features, train_labels, test_features, test_labels = esc10_input.get_data(test_fold, feat)

    # 一些超参的配置
    epoch = 60
    num_class = 10
    batch_size = 32
    input_shape = (60, 65, 1)

    # 构建CNN模型
    model = models.CNN(input_shape, num_class)

    # 回调函数
    reduce_lr = LearningRateScheduler(scheduler)   # 学习率衰减
    logs = TensorBoard(log_dir='./log/fold{}/'.format(test_fold))   # 保存模型训练日志
    checkpoint = ModelCheckpoint('./saved_model/cnn_{}_fold{}_best.h5'.format(feat, test_fold),  # 保存在验证集上性能最好的模型
                                 monitor='val_acc', verbose=1, save_best_only=True, mode='max', period=1)
    # 训练模型
    model.fit(train_features, train_labels, batch_size=batch_size, nb_epoch=epoch, verbose=1, validation_split=0.1,
              callbacks=[checkpoint, reduce_lr, logs])

    # 保存模型
    model.save('./saved_model/cnn_{}_fold{}.h5'.format(feat, test_fold))

    # 输出训练好的模型在测试集上的表现
    score = model.evaluate(test_features, test_labels)
    print('Test score:', score[0])
    print('Test accuracy:', score[1])

    return score[1] 
开发者ID:JasonZhang156,项目名称:Sound-Recognition-Tutorial,代码行数:46,代码来源:train.py

示例4: train

# 需要导入模块: import models [as 别名]
# 或者: from models import CNN [as 别名]
def train():
    fluid.enable_dygraph(device)
    processor = SentaProcessor(
        data_dir=args.data_dir,
        vocab_path=args.vocab_path,
        random_seed=args.random_seed)
    num_labels = len(processor.get_labels())

    num_train_examples = processor.get_num_examples(phase="train")

    max_train_steps = args.epoch * num_train_examples // args.batch_size // dev_count

    train_data_generator = processor.data_generator(
        batch_size=args.batch_size,
        padding_size=args.padding_size,
        places=device,
        phase='train',
        epoch=args.epoch,
        shuffle=False)

    eval_data_generator = processor.data_generator(
        batch_size=args.batch_size,
        padding_size=args.padding_size,
        places=device,
        phase='dev',
        epoch=args.epoch,
        shuffle=False)
    if args.model_type == 'cnn_net':
        model = CNN(args.vocab_size, args.batch_size, args.padding_size)
    elif args.model_type == 'bow_net':
        model = BOW(args.vocab_size, args.batch_size, args.padding_size)
    elif args.model_type == 'gru_net':
        model = GRU(args.vocab_size, args.batch_size, args.padding_size)
    elif args.model_type == 'bigru_net':
        model = BiGRU(args.vocab_size, args.batch_size, args.padding_size)

    optimizer = fluid.optimizer.Adagrad(
        learning_rate=args.lr, parameter_list=model.parameters())

    inputs = [Input([None, None], 'int64', name='doc')]
    labels = [Input([None, 1], 'int64', name='label')]

    model.prepare(
        optimizer,
        CrossEntropy(),
        Accuracy(topk=(1, )),
        inputs,
        labels,
        device=device)

    model.fit(train_data=train_data_generator,
              eval_data=eval_data_generator,
              batch_size=args.batch_size,
              epochs=args.epoch,
              save_dir=args.checkpoints,
              eval_freq=args.eval_freq,
              save_freq=args.save_freq) 
开发者ID:PaddlePaddle,项目名称:hapi,代码行数:59,代码来源:sentiment_classifier.py

示例5: infer

# 需要导入模块: import models [as 别名]
# 或者: from models import CNN [as 别名]
def infer():
    fluid.enable_dygraph(device)
    processor = SentaProcessor(
        data_dir=args.data_dir,
        vocab_path=args.vocab_path,
        random_seed=args.random_seed)

    infer_data_generator = processor.data_generator(
        batch_size=args.batch_size,
        padding_size=args.padding_size,
        places=device,
        phase='infer',
        epoch=1,
        shuffle=False)
    if args.model_type == 'cnn_net':
        model_infer = CNN(args.vocab_size, args.batch_size, args.padding_size)
    elif args.model_type == 'bow_net':
        model_infer = BOW(args.vocab_size, args.batch_size, args.padding_size)
    elif args.model_type == 'gru_net':
        model_infer = GRU(args.vocab_size, args.batch_size, args.padding_size)
    elif args.model_type == 'bigru_net':
        model_infer = BiGRU(args.vocab_size, args.batch_size,
                            args.padding_size)

    print('Do inferring ...... ')
    inputs = [Input([None, None], 'int64', name='doc')]
    model_infer.prepare(
        None, CrossEntropy(), Accuracy(topk=(1, )), inputs, device=device)
    model_infer.load(args.checkpoints, reset_optimizer=True)
    preds = model_infer.predict(test_data=infer_data_generator)
    preds = np.array(preds[0]).reshape((-1, 2))

    if args.output_dir:
        with open(os.path.join(args.output_dir, 'predictions.json'), 'w') as w:

            for p in range(len(preds)):
                label = np.argmax(preds[p])
                result = json.dumps({
                    'index': p,
                    'label': label,
                    'probs': preds[p].tolist()
                })
                w.write(result + '\n')
        print('Predictions saved at ' + os.path.join(args.output_dir,
                                                     'predictions.json')) 
开发者ID:PaddlePaddle,项目名称:hapi,代码行数:47,代码来源:sentiment_classifier.py


注:本文中的models.CNN属性示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。