本文整理汇总了Python中loss.Loss方法的典型用法代码示例。如果您正苦于以下问题:Python loss.Loss方法的具体用法?Python loss.Loss怎么用?Python loss.Loss使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类loss
的用法示例。
在下文中一共展示了loss.Loss方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: import loss [as 别名]
# 或者: from loss import Loss [as 别名]
def main():
global model
if args.data_test == ['video']:
from videotester import VideoTester
model = model.Model(args, checkpoint)
t = VideoTester(args, model, checkpoint)
t.test()
else:
if checkpoint.ok:
loader = data.Data(args)
_model = model.Model(args, checkpoint)
_loss = loss.Loss(args, checkpoint) if not args.test_only else None
t = Trainer(args, loader, _model, _loss, checkpoint)
while not t.terminate():
t.train()
t.test()
checkpoint.done()
示例2: train
# 需要导入模块: import loss [as 别名]
# 或者: from loss import Loss [as 别名]
def train(train_img_path, train_gt_path, pths_path, batch_size, lr, num_workers, epoch_iter, interval):
file_num = len(os.listdir(train_img_path))
trainset = custom_dataset(train_img_path, train_gt_path)
train_loader = data.DataLoader(trainset, batch_size=batch_size, \
shuffle=True, num_workers=num_workers, drop_last=True)
criterion = Loss()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = EAST()
data_parallel = False
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
data_parallel = True
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[epoch_iter//2], gamma=0.1)
for epoch in range(epoch_iter):
model.train()
scheduler.step()
epoch_loss = 0
epoch_time = time.time()
for i, (img, gt_score, gt_geo, ignored_map) in enumerate(train_loader):
start_time = time.time()
img, gt_score, gt_geo, ignored_map = img.to(device), gt_score.to(device), gt_geo.to(device), ignored_map.to(device)
pred_score, pred_geo = model(img)
loss = criterion(gt_score, pred_score, gt_geo, pred_geo, ignored_map)
epoch_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Epoch is [{}/{}], mini-batch is [{}/{}], time consumption is {:.8f}, batch_loss is {:.8f}'.format(\
epoch+1, epoch_iter, i+1, int(file_num/batch_size), time.time()-start_time, loss.item()))
print('epoch_loss is {:.8f}, epoch_time is {:.8f}'.format(epoch_loss/int(file_num/batch_size), time.time()-epoch_time))
print(time.asctime(time.localtime(time.time())))
print('='*50)
if (epoch + 1) % interval == 0:
state_dict = model.module.state_dict() if data_parallel else model.state_dict()
torch.save(state_dict, os.path.join(pths_path, 'model_epoch_{}.pth'.format(epoch+1)))