当前位置: 首页>>代码示例>>Python>>正文


Python opts.opts方法代码示例

本文整理汇总了Python中opts.opts方法的典型用法代码示例。如果您正苦于以下问题:Python opts.opts方法的具体用法?Python opts.opts怎么用?Python opts.opts使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在opts的用法示例。


在下文中一共展示了opts.opts方法的13个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_stacked_hourglass

# 需要导入模块: import opts [as 别名]
# 或者: from opts import opts [as 别名]
def test_stacked_hourglass():
    level       = 5
    channels    = [256, 256, 384, 384, 384, 512]
    num_blocks  = [2, 2, 2, 2, 2, 4]
    num_stacks  = 2

    import sys
    sys.path.insert(0, "/export/guanghan/CenterNet-Gluon/")
    sys.path.insert(0, "/Users/guanghan.ning/Desktop/dev/CenterNet-Gluon/")
    from opts import opts
    opt = opts().init()
    print(opt.arch)
    print(opt.heads)

    blk = stacked_hourglass(level, num_stacks, channels, num_blocks, opt.heads)
    blk.initialize()
    X   = nd.random.uniform(shape=(1, 3, 512, 512))
    Y   = blk(X)
    print("\t Input shape: ", X.shape)
    print("\t output len:", len(Y)) 
开发者ID:Guanghan,项目名称:mxnet-centernet,代码行数:22,代码来源:hourglass.py

示例2: test_inference

# 需要导入模块: import opts [as 别名]
# 或者: from opts import opts [as 别名]
def test_inference():
    from opts import opts
    from detectors.pose_detector import PoseDetector
    opt = opts()
    opt = opt.init()
    detector = PoseDetector(opt)
    detector.model = load_model(detector.model, opt.pretrained_path, ctx = ctx)

    img_path = "/Users/guanghan.ning/Desktop/dev/CenterNet-Gluon/assets/demo.jpg"
    ret = detector.run(img_path)
    results[img_id] = ret['results']

    time_stats = ['tot', 'load', 'pre', 'net', 'dec', 'post', 'merge']
    avg_time_stats = {t: AverageMeter() for t in time_stats}
    print_message = ''
    for t in avg_time_stats:
        avg_time_stats[t].update(ret[t])
        print_message += '|{} {:.3f} '.format(t, avg_time_stats[t].avg)
    print(print_message)
    return 
开发者ID:Guanghan,项目名称:mxnet-centernet,代码行数:22,代码来源:test_decoder.py

示例3: main

# 需要导入模块: import opts [as 别名]
# 或者: from opts import opts [as 别名]
def main():
	opt = opts().parse()
	now = datetime.datetime.now()
	logger = Logger(opt.saveDir + '/logs_{}'.format(now.isoformat()))

	if opt.loadModel == 'none':

		model = inflate(opt).cuda()
	elif opt.loadModel == 'scratch':
		model = Pose3D(opt.nChannels, opt.nStack, opt.nModules, opt.numReductions, opt.nRegModules, opt.nRegFrames, ref.nJoints).cuda()
	else :
		model = torch.load(opt.loadModel).cuda()

	train_loader = torch.utils.data.DataLoader(
		h36m('train',opt),
		batch_size = opt.dataloaderSize,
		shuffle = False,
		num_workers = int(ref.nThreads)
	)

	optimizer = torch.optim.RMSprop(
		[{'params': model.parameters(), 'lr': opt.LRhg}], 
		alpha = ref.alpha, 
		eps = ref.epsilon, 
		weight_decay = ref.weightDecay, 
		momentum = ref.momentum
	)

	
	for epoch in range(1, opt.nEpochs + 1):
		loss_train, acc_train = train(epoch, opt, train_loader, model, optimizer)
		logger.scalar_summary('loss_train', loss_train, epoch)
		logger.scalar_summary('acc_train', acc_train, epoch)
		logger.write('{:8f} {:8f} \n'.format(loss_train, acc_train))

	logger.close() 
开发者ID:Naman-ntc,项目名称:3D-HourGlass-Network,代码行数:38,代码来源:overfit.py

示例4: inflate

# 需要导入模块: import opts [as 别名]
# 或者: from opts import opts [as 别名]
def inflate(opt = None):
	if opt is not None:
		model3d = Pose3D(opt.nChannels, opt.nStack, opt.nModules, opt.numReductions, opt.nRegModules, opt.nRegFrames, ref.nJoints, ref.temporal)
		Inflate.nChannels = opt.nChannels
		Inflate.nStack = opt.nStack
		Inflate.nModules = opt.nModules
		Inflate.nRegFrames = opt.nRegFrames
		Inflate.nJoints = ref.nJoints
		Inflate.scheme = opt.scheme
		Inflate.mult = opt.mult
	else :
		opt = opts().parse()
		Inflate.nChannels = opt.nChannels
		Inflate.nStack = opt.nStack
		Inflate.nModules = opt.nModules
		Inflate.nRegFrames = opt.nRegFrames
		Inflate.nJoints = ref.nJoints
		Inflate.scheme = opt.scheme
		model3d = Pose3D(opt.nChannels, opt.nStack, opt.nModules, opt.numReductions, opt.nRegModules, opt.nRegFrames, ref.nJoints, ref.temporal)
	pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
	pickle.load = partial(pickle.load, encoding="latin1")
	if opt is not None:
		model = torch.load(opt.Model2D)
	else:
		model = torch.load('models/xingy.pth') #, map_location=lambda storage, loc: storage)

	Inflate.inflatePose3D(model3d, model)

	torch.save(model3d,open('inflatedModel.pth','wb'))

	return model3d


#inflate() 
开发者ID:Naman-ntc,项目名称:3D-HourGlass-Network,代码行数:36,代码来源:inflateScript.py

示例5: main

# 需要导入模块: import opts [as 别名]
# 或者: from opts import opts [as 别名]
def main():
    # parse arguments, build exp dir
    opt = opts()
    args = opt.parse()
    train_op.initialize_experiment_directories(args)
    train_op.platform_specific_initialization(args)

    # build data loader
    train_loader,val_loader=buildDataset(args)

    # build learner
    lp = learnerParam()
    model=learner(args,lp)
    
    # build trainer and launch training
    mytrainer=trainer(
        model,
        train_loader,
        val_loader,
        max_epoch=200,
        )

    mytrainer.add_callbacks([PeriodicCallback(cb_loc=CallbackLoc.epoch_end,pstep=5,func=model.save_checkpoint)])
    mytrainer.add_callbacks([PeriodicCallback(cb_loc=CallbackLoc.epoch_end,pstep=5,func=model.evalPlot)])

    mytrainer.run() 
开发者ID:zhenpeiyang,项目名称:RelativePose,代码行数:28,代码来源:mainPanoCompletion2view.py

示例6: test_load

# 需要导入模块: import opts [as 别名]
# 或者: from opts import opts [as 别名]
def test_load():
    from opts import opts
    opt = opts().init()

    batch_size = 16
    #batchify_fn = Tuple(Stack(), Stack(), Stack(), Stack())  # stack image, heatmaps, scale, offset
    batchify_fn = Tuple(Stack(), Stack(), Stack(), Stack(), Stack(), Stack())  # stack image, heatmaps, scale, offset, ind, mask
    num_workers = 2

    train_dataset = CenterCOCODataset(opt, split = 'train')
    train_loader = gluon.data.DataLoader( train_dataset,
        batch_size, True, batchify_fn=batchify_fn, last_batch='rollover', num_workers=num_workers)
    ctx = [mx.gpu(int(i)) for i in opt.gpus_str.split(',') if i.strip()]
    ctx = ctx if ctx else [mx.cpu()]

    for i, batch in enumerate(train_loader):
        print("{} Batch".format(i))
        print("image batch shape: ", batch[0].shape)
        print("heatmap batch shape", batch[1].shape)
        print("scale batch shape", batch[2].shape)
        print("offset batch shape", batch[3].shape)
        print("indices batch shape", batch[4].shape)
        print("mask batch shape", batch[5].shape)

        X = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
        targets_heatmaps = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)  # heatmaps: (batch, num_classes, H/S, W/S)
        targets_scale = gluon.utils.split_and_load(batch[2], ctx_list=ctx, batch_axis=0)  # scale: wh (batch, 2, H/S, W/S)
        targets_offset = gluon.utils.split_and_load(batch[3], ctx_list=ctx, batch_axis=0) # offset: xy (batch, 2, H/s, W/S)
        targets_inds = gluon.utils.split_and_load(batch[4], ctx_list=ctx, batch_axis=0)
        targets_mask = gluon.utils.split_and_load(batch[5], ctx_list=ctx, batch_axis=0)

        print("len(targets_heatmaps): ", len(targets_heatmaps))
        print("First item: image shape: ", X[0].shape)
        print("First item: heatmaps shape: ", targets_heatmaps[0].shape)
        print("First item: scalemaps shape: ", targets_scale[0].shape)
        print("First item: offsetmaps shape: ", targets_offset[0].shape)
        print("First item: indices shape: ", targets_inds[0].shape)
        print("First item: mask shape: ", targets_mask[0].shape)
    return 
开发者ID:Guanghan,项目名称:mxnet-centernet,代码行数:41,代码来源:test_custom_dataloader.py

示例7: prefetch_test

# 需要导入模块: import opts [as 别名]
# 或者: from opts import opts [as 别名]
def prefetch_test(opt):
  os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str

  Dataset = dataset_factory[opt.dataset]
  opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
  print(opt)
  Logger(opt)
  Detector = detector_factory[opt.task]
  
  split = 'val' if not opt.trainval else 'test'
  dataset = Dataset(opt, split)
  detector = Detector(opt)
  
  data_loader = torch.utils.data.DataLoader(
    PrefetchDataset(opt, dataset, detector.pre_process), 
    batch_size=1, shuffle=False, num_workers=1, pin_memory=True)

  results = {}
  num_iters = len(dataset)
  bar = Bar('{}'.format(opt.exp_id), max=num_iters)
  time_stats = ['tot', 'load', 'pre', 'net', 'dec', 'post', 'merge']
  avg_time_stats = {t: AverageMeter() for t in time_stats}
  for ind, (img_id, pre_processed_images) in enumerate(data_loader):
    ret = detector.run(pre_processed_images)
    results[img_id.numpy().astype(np.int32)[0]] = ret['results']
    Bar.suffix = '[{0}/{1}]|Tot: {total:} |ETA: {eta:} '.format(
                   ind, num_iters, total=bar.elapsed_td, eta=bar.eta_td)
    for t in avg_time_stats:
      avg_time_stats[t].update(ret[t])
      Bar.suffix = Bar.suffix + '|{} {tm.val:.3f}s ({tm.avg:.3f}s) '.format(
        t, tm = avg_time_stats[t])
    bar.next()
  bar.finish()
  dataset.run_eval(results, opt.save_dir) 
开发者ID:CaoWGG,项目名称:CenterNet-CondInst,代码行数:36,代码来源:test.py

示例8: test

# 需要导入模块: import opts [as 别名]
# 或者: from opts import opts [as 别名]
def test(opt):
  os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str

  Dataset = dataset_factory[opt.dataset]
  opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
  print(opt)
  Logger(opt)
  Detector = detector_factory[opt.task]
  
  split = 'val' if not opt.trainval else 'test'
  dataset = Dataset(opt, split)
  detector = Detector(opt)

  results = {}
  num_iters = len(dataset)
  for ind in tqdm(range(num_iters)):
    img_id = dataset.images[ind]
    img_info = dataset.coco.loadImgs(ids=[img_id])[0]
    img_path = os.path.join(dataset.img_dir, img_info['file_name'])

    if opt.task == 'ddd':
      ret = detector.run(img_path, img_info['calib'])
    else:
      ret = detector.run(img_path)
    
    results[img_id] = ret['results']
  dataset.run_eval(results, opt.save_dir) 
开发者ID:CaoWGG,项目名称:CenterNet-CondInst,代码行数:29,代码来源:test.py

示例9: test

# 需要导入模块: import opts [as 别名]
# 或者: from opts import opts [as 别名]
def test(opt):
  os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str

  Dataset = dataset_factory[opt.dataset]
  opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
  print(opt)
  Logger(opt)
  Detector = detector_factory[opt.task]
  
  split = 'val' if not opt.trainval else 'test'
  dataset = Dataset(opt, split)
  detector = Detector(opt)

  results = {}
  num_iters = len(dataset)
  bar = Bar('{}'.format(opt.exp_id), max=num_iters)
  time_stats = ['tot', 'load', 'pre', 'net', 'dec', 'post', 'merge']
  avg_time_stats = {t: AverageMeter() for t in time_stats}
  for ind in range(num_iters):
    img_id = dataset.images[ind]
    img_info = dataset.coco.loadImgs(ids=[img_id])[0]
    img_path = os.path.join(dataset.img_dir, img_info['file_name'])

    if opt.task == 'ddd':
      ret = detector.run(img_path, img_info['calib'])
    else:
      ret = detector.run(img_path)
    
    results[img_id] = ret['results']

    Bar.suffix = '[{0}/{1}]|Tot: {total:} |ETA: {eta:} '.format(
                   ind, num_iters, total=bar.elapsed_td, eta=bar.eta_td)
    for t in avg_time_stats:
      avg_time_stats[t].update(ret[t])
      Bar.suffix = Bar.suffix + '|{} {:.3f} '.format(t, avg_time_stats[t].avg)
    bar.next()
  bar.finish()
  dataset.run_eval(results, opt.save_dir) 
开发者ID:kimyoon-young,项目名称:centerNet-deep-sort,代码行数:40,代码来源:test.py

示例10: main

# 需要导入模块: import opts [as 别名]
# 或者: from opts import opts [as 别名]
def main():
  opt = opts().parse()
  if opt.loadModel == '':
    opt.loadModel = '../models/Pascal3D-cpu.pth'
  model = torch.load(opt.loadModel)
  img = cv2.imread(opt.demo)
  s = max(img.shape[0], img.shape[1]) * 1.0
  c = np.array([img.shape[1] / 2., img.shape[0] / 2.])
  img = Crop(img, c, s, 0, ref.inputRes).astype(np.float32).transpose(2, 0, 1) / 256.
  input = torch.from_numpy(img.copy()).float()
  input = input.view(1, input.size(0), input.size(1), input.size(2))
  input_var = torch.autograd.Variable(input).float()
  if opt.GPU > -1:
    model = model.cuda(opt.GPU)
    input_var = input_var.cuda(opt.GPU)
  
  output = model(input_var)
  hm = output[-1].data.cpu().numpy()
  
  debugger = Debugger()
  img = (input[0].numpy().transpose(1, 2, 0)*256).astype(np.uint8).copy()
  inp = img.copy()
  star = (cv2.resize(hm[0, 0], (ref.inputRes, ref.inputRes)) * 255)
  star[star > 255] = 255
  star[star < 0] = 0
  star = np.tile(star, (3, 1, 1)).transpose(1, 2, 0)
  trans = 0.8
  star = (trans * star + (1. - trans) * img).astype(np.uint8)

   
  ps = parseHeatmap(hm[0], thresh = 0.1)
  canonical, pred, color, score = [], [], [], []
  for k in range(len(ps[0])):
    x, y, z = ((hm[0, 1:4, ps[0][k], ps[1][k]] + 0.5) * ref.outputRes).astype(np.int32)
    dep = ((hm[0, 4, ps[0][k], ps[1][k]] + 0.5) * ref.outputRes).astype(np.int32)
    canonical.append([x, y, z])
    pred.append([ps[1][k], ref.outputRes - dep, ref.outputRes - ps[0][k]])
    score.append(hm[0, 0, ps[0][k], ps[1][k]])
    color.append((1.0 * x / ref.outputRes, 1.0 * y / ref.outputRes, 1.0 * z / ref.outputRes))
    cv2.circle(img, (ps[1][k] * 4, ps[0][k] * 4), 4, (255, 255, 255), -1)
    cv2.circle(img, (ps[1][k] * 4, ps[0][k] * 4), 2, (int(z * 4), int(y * 4), int(x * 4)), -1)
  
  pred = np.array(pred).astype(np.float32)
  canonical = np.array(canonical).astype(np.float32)
  
  pointS = canonical * 1.0 / ref.outputRes
  pointT = pred * 1.0 / ref.outputRes
  R, t, s = horn87(pointS.transpose(), pointT.transpose(), score)
  
  rotated_pred = s * np.dot(R, canonical.transpose()).transpose() + t * ref.outputRes

  debugger.addImg(inp, 'inp')
  debugger.addImg(star, 'star')
  debugger.addImg(img, 'nms')
  debugger.addPoint3D(canonical / ref.outputRes - 0.5, c = color, marker = '^')
  debugger.addPoint3D(pred / ref.outputRes - 0.5, c = color, marker = 'x')
  debugger.addPoint3D(rotated_pred / ref.outputRes - 0.5, c = color, marker = '*')

  debugger.showAllImg(pause = True)
  debugger.show3D() 
开发者ID:xingyizhou,项目名称:StarMap,代码行数:62,代码来源:demo.py

示例11: test_load

# 需要导入模块: import opts [as 别名]
# 或者: from opts import opts [as 别名]
def test_load():
    from opts import opts
    opt = opts().init()

    batch_size = 16
    # inp, hm, wh, reg, dep, dim, rotbin, rotres, ind, reg_mask, rot_mask, meta
    batchify_fn = Tuple(Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Stack())
    num_workers = 2

    train_dataset = CenterKITTIDataset(opt, split = 'train')
    train_loader = gluon.data.DataLoader( train_dataset,
        batch_size, True, batchify_fn=batchify_fn, last_batch='rollover', num_workers=num_workers)
    ctx = [mx.gpu(int(i)) for i in opt.gpus_str.split(',') if i.strip()]
    ctx = ctx if ctx else [mx.cpu()]

    for i, batch in enumerate(train_loader):
        print("{} Batch".format(i))
        print("image batch shape: ", batch[0].shape)
        print("center batch shape", batch[1].shape)

        print("2d wh batch shape", batch[2].shape)
        print("2d offset batch shape", batch[3].shape)

        print("3d depth batch shape", batch[4].shape)
        print("3d dimension batch shape", batch[5].shape)
        print("3d rotbin batch shape", batch[6].shape)
        print("3d rotres batch shape", batch[7].shape)

        print("indices batch shape", batch[8].shape)
        print("2d offset mask batch shape", batch[9].shape)
        print("3d rotation mask batch shape", batch[10].shape)

        X = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
        targets_center = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)  # heatmaps: (batch, num_classes, H/S, W/S)
        targets_2d_wh = gluon.utils.split_and_load(batch[2], ctx_list=ctx, batch_axis=0)  # scale: wh (batch, 2, H/S, W/S)
        targets_2d_offset = gluon.utils.split_and_load(batch[3], ctx_list=ctx, batch_axis=0) # offset: xy (batch, 2, H/s, W/S)

        targets_3d_depth = gluon.utils.split_and_load(batch[4], ctx_list=ctx, batch_axis=0)
        targets_3d_dim = gluon.utils.split_and_load(batch[5], ctx_list=ctx, batch_axis=0)
        targets_3d_rotbin = gluon.utils.split_and_load(batch[6], ctx_list=ctx, batch_axis=0)
        targets_3d_rotres = gluon.utils.split_and_load(batch[7], ctx_list=ctx, batch_axis=0)

        targets_inds = gluon.utils.split_and_load(batch[8], ctx_list=ctx, batch_axis=0)
        targets_2d_wh_mask = gluon.utils.split_and_load(batch[9], ctx_list=ctx, batch_axis=0)
        targets_3d_rot_mask = gluon.utils.split_and_load(batch[10], ctx_list=ctx, batch_axis=0)

        print("len(targets_center): ", len(targets_center))
        print("First item: image shape: ", X[0].shape)
        print("First item: center heatmap shape: ", targets_center[0].shape)
    return 
开发者ID:Guanghan,项目名称:mxnet-centernet,代码行数:52,代码来源:test_custom_dataloader_3dod.py

示例12: test_load

# 需要导入模块: import opts [as 别名]
# 或者: from opts import opts [as 别名]
def test_load():
    from opts import opts
    opt = opts().init()

    batch_size = 16
    # inp, hm, wh, reg, dep, dim, rotbin, rotres, ind, reg_mask, rot_mask, meta
    batchify_fn = Tuple(Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Stack())
    num_workers = 2

    train_dataset = CenterMultiPoseDataset(opt, split = 'train')
    train_loader = gluon.data.DataLoader( train_dataset,
        batch_size, True, batchify_fn=batchify_fn, last_batch='rollover', num_workers=num_workers)
    ctx = [mx.gpu(int(i)) for i in opt.gpus_str.split(',') if i.strip()]
    ctx = ctx if ctx else [mx.cpu()]

    '''
           inp, ind, \
           hm, wh, reg, reg_mask, \         # 2d detection: center, wh, offset
           kps, kps_mask, \                 # 2d pose: joint locations relative to center
           hm_hp, hp_offset, hp_ind, hp_mask  # 2d pose: joint heapmaps (and offset to compoensate discretization)
    '''

    for i, batch in enumerate(train_loader):
        print("{} Batch".format(i))
        print("image batch shape: ", batch[0].shape)
        print("indices batch shape", batch[1].shape)

        print("center batch shape", batch[2].shape)
        print("2d wh batch shape", batch[3].shape)
        print("2d offset batch shape", batch[4].shape)
        print("2d offset mask batch shape", batch[5].shape)

        print("pose relative to center batch shape", batch[6].shape)
        print("pose relative to center mask batch shape", batch[7].shape)

        print("pose heatmap batch shape", batch[8].shape)
        print("pose heatmap offset batch shape", batch[9].shape)
        print("pose heatmap ind shape", batch[10].shape)
        print("pose heatmap mask batch shape", batch[11].shape)

        X = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
        targets_inds = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)

        targets_center = gluon.utils.split_and_load(batch[2], ctx_list=ctx, batch_axis=0)  # heatmaps: (batch, num_classes, H/S, W/S)
        targets_2d_wh = gluon.utils.split_and_load(batch[3], ctx_list=ctx, batch_axis=0)  # scale: wh (batch, 2, H/S, W/S)
        targets_2d_offset = gluon.utils.split_and_load(batch[4], ctx_list=ctx, batch_axis=0) # offset: xy (batch, 2, H/s, W/S)
        targets_2d_wh_mask = gluon.utils.split_and_load(batch[5], ctx_list=ctx, batch_axis=0)

        targets_poserel = gluon.utils.split_and_load(batch[6], ctx_list=ctx, batch_axis=0)
        targets_poserel_mask = gluon.utils.split_and_load(batch[7], ctx_list=ctx, batch_axis=0)

        targets_posemap = gluon.utils.split_and_load(batch[8], ctx_list=ctx, batch_axis=0)
        targets_posemap_offset = gluon.utils.split_and_load(batch[9], ctx_list=ctx, batch_axis=0)
        targets_posemap_ind = gluon.utils.split_and_load(batch[10], ctx_list=ctx, batch_axis=0)
        targets_posemap_mask = gluon.utils.split_and_load(batch[11], ctx_list=ctx, batch_axis=0)

        print("len(targets_center): ", len(targets_center))
        print("First item: image shape: ", X[0].shape)
        print("First item: center heatmap shape: ", targets_center[0].shape)
    return 
开发者ID:Guanghan,项目名称:mxnet-centernet,代码行数:62,代码来源:test_custom_dataloader_2dpose.py

示例13: main

# 需要导入模块: import opts [as 别名]
# 或者: from opts import opts [as 别名]
def main():
    opt = opts().parse()
    now = datetime.datetime.now()
    logger = Logger(opt.saveDir, now.isoformat())
    model, optimizer = getModel(opt)
    criterion = torch.nn.MSELoss().cuda()

    # if opt.GPU > -1:
    #     print('Using GPU {}',format(opt.GPU))
    #     model = model.cuda(opt.GPU)
    #     criterion = criterion.cuda(opt.GPU)
    # dev = opt.device
    model = model.cuda()

    val_loader = torch.utils.data.DataLoader(
            MPII(opt, 'val'), 
            batch_size = 1, 
            shuffle = False,
            num_workers = int(ref.nThreads)
    )

    if opt.test:
        log_dict_train, preds = val(0, opt, val_loader, model, criterion)
        sio.savemat(os.path.join(opt.saveDir, 'preds.mat'), mdict = {'preds': preds})
        return
    # pyramidnet pretrain一次,先定义gen的训练数据loader
    train_loader = torch.utils.data.DataLoader(
            MPII(opt, 'train'), 
            batch_size = opt.trainBatch, 
            shuffle = True if opt.DEBUG == 0 else False,
            num_workers = int(ref.nThreads)
    )
    # 调用train方法
    for epoch in range(1, opt.nEpochs + 1):
        log_dict_train, _ = train(epoch, opt, train_loader, model, criterion, optimizer)
        for k, v in log_dict_train.items():
            logger.scalar_summary('train_{}'.format(k), v, epoch)
            logger.write('{} {:8f} | '.format(k, v))
        if epoch % opt.valIntervals == 0:
            log_dict_val, preds = val(epoch, opt, val_loader, model, criterion)
            for k, v in log_dict_val.items():
                logger.scalar_summary('val_{}'.format(k), v, epoch)
                logger.write('{} {:8f} | '.format(k, v))
            #saveModel(model, optimizer, os.path.join(opt.saveDir, 'model_{}.checkpoint'.format(epoch)))
            torch.save(model, os.path.join(opt.saveDir, 'model_{}.pth'.format(epoch)))
            sio.savemat(os.path.join(opt.saveDir, 'preds_{}.mat'.format(epoch)), mdict = {'preds': preds})
        logger.write('\n')
        if epoch % opt.dropLR == 0:
            lr = opt.LR * (0.1 ** (epoch // opt.dropLR))
            print('Drop LR to {}'.format(lr))
            adjust_learning_rate(optimizer, lr)
    logger.close()
    torch.save(model.cpu(), os.path.join(opt.saveDir, 'model_cpu.pth')) 
开发者ID:IcewineChen,项目名称:pytorch-PyraNet,代码行数:55,代码来源:main.py


注:本文中的opts.opts方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。