本文整理匯總了Python中models.setup方法的典型用法代碼示例。如果您正苦於以下問題:Python models.setup方法的具體用法?Python models.setup怎麽用?Python models.setup使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類models
的用法示例。
在下文中一共展示了models.setup方法的2個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: test
# 需要導入模塊: import models [as 別名]
# 或者: from models import setup [as 別名]
def test(opt):
logger = Logger(opt)
dataset = VISTDataset(opt)
opt.vocab_size = dataset.get_vocab_size()
opt.seq_length = dataset.get_story_length()
dataset.test()
test_loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.workers)
evaluator = Evaluator(opt, 'test')
model = models.setup(opt)
model.cuda()
predictions, metrics = evaluator.test_story(model, dataset, test_loader, opt)
示例2: train
# 需要導入模塊: import models [as 別名]
# 或者: from models import setup [as 別名]
def train(opt,train_iter, test_iter,verbose=True):
global_start= time.time()
logger = utils.getLogger()
model=models.setup(opt)
if torch.cuda.is_available():
model.cuda()
params = [param for param in model.parameters() if param.requires_grad] #filter(lambda p: p.requires_grad, model.parameters())
model_info =";".join( [str(k)+":"+ str(v) for k,v in opt.__dict__.items() if type(v) in (str,int,float,list,bool)])
logger.info("# parameters:" + str(sum(param.numel() for param in params)))
logger.info(model_info)
model.train()
optimizer = utils.getOptimizer(params,name=opt.optimizer, lr=opt.learning_rate,scheduler= utils.get_lr_scheduler(opt.lr_scheduler))
loss_fun = F.cross_entropy
filename = None
percisions=[]
for i in range(opt.max_epoch):
for epoch,batch in enumerate(train_iter):
optimizer.zero_grad()
start= time.time()
text = batch.text[0] if opt.from_torchtext else batch.text
predicted = model(text)
loss= loss_fun(predicted,batch.label)
loss.backward()
utils.clip_gradient(optimizer, opt.grad_clip)
optimizer.step()
if verbose:
if torch.cuda.is_available():
logger.info("%d iteration %d epoch with loss : %.5f in %.4f seconds" % (i,epoch,loss.cpu().data.numpy(),time.time()-start))
else:
logger.info("%d iteration %d epoch with loss : %.5f in %.4f seconds" % (i,epoch,loss.data.numpy()[0],time.time()-start))
percision=utils.evaluation(model,test_iter,opt.from_torchtext)
if verbose:
logger.info("%d iteration with percision %.4f" % (i,percision))
if len(percisions)==0 or percision > max(percisions):
if filename:
os.remove(filename)
filename = model.save(metric=percision)
percisions.append(percision)
# while(utils.is_writeable(performance_log_file)):
df = pd.read_csv(performance_log_file,index_col=0,sep="\t")
df.loc[model_info,opt.dataset] = max(percisions)
df.to_csv(performance_log_file,sep="\t")
logger.info(model_info +" with time :"+ str( time.time()-global_start)+" ->" +str( max(percisions) ) )
print(model_info +" with time :"+ str( time.time()-global_start)+" ->" +str( max(percisions) ) )