本文整理汇总了Python中dataset.alignCollate方法的典型用法代码示例。如果您正苦于以下问题:Python dataset.alignCollate方法的具体用法?Python dataset.alignCollate怎么用?Python dataset.alignCollate使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类dataset
的用法示例。
在下文中一共展示了dataset.alignCollate方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: data_loader
# 需要导入模块: import dataset [as 别名]
# 或者: from dataset import alignCollate [as 别名]
def data_loader():
# train
train_dataset = dataset.lmdbDataset(root=args.trainroot)
assert train_dataset
if not params.random_sample:
sampler = dataset.randomSequentialSampler(train_dataset, params.batchSize)
else:
sampler = None
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params.batchSize, \
shuffle=True, sampler=sampler, num_workers=int(params.workers), \
collate_fn=dataset.alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio))
# val
val_dataset = dataset.lmdbDataset(root=args.valroot, transform=dataset.resizeNormalize((params.imgW, params.imgH)))
assert val_dataset
val_loader = torch.utils.data.DataLoader(val_dataset, shuffle=True, batch_size=params.batchSize, num_workers=int(params.workers))
return train_loader, val_loader
示例2: val
# 需要导入模块: import dataset [as 别名]
# 或者: from dataset import alignCollate [as 别名]
def val(net, test_dataset, criterion, max_iter=2):
print('Start val')
for p in crnn.parameters():
p.requires_grad = False
net.eval()
data_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=opt.batchSize, num_workers=int(opt.workers),
sampler=dataset.randomSequentialSampler(test_dataset, opt.batchSize),
collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio))
val_iter = iter(data_loader)
i = 0
n_correct = 0
loss_avg = utils.averager()
test_distance=0
max_iter = min(max_iter, len(data_loader))
for i in range(max_iter):
data = val_iter.next()
i += 1
cpu_images, cpu_texts = data
batch_size = cpu_images.size(0)
utils.loadData(image, cpu_images)
if ifUnicode:
cpu_texts = [ clean_txt(tx.decode('utf-8')) for tx in cpu_texts]
t, l = converter.encode(cpu_texts)
utils.loadData(text, t)
utils.loadData(length, l)
preds = crnn(image)
preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
cost = criterion(preds, text, preds_size, length) / batch_size
loss_avg.add(cost)
_, preds = preds.max(2)
# preds = preds.squeeze(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
for pred, target in zip(sim_preds, cpu_texts):
if pred.strip() == target.strip():
n_correct += 1
# print(distance.levenshtein(pred.strip(), target.strip()))
test_distance +=distance.nlevenshtein(pred.strip(), target.strip(),method=2)
raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:opt.n_test_disp]
for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts):
print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
accuracy = n_correct / float(max_iter * opt.batchSize)
test_distance=test_distance/float(max_iter * opt.batchSize)
testLoss = loss_avg.val()
#print('Test loss: %f, accuray: %f' % (testLoss, accuracy))
return testLoss,accuracy,test_distance