当前位置: 首页>>代码示例>>Python>>正文


Python config.model方法代码示例

本文整理汇总了Python中config.model方法的典型用法代码示例。如果您正苦于以下问题:Python config.model方法的具体用法?Python config.model怎么用?Python config.model使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在config的用法示例。


在下文中一共展示了config.model方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: load_session

# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def load_session():
    global sess_path, model_config, device, learning_rate, reset_optimizer
    try:
        sess = torch.load(sess_path)
        if 'model_config' in sess and sess['model_config'] != model_config:
            model_config = sess['model_config']
            print('Use session config instead:')
            print(utils.dict2params(model_config))
        model_state = sess['model_state']
        optimizer_state = sess['model_optimizer_state']
        print('Session is loaded from', sess_path)
        sess_loaded = True
    except:
        print('New session')
        sess_loaded = False
    model = PerformanceRNN(**model_config).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    if sess_loaded:
        model.load_state_dict(model_state)
        if not reset_optimizer:
            optimizer.load_state_dict(optimizer_state)
    return model, optimizer 
开发者ID:djosix,项目名称:Performance-RNN-PyTorch,代码行数:24,代码来源:train.py

示例2: save_activations

# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def save_activations(model, inputs, files, layer, batch_number):
    all_activations = []
    ids = []
    af = get_activation_function(model, layer)
    for i in range(len(inputs)):
        acts = get_activations(af, [inputs[i]])
        all_activations.append(acts)
        ids.append(files[i].split('/')[-2])

    submission = pd.DataFrame(all_activations)
    submission.insert(0, 'class', ids)
    submission.reset_index()
    if batch_number > 0:
        submission.to_csv(config.activations_path, index=False, mode='a', header=False)
    else:
        submission.to_csv(config.activations_path, index=False) 
开发者ID:Arsey,项目名称:keras-transfer-learning-for-oxford102,代码行数:18,代码来源:util.py

示例3: parse_args

# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def parse_args():
    """
    Parse input arguments
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('--path', dest='path', help='Path to image', default=None, type=str)
    parser.add_argument('--accuracy', action='store_true', help='To print accuracy score')
    parser.add_argument('--plot_confusion_matrix', action='store_true')
    parser.add_argument('--execution_time', action='store_true')
    parser.add_argument('--store_activations', action='store_true')
    parser.add_argument('--novelty_detection', action='store_true')
    parser.add_argument('--model', type=str, required=True, help='Base model architecture',
                        choices=[config.MODEL_RESNET50, config.MODEL_RESNET152, config.MODEL_INCEPTION_V3,
                                 config.MODEL_VGG16])
    parser.add_argument('--data_dir', help='Path to data train directory')
    parser.add_argument('--batch_size', default=500, type=int, help='How many files to predict on at once')
    args = parser.parse_args()
    return args 
开发者ID:Arsey,项目名称:keras-transfer-learning-for-oxford102,代码行数:20,代码来源:predict.py

示例4: save_model

# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def save_model():
    global model, optimizer, model_config, sess_path
    print('Saving to', sess_path)
    torch.save({'model_config': model_config,
                'model_state': model.state_dict(),
                'model_optimizer_state': optimizer.state_dict()}, sess_path)
    print('Done saving')


#========================================================================
# Training
#======================================================================== 
开发者ID:djosix,项目名称:Performance-RNN-PyTorch,代码行数:14,代码来源:train.py

示例5: get_model

# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def get_model():
    model = resnet_50()
    if config.model == "resnet18":
        model = resnet_18()
    if config.model == "resnet34":
        model = resnet_34()
    if config.model == "resnet101":
        model = resnet_101()
    if config.model == "resnet152":
        model = resnet_152()
    model.build(input_shape=(None, config.image_height, config.image_width, config.channels))
    model.summary()
    return model 
开发者ID:calmisential,项目名称:TensorFlow2.0_ResNet,代码行数:15,代码来源:train.py

示例6: train_step

# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def train_step(images, labels):
        with tf.GradientTape() as tape:
            predictions = model(images, training=True)
            loss = loss_object(y_true=labels, y_pred=predictions)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(grads_and_vars=zip(gradients, model.trainable_variables))

        train_loss(loss)
        train_accuracy(labels, predictions) 
开发者ID:calmisential,项目名称:TensorFlow2.0_ResNet,代码行数:11,代码来源:train.py

示例7: valid_step

# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def valid_step(images, labels):
        predictions = model(images, training=False)
        v_loss = loss_object(labels, predictions)

        valid_loss(v_loss)
        valid_accuracy(labels, predictions)

    # start training 
开发者ID:calmisential,项目名称:TensorFlow2.0_ResNet,代码行数:10,代码来源:train.py

示例8: save_history

# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def save_history(history, prefix):
    if 'acc' not in history.history:
        return

    if not os.path.exists(config.plots_dir):
        os.mkdir(config.plots_dir)

    img_path = os.path.join(config.plots_dir, '{}-%s.jpg'.format(prefix))

    # summarize history for accuracy
    plt.plot(history.history['acc'])
    plt.plot(history.history['val_acc'])
    plt.title('model accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.savefig(img_path % 'accuracy')
    plt.close()

    # summarize history for loss
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper right')
    plt.savefig(img_path % 'loss')
    plt.close() 
开发者ID:Arsey,项目名称:keras-transfer-learning-for-oxford102,代码行数:30,代码来源:util.py

示例9: get_model_class_instance

# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def get_model_class_instance(*args, **kwargs):
    module = importlib.import_module("models.{}".format(config.model))
    return module.inst_class(*args, **kwargs) 
开发者ID:Arsey,项目名称:keras-transfer-learning-for-oxford102,代码行数:5,代码来源:util.py

示例10: parse_args

# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def parse_args():
    """
    Parse input arguments
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, required=True, help='Base model architecture',
                        choices=[config.MODEL_RESNET50,
                                 config.MODEL_RESNET152,
                                 config.MODEL_INCEPTION_V3,
                                 config.MODEL_VGG16])
    parser.add_argument('--use_nn', action='store_true')
    args = parser.parse_args()
    return args 
开发者ID:Arsey,项目名称:keras-transfer-learning-for-oxford102,代码行数:15,代码来源:train_novelty_detection.py

示例11: parse_args

# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', help='Path to data dir')
    parser.add_argument('--model', type=str, required=True, help='Base model architecture', choices=[
        config.MODEL_RESNET50,
        config.MODEL_RESNET152,
        config.MODEL_INCEPTION_V3,
        config.MODEL_VGG16])
    parser.add_argument('--nb_epoch', type=int, default=1000)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--freeze_layers_number', type=int, help='will freeze the first N layers and unfreeze the rest')
    return parser.parse_args() 
开发者ID:Arsey,项目名称:keras-transfer-learning-for-oxford102,代码行数:14,代码来源:train.py

示例12: train

# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def train():
    model = util.get_model_class_instance(
        class_weight=util.get_class_weight(config.train_dir),
        nb_epoch=args.nb_epoch,
        batch_size=args.batch_size,
        freeze_layers_number=args.freeze_layers_number)
    model.train()
    print('Training is finished!') 
开发者ID:Arsey,项目名称:keras-transfer-learning-for-oxford102,代码行数:10,代码来源:train.py

示例13: parse_args

# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def parse_args():
    parser = argparse.ArgumentParser('train net')
    parser.add_argument('gpu_id', type=int)
    parser.add_argument('model', type=str)
    parser.add_argument('--restore', dest='restore', type=str)
    parser.add_argument('--debug', dest='debug', type=bool, default=False)
    parser.add_argument('--init_weights', dest='init_weights', type=str,
                        default='ResNet-50-model.caffemodel')
    parser.add_argument('--step', dest='step', type=int, default=int(1e6))
    parser.add_argument('--process', dest='process', type=int, default=3)

    args = parser.parse_args()
    return args 
开发者ID:voidrank,项目名称:FastMask,代码行数:15,代码来源:train.py

示例14: handle

# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def handle(clientsocket):
    while 1:
        buf = clientsocket.recv(config.buffer_size)
        if buf == 'exit'.encode():
            return  # client terminated connection

        response = ''
        if os.path.isfile(buf):
            try:
                img = [model_module.load_img(buf)]

                out = model.predict(np.array(img))
                prediction = np.argmax(out)
                top10 = out[0].argsort()[-10:][::-1]

                class_indices = dict(zip(config.classes, range(len(config.classes))))
                keys = list(class_indices.keys())
                values = list(class_indices.values())

                answer = keys[values.index(prediction)]

                try:
                    acts = util.get_activations(af, img)
                    predicted_relativity = novelty_detection_clf.predict(acts)[0]
                    nd_class = novelty_detection_clf.__classes[predicted_relativity]
                except Exception as e:
                    print(e.message)
                    nd_class = 'related'

                top10_json = "["
                for i, t in enumerate(top10):
                    top10_json += '{"probability":"%s", "class":"%s"}%s' % (
                        out[0][t], keys[values.index(t)], '' if i == 9 else ',')
                top10_json += "]"

                response = '{"probability":"%s","class":"%s","relativity":"%s","top10":%s}' % (
                    out[0][prediction], answer, nd_class, top10_json)
                print(response)
            except Exception as e:
                print('Error', e)
                traceback.print_stack()
                response = UNKNOWN_ERROR
        else:
            response = FILE_DOES_NOT_EXIST

        clientsocket.sendall(response.encode()) 
开发者ID:Arsey,项目名称:keras-transfer-learning-for-oxford102,代码行数:48,代码来源:server.py

示例15: predict

# 需要导入模块: import config [as 别名]
# 或者: from config import model [as 别名]
def predict(path):
    files = get_files(path)
    n_files = len(files)
    print('Found {} files'.format(n_files))

    if args.novelty_detection:
        activation_function = util.get_activation_function(model, model_module.noveltyDetectionLayerName)
        novelty_detection_clf = joblib.load(config.get_novelty_detection_model_path())

    y_trues = []
    predictions = np.zeros(shape=(n_files,))
    nb_batch = int(np.ceil(n_files / float(args.batch_size)))
    for n in range(0, nb_batch):
        print('Batch {}'.format(n))
        n_from = n * args.batch_size
        n_to = min(args.batch_size * (n + 1), n_files)

        y_true, inputs = get_inputs_and_trues(files[n_from:n_to])
        y_trues += y_true

        if args.store_activations:
            util.save_activations(model, inputs, files[n_from:n_to], model_module.noveltyDetectionLayerName, n)

        if args.novelty_detection:
            activations = util.get_activations(activation_function, [inputs[0]])
            nd_preds = novelty_detection_clf.predict(activations)[0]
            print(novelty_detection_clf.__classes[nd_preds])

        if not args.store_activations:
            # Warm up the model
            if n == 0:
                print('Warming up the model')
                start = time.clock()
                model.predict(np.array([inputs[0]]))
                end = time.clock()
                print('Warming up took {} s'.format(end - start))

            # Make predictions
            start = time.clock()
            out = model.predict(np.array(inputs))
            end = time.clock()
            predictions[n_from:n_to] = np.argmax(out, axis=1)
            print('Prediction on batch {} took: {}'.format(n, end - start))

    if not args.store_activations:
        for i, p in enumerate(predictions):
            recognized_class = list(classes_in_keras_format.keys())[list(classes_in_keras_format.values()).index(p)]
            print('| should be {} ({}) -> predicted as {} ({})'.format(y_trues[i], files[i].split(os.sep)[-2], p,
                                                                       recognized_class))

        if args.accuracy:
            print('Accuracy {}'.format(accuracy_score(y_true=y_trues, y_pred=predictions)))

        if args.plot_confusion_matrix:
            cnf_matrix = confusion_matrix(y_trues, predictions)
            util.plot_confusion_matrix(cnf_matrix, config.classes, normalize=False)
            util.plot_confusion_matrix(cnf_matrix, config.classes, normalize=True) 
开发者ID:Arsey,项目名称:keras-transfer-learning-for-oxford102,代码行数:59,代码来源:predict.py


注:本文中的config.model方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。