当前位置: 首页>>代码示例>>Python>>正文


Python Solver.train方法代码示例

本文整理汇总了Python中solver.Solver.train方法的典型用法代码示例。如果您正苦于以下问题:Python Solver.train方法的具体用法?Python Solver.train怎么用?Python Solver.train使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在solver.Solver的用法示例。


在下文中一共展示了Solver.train方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: train

# 需要导入模块: from solver import Solver [as 别名]
# 或者: from solver.Solver import train [as 别名]
def train(param=PARAMS, sv=SOLVE, small=False):

    sv['name'] = __file__.rstrip('.py')
    input_var = raw_input('Are you testing now? ')
    
    if 'no' in input_var:
        sv.pop('name')
    else:
        sv['name'] += input_var

    out = get(1) 
    from my_layer import LSTM
    sym = LSTM(e_net.l3_4, 64*64, 1, 64, 64)
    sym = list(sym)
    sym[0] = mx.sym.LogisticRegressionOutput(data=sym[0], name='softmax')
    sym = mx.symbol.Group(list(sym))

    param['eval_data'] = out['val'] 
    param['marks'] = param['e_marks'] = out['marks'] 
    param['ctx'] = mu.gpu(1)

    print out['train'].label[0][1].shape
  
    s = Solver(sym, out['train'], sv, **param)
    s.train()
    s.predict()
开发者ID:ZijiaLewisLu,项目名称:HeartDeep-Kaggle-DSB2,代码行数:28,代码来源:train_rnn.py

示例2: train

# 需要导入模块: from solver import Solver [as 别名]
# 或者: from solver.Solver import train [as 别名]
def train(param=PARAMS, sv=SOLVE, small=False):

    sv['name'] = 'TEST'
    input_var = raw_input('Are you testing now? ')
    
    if 'no' in input_var:
        sv.pop('name')
    else:
        sv['name'] += input_var


    #out = u.get(6,small=True, aug=True) 
    imgs, ll = load_rnn_pk(files)
    imgs = imgs.reshape((-1,1,256,256))
    ll   = ll.reshape((-1,1,256,256))
    datas = u.prepare_set(imgs, ll)

    out = u.create_iter(*datas, batch_size=5)
    net = cnn_net(
        use_logis=True
        )

    param['eval_data'] = out[1] 
  
    s = Solver(net, out[0], sv, **param)
    s.train()
    s.predict()
    s.all_to_png()
    s.save_best_model()
    s.plot_process()
开发者ID:ZijiaLewisLu,项目名称:HeartDeep-Kaggle-DSB2,代码行数:32,代码来源:train_on_new.py

示例3: train

# 需要导入模块: from solver import Solver [as 别名]
# 或者: from solver.Solver import train [as 别名]
def train(param=PARAMS, sv=SOLVE, small=False):

    net = make_net()

    out = R.get(2, rate=0.05) 
    train, param['eval_data'] = out['train'], out['val']  
    param['marks'] = param['e_marks'] = out['marks'] 

    s = Solver(net, train, sv, **param)
    s.train()
    s.predict()
开发者ID:ZijiaLewisLu,项目名称:HeartDeep-Kaggle-DSB2,代码行数:13,代码来源:train.py

示例4: main

# 需要导入模块: from solver import Solver [as 别名]
# 或者: from solver.Solver import train [as 别名]
def main(_):
    
    model = DTN(mode=FLAGS.mode, learning_rate=0.0003)
    solver = Solver(model, batch_size=100, pretrain_iter=20000, train_iter=2000, sample_iter=100, 
                    svhn_dir='svhn', mnist_dir='mnist', model_save_path=FLAGS.model_save_path, sample_save_path=FLAGS.sample_save_path)
    
    # create directories if not exist
    if not tf.gfile.Exists(FLAGS.model_save_path):
        tf.gfile.MakeDirs(FLAGS.model_save_path)
    if not tf.gfile.Exists(FLAGS.sample_save_path):
        tf.gfile.MakeDirs(FLAGS.sample_save_path)
    
    if FLAGS.mode == 'pretrain':
        solver.pretrain()
    elif FLAGS.mode == 'train':
        solver.train()
    else:
        solver.eval()
开发者ID:ALISCIFP,项目名称:domain-transfer-network,代码行数:20,代码来源:main.py

示例5: train

# 需要导入模块: from solver import Solver [as 别名]
# 或者: from solver.Solver import train [as 别名]
def train(param=PARAMS, sv=SOLVE, small=False):

    sv['name'] = __file__.rstrip('.py')
    input_var = raw_input('Are you testing now? ')
    
    if 'no' in input_var:
        sv.pop('name')
    else:
        sv['name'] += input_var

    out = get(6, aug=True) 
    sym = net()

    param['eval_data'] = out['val'] 
  
    s = Solver(sym, out['train'], sv, **param)
    s.train()
    s.predict()
开发者ID:ZijiaLewisLu,项目名称:HeartDeep-Kaggle-DSB2,代码行数:20,代码来源:train_e.py

示例6: train

# 需要导入模块: from solver import Solver [as 别名]
# 或者: from solver.Solver import train [as 别名]
def train(param=PARAMS, sv=SOLVE, small=False):
    # prepare net
    net = unroll_lstm(10, 64*64, 1, 64, 64)

    # prepare data
    from Evol.load_e import reshape_label
    from RNN.rnn_load import load_rnn_pk
    img, ll = load_rnn_pk(['../DATA/PK/NEW/[T10,N10]<8-11:42:11>.pk'])

    ll = reshape_label(ll)
    lt, lv = ll[:8], ll[8:]
    train = UnrollIter(lt, label=lt, batch_size=2, num_hidden=64*64)
    val   = UnrollIter(lv, label=lv, batch_size=2, num_hidden=64*64)

    # train
    s = Solver(net, train, sv, **param)
    print 'Start Training...'
    s.train()
    s.predict()
开发者ID:ZijiaLewisLu,项目名称:HeartDeep-Kaggle-DSB2,代码行数:21,代码来源:unroll_train.py

示例7: main

# 需要导入模块: from solver import Solver [as 别名]
# 或者: from solver.Solver import train [as 别名]
def main(args):
    solver = Solver(root = args.root,
                    result_dir = args.result_dir,
                    weight_dir = args.weight_dir,
                    load_weight = args.load_weight,
                    batch_size = args.batch_size,
                    test_size = args.test_size,
                    test_img_num = args.test_img_num,
                    img_size = args.img_size,
                    num_epoch = args.num_epoch,
                    save_every = args.save_every,
                    lr = args.lr,
                    beta_1 = args.beta_1,
                    beta_2 = args.beta_2,
                    lambda_kl = args.lambda_kl,
                    lambda_img = args.lambda_img,
                    lambda_z = args.lambda_z,
                    z_dim = args.z_dim)
                    
    solver.train()
开发者ID:Pandinosaurus,项目名称:BicycleGAN-pytorch,代码行数:22,代码来源:train.py

示例8: main

# 需要导入模块: from solver import Solver [as 别名]
# 或者: from solver.Solver import train [as 别名]
def main(config):
    # For fast training.
    cudnn.benchmark = True

    # Create directories if not exist.
    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)
    if not os.path.exists(config.model_save_dir):
        os.makedirs(config.model_save_dir)
    if not os.path.exists(config.sample_dir):
        os.makedirs(config.sample_dir)
    if not os.path.exists(config.result_dir):
        os.makedirs(config.result_dir)

    # Data loader.
    celeba_loader = None
    rafd_loader = None

    if config.dataset in ['CelebA', 'Both']:
        celeba_loader = get_loader(config.celeba_image_dir, config.attr_path, config.selected_attrs,
                                   config.celeba_crop_size, config.image_size, config.batch_size,
                                   'CelebA', config.mode, config.num_workers)
    if config.dataset in ['RaFD', 'Both']:
        rafd_loader = get_loader(config.rafd_image_dir, None, None,
                                 config.rafd_crop_size, config.image_size, config.batch_size,
                                 'RaFD', config.mode, config.num_workers)
    

    # Solver for training and testing StarGAN.
    solver = Solver(celeba_loader, rafd_loader, config)

    if config.mode == 'train':
        if config.dataset in ['CelebA', 'RaFD']:
            solver.train()
        elif config.dataset in ['Both']:
            solver.train_multi()
    elif config.mode == 'test':
        if config.dataset in ['CelebA', 'RaFD']:
            solver.test()
        elif config.dataset in ['Both']:
            solver.test_multi()
开发者ID:JacobLee121,项目名称:StarGAN,代码行数:43,代码来源:main.py

示例9: main

# 需要导入模块: from solver import Solver [as 别名]
# 或者: from solver.Solver import train [as 别名]
def main(config):
    cudnn.benchmark = True
    
    data_loader = get_loader(image_path=config.image_path,
                             image_size=config.image_size,
                             batch_size=config.batch_size,
                             num_workers=config.num_workers)
    
    solver = Solver(config, data_loader)
    
    # Create directories if not exist
    if not os.path.exists(config.model_path):
        os.makedirs(config.model_path)
    if not os.path.exists(config.sample_path):
        os.makedirs(config.sample_path)
    
    # Train and sample the images
    if config.mode == 'train':
        solver.train()
    elif config.mode == 'sample':
        solver.sample()
开发者ID:AbhinavJain13,项目名称:pytorch-tutorial,代码行数:23,代码来源:main.py

示例10: train

# 需要导入模块: from solver import Solver [as 别名]
# 或者: from solver.Solver import train [as 别名]
def train(param=PARAMS, sv=SOLVE, small=False):

    sv['name'] = 'TEST'
    input_var = raw_input('Are you testing now? ')
    
    if 'no' in input_var:
        sv.pop('name')
    else:
        sv['name'] += input_var

    out = get(6, small=True, aug=True) 
    net = net()

    param['eval_data'] = out['eval'] 
  
    s = Solver(net, out['train'], sv, **param)
    s.train()
    s.predict()
    s.all_to_png()
    s.save_best_model()
    s.plot_process()
开发者ID:ZijiaLewisLu,项目名称:HeartDeep-Kaggle-DSB2,代码行数:23,代码来源:Template_train.py

示例11: train

# 需要导入模块: from solver import Solver [as 别名]
# 或者: from solver.Solver import train [as 别名]
def train(base_model, param=PARAMS, sv=SOLVE, small=False):

    # prepare data
    if small:
        files = rnn_load.f10
        param['ctx'] = mu.gpu(1)
    else:
        files = rnn_load.files

    imgs, labels = rnn_load.load_rnn_pk(files)
    it, lt, iv, lv = mu.prepare_set(imgs, labels)
    N, T = it.shape[:2]

    # cnn process
    model = mx.model.FeedForward.load(*base_model, ctx=mu.gpu(1))
    rnn_input = np.zeros_like(it)
    for n in range(1):
        rnn_input[n], imgs, labels = mu.predict_draw(model, it[n])

    # prepare params
    #datas = [rnn_input, lt, iv, lv]
    datas = [ lt, lt, lv, lv]
    for i, d in enumerate(datas):
        #datas[i] = np.transpose(d,axes=(1,0,2,3,4))

        # make T become one
        datas[i] = d.reshape((-1,1)+d.shape[2:])

    iters = rnn_load.create_rnn_iter(*datas, batch_size=1, num_hidden=1000)
    param['eval_data'] = iters[1]
    mark = param['marks'] = param['e_marks'] = [1]*T
    rnet = rnn_net(begin=mx.sym.Variable('data'), num_hidden=1000)
    s = Solver(rnet, iters[0], sv, **param)

    # train
    print 'Start Training...'
    s.train()
    s.predict()
开发者ID:ZijiaLewisLu,项目名称:HeartDeep-Kaggle-DSB2,代码行数:40,代码来源:freeze_cnn.py

示例12: R_LSTM_Iter

# 需要导入模块: from solver import Solver [as 别名]
# 或者: from solver.Solver import train [as 别名]
tll  = np.concatenate(labels[:-split], axis=1)
print timg.shape, tll.shape
vimg = np.concatenate(images[-split:], axis=1)
vll  = np.concatenate(labels[-split:], axis=1)


from r_lstm import R_LSTM_Iter
train = R_LSTM_Iter(timg, label=tll, num_hidden=3, batch_size=1)
val =   R_LSTM_Iter(vimg, label=vll, num_hidden=3, batch_size=1)

from solver import Solver
from train import make_net
from settings import PARAMS, SOLVE

SOLVE['is_rnn'] = True
SOLVE['load']   = True
SOLVE['load_perfix'] = '/home/zijia/HeartDeepLearning/R_LSTM/Result/<26-12:43:22>[E5]/[ACC-0.97747 E4]'
SOLVE['load_epoch']  = 4
 
PARAMS['eval_data'] = val 
PARAMS['marks'] = marks[:-split]
PARAMS['e_marks'] = marks[-split:]
PARAMS['ctx'] = mu.gpu(1)

PARAMS['learning_rate'] = 1

s = Solver(make_net(), train, SOLVE, **PARAMS)

s.train()
s.predict()
开发者ID:ZijiaLewisLu,项目名称:HeartDeep-Kaggle-DSB2,代码行数:32,代码来源:finetune.py


注:本文中的solver.Solver.train方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。