当前位置: 首页>>代码示例>>Python>>正文


Python models.setup方法代码示例

本文整理汇总了Python中models.setup方法的典型用法代码示例。如果您正苦于以下问题:Python models.setup方法的具体用法?Python models.setup怎么用?Python models.setup使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在models的用法示例。


在下文中一共展示了models.setup方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test

# 需要导入模块: import models [as 别名]
# 或者: from models import setup [as 别名]
def test(opt):
    logger = Logger(opt)
    dataset = VISTDataset(opt)
    opt.vocab_size = dataset.get_vocab_size()
    opt.seq_length = dataset.get_story_length()

    dataset.test()
    test_loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.workers)
    evaluator = Evaluator(opt, 'test')
    model = models.setup(opt)
    model.cuda()
    predictions, metrics = evaluator.test_story(model, dataset, test_loader, opt) 
开发者ID:eric-xw,项目名称:AREL,代码行数:14,代码来源:train_AREL.py

示例2: train

# 需要导入模块: import models [as 别名]
# 或者: from models import setup [as 别名]
def train(opt,train_iter, test_iter,verbose=True):
    global_start= time.time()
    logger = utils.getLogger()
    model=models.setup(opt)
    if torch.cuda.is_available():
        model.cuda()
    params = [param for param in model.parameters() if param.requires_grad] #filter(lambda p: p.requires_grad, model.parameters())
    
    model_info =";".join( [str(k)+":"+ str(v)  for k,v in opt.__dict__.items() if type(v) in (str,int,float,list,bool)])  
    logger.info("# parameters:" + str(sum(param.numel() for param in params)))
    logger.info(model_info)
    
    
    model.train()
    optimizer = utils.getOptimizer(params,name=opt.optimizer, lr=opt.learning_rate,scheduler= utils.get_lr_scheduler(opt.lr_scheduler))

    loss_fun = F.cross_entropy

    filename = None
    percisions=[]
    for i in range(opt.max_epoch):
        for epoch,batch in enumerate(train_iter):
            optimizer.zero_grad()
            start= time.time()
            
            text = batch.text[0] if opt.from_torchtext else batch.text
            predicted = model(text)
    
            loss= loss_fun(predicted,batch.label)
    
            loss.backward()
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            
            if verbose:
                if  torch.cuda.is_available():
                    logger.info("%d iteration %d epoch with loss : %.5f in %.4f seconds" % (i,epoch,loss.cpu().data.numpy(),time.time()-start))
                else:
                    logger.info("%d iteration %d epoch with loss : %.5f in %.4f seconds" % (i,epoch,loss.data.numpy()[0],time.time()-start))
 
        percision=utils.evaluation(model,test_iter,opt.from_torchtext)
        if verbose:
            logger.info("%d iteration with percision %.4f" % (i,percision))
        if len(percisions)==0 or percision > max(percisions):
            if filename:
                os.remove(filename)
            filename = model.save(metric=percision)
        percisions.append(percision)
            
#    while(utils.is_writeable(performance_log_file)):
    df = pd.read_csv(performance_log_file,index_col=0,sep="\t")
    df.loc[model_info,opt.dataset] =  max(percisions) 
    df.to_csv(performance_log_file,sep="\t")    
    logger.info(model_info +" with time :"+ str( time.time()-global_start)+" ->" +str( max(percisions) ) )
    print(model_info +" with time :"+ str( time.time()-global_start)+" ->" +str( max(percisions) ) ) 
开发者ID:wabyking,项目名称:TextClassificationBenchmark,代码行数:57,代码来源:main.py


注:本文中的models.setup方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。