本文整理汇总了Python中utils.metrics.Evaluator方法的典型用法代码示例。如果您正苦于以下问题:Python metrics.Evaluator方法的具体用法?Python metrics.Evaluator怎么用?Python metrics.Evaluator使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类utils.metrics
的用法示例。
在下文中一共展示了metrics.Evaluator方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: from utils import metrics [as 别名]
# 或者: from utils.metrics import Evaluator [as 别名]
def __init__(self, args, model, train_set, val_set, test_set, class_weights, saver):
self.args = args
self.saver = saver
self.saver.save_experiment_config()
self.train_dataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
self.val_dataloader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
self.test_dataloader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
self.train_summary = TensorboardSummary(os.path.join(self.saver.experiment_dir, "train"))
self.train_writer = self.train_summary.create_summary()
self.val_summary = TensorboardSummary(os.path.join(self.saver.experiment_dir, "validation"))
self.val_writer = self.val_summary.create_summary()
self.model = model
self.dataset_size = {'train': len(train_set), 'val': len(val_set), 'test': len(test_set)}
train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
{'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]
if args.use_balanced_weights:
weight = torch.from_numpy(class_weights.astype(np.float32))
else:
weight = None
if args.optimizer == 'SGD':
print('Using SGD')
self.optimizer = torch.optim.SGD(train_params, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov)
elif args.optimizer == 'Adam':
print('Using Adam')
self.optimizer = torch.optim.Adam(train_params, weight_decay=args.weight_decay)
else:
raise NotImplementedError
self.lr_scheduler = None
if args.use_lr_scheduler:
if args.lr_scheduler == 'step':
print('Using step lr scheduler')
self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[int(x) for x in args.step_size.split(",")], gamma=0.1)
self.criterion = SegmentationLosses(weight=weight, ignore_index=255, cuda=args.cuda).build_loss(mode=args.loss_type)
self.evaluator = Evaluator(train_set.num_classes)
self.best_pred = 0.0
示例2: __init__
# 需要导入模块: from utils import metrics [as 别名]
# 或者: from utils.metrics import Evaluator [as 别名]
def __init__(self, args):
self.args = args
# Define Saver
self.saver = Saver(args)
self.saver.save_experiment_config()
# Define Tensorboard Summary
self.summary = TensorboardSummary(self.saver.experiment_dir)
self.writer = self.summary.create_summary()
# PATH = args.path
# Define Dataloader
kwargs = {'num_workers': args.workers, 'pin_memory': True}
self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
# Define network
model = SCNN(nclass=self.nclass,backbone=args.backbone,output_stride=args.out_stride,cuda = args.cuda)
# Define Optimizer
optimizer = torch.optim.SGD(model.parameters(),args.lr, momentum=args.momentum,
weight_decay=args.weight_decay, nesterov=args.nesterov)
# Define Criterion
weight = None
self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
self.model, self.optimizer = model, optimizer
# Define Evaluator
self.evaluator = Evaluator(self.nclass)
# Define lr scheduler
self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
args.epochs, len(self.train_loader))
# Using cuda
if args.cuda:
self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
# patch_replication_callback(self.model)
self.model = self.model.cuda()
# Resuming checkpoint
self.best_pred = 0.0
if args.resume is not None:
if not os.path.isfile(args.resume):
raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
if args.cuda:
self.model.module.load_state_dict(checkpoint['state_dict'])
else:
self.model.load_state_dict(checkpoint['state_dict'])
if not args.ft:
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.best_pred = checkpoint['best_pred']
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
示例3: forward_all
# 需要导入模块: from utils import metrics [as 别名]
# 或者: from utils.metrics import Evaluator [as 别名]
def forward_all(net_inference, dataloader, visualize=False, opt=None):
evaluator = Evaluator(21)
evaluator.reset()
with torch.no_grad():
for ii, sample in enumerate(dataloader):
image, label = sample['image'].cuda(), sample['label'].cuda()
activations = net_inference(image)
image = image.cpu().numpy()
label = label.cpu().numpy().astype(np.uint8)
logits = activations[list(activations.keys())[-1]] if type(activations) != torch.Tensor else activations
pred = torch.max(logits, 1)[1].cpu().numpy().astype(np.uint8)
evaluator.add_batch(label, pred)
# print(label.shape, pred.shape)
if visualize:
for jj in range(sample["image"].size()[0]):
segmap_label = decode_segmap(label[jj], dataset='pascal')
segmap_pred = decode_segmap(pred[jj], dataset='pascal')
img_tmp = np.transpose(image[jj], axes=[1, 2, 0])
img_tmp *= (0.229, 0.224, 0.225)
img_tmp += (0.485, 0.456, 0.406)
img_tmp *= 255.0
img_tmp = img_tmp.astype(np.uint8)
cv2.imshow('image', img_tmp[:, :, [2,1,0]])
cv2.imshow('gt', segmap_label)
cv2.imshow('pred', segmap_pred)
cv2.waitKey(0)
Acc = evaluator.Pixel_Accuracy()
Acc_class = evaluator.Pixel_Accuracy_Class()
mIoU = evaluator.Mean_Intersection_over_Union()
FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union()
print("Acc: {}".format(Acc))
print("Acc_class: {}".format(Acc_class))
print("mIoU: {}".format(mIoU))
print("FWIoU: {}".format(FWIoU))
if opt is not None:
with open("seg_result.txt", 'a+') as ww:
ww.write("{}, quant: {}, relu: {}, equalize: {}, absorption: {}, correction: {}, clip: {}, distill_range: {}\n".format(
opt.dataset, opt.quantize, opt.relu, opt.equalize, opt.absorption, opt.correction, opt.clip_weight, opt.distill_range
))
ww.write("Acc: {}, Acc_class: {}, mIoU: {}, FWIoU: {}\n\n".format(Acc, Acc_class, mIoU, FWIoU))