本文整理汇总了Python中models.setup方法的典型用法代码示例。如果您正苦于以下问题:Python models.setup方法的具体用法?Python models.setup怎么用?Python models.setup使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类models
的用法示例。
在下文中一共展示了models.setup方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test
# 需要导入模块: import models [as 别名]
# 或者: from models import setup [as 别名]
def test(opt):
logger = Logger(opt)
dataset = VISTDataset(opt)
opt.vocab_size = dataset.get_vocab_size()
opt.seq_length = dataset.get_story_length()
dataset.test()
test_loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.workers)
evaluator = Evaluator(opt, 'test')
model = models.setup(opt)
model.cuda()
predictions, metrics = evaluator.test_story(model, dataset, test_loader, opt)
示例2: train
# 需要导入模块: import models [as 别名]
# 或者: from models import setup [as 别名]
def train(opt,train_iter, test_iter,verbose=True):
global_start= time.time()
logger = utils.getLogger()
model=models.setup(opt)
if torch.cuda.is_available():
model.cuda()
params = [param for param in model.parameters() if param.requires_grad] #filter(lambda p: p.requires_grad, model.parameters())
model_info =";".join( [str(k)+":"+ str(v) for k,v in opt.__dict__.items() if type(v) in (str,int,float,list,bool)])
logger.info("# parameters:" + str(sum(param.numel() for param in params)))
logger.info(model_info)
model.train()
optimizer = utils.getOptimizer(params,name=opt.optimizer, lr=opt.learning_rate,scheduler= utils.get_lr_scheduler(opt.lr_scheduler))
loss_fun = F.cross_entropy
filename = None
percisions=[]
for i in range(opt.max_epoch):
for epoch,batch in enumerate(train_iter):
optimizer.zero_grad()
start= time.time()
text = batch.text[0] if opt.from_torchtext else batch.text
predicted = model(text)
loss= loss_fun(predicted,batch.label)
loss.backward()
utils.clip_gradient(optimizer, opt.grad_clip)
optimizer.step()
if verbose:
if torch.cuda.is_available():
logger.info("%d iteration %d epoch with loss : %.5f in %.4f seconds" % (i,epoch,loss.cpu().data.numpy(),time.time()-start))
else:
logger.info("%d iteration %d epoch with loss : %.5f in %.4f seconds" % (i,epoch,loss.data.numpy()[0],time.time()-start))
percision=utils.evaluation(model,test_iter,opt.from_torchtext)
if verbose:
logger.info("%d iteration with percision %.4f" % (i,percision))
if len(percisions)==0 or percision > max(percisions):
if filename:
os.remove(filename)
filename = model.save(metric=percision)
percisions.append(percision)
# while(utils.is_writeable(performance_log_file)):
df = pd.read_csv(performance_log_file,index_col=0,sep="\t")
df.loc[model_info,opt.dataset] = max(percisions)
df.to_csv(performance_log_file,sep="\t")
logger.info(model_info +" with time :"+ str( time.time()-global_start)+" ->" +str( max(percisions) ) )
print(model_info +" with time :"+ str( time.time()-global_start)+" ->" +str( max(percisions) ) )