当前位置: 首页>>代码示例>>Python>>正文


Python torch.backends方法代码示例

本文整理汇总了Python中torch.backends方法的典型用法代码示例。如果您正苦于以下问题:Python torch.backends方法的具体用法?Python torch.backends怎么用?Python torch.backends使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.backends方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: import torch [as 别名]
# 或者: from torch import backends [as 别名]
def main(config):
    from torch.backends import cudnn
    # For fast training
    cudnn.benchmark = True

    data_loader = get_loader(
        config.mode_data,
        config.image_size,
        config.batch_size,
        config.dataset_fake,
        config.mode,
        num_workers=config.num_workers,
        all_attr=config.ALL_ATTR,
        c_dim=config.c_dim)

    from misc.scores import set_score
    if set_score(config):
        return

    if config.mode == 'train':
        from train import Train
        Train(config, data_loader)
        from test import Test
        test = Test(config, data_loader)
        test(dataset=config.dataset_real)

    elif config.mode == 'test':
        from test import Test
        test = Test(config, data_loader)
        if config.DEMO_PATH:
            test.DEMO(config.DEMO_PATH)
        else:
            test(dataset=config.dataset_real) 
开发者ID:BCV-Uniandes,项目名称:SMIT,代码行数:35,代码来源:main.py

示例2: main

# 需要导入模块: import torch [as 别名]
# 或者: from torch import backends [as 别名]
def main(args):
  np.random.seed(args.seed)
  torch.manual_seed(args.seed)
  torch.cuda.manual_seed(args.seed)
  torch.cuda.manual_seed_all(args.seed)
  cudnn.benchmark = True
  torch.backends.cudnn.deterministic = True

  args.cuda = args.cuda and torch.cuda.is_available()
  if args.cuda:
    print('using cuda.')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
  else:
    torch.set_default_tensor_type('torch.FloatTensor')
  
  # Create data loaders
  if args.height is None or args.width is None:
    args.height, args.width = (32, 100)

  dataset_info = DataInfo(args.voc_type)

  # Create model
  model = ModelBuilder(arch=args.arch, rec_num_classes=dataset_info.rec_num_classes,
                       sDim=args.decoder_sdim, attDim=args.attDim, max_len_labels=args.max_len,
                       eos=dataset_info.char2id[dataset_info.EOS], STN_ON=args.STN_ON)

  # Load from checkpoint
  if args.resume:
    checkpoint = load_checkpoint(args.resume)
    model.load_state_dict(checkpoint['state_dict'])

  if args.cuda:
    device = torch.device("cuda")
    model = model.to(device)
    model = nn.DataParallel(model)

  # Evaluation
  model.eval()
  img = image_process(args.image_path)
  with torch.no_grad():
    img = img.to(device)
  input_dict = {}
  input_dict['images'] = img.unsqueeze(0)
  # TODO: testing should be more clean.
  # to be compatible with the lmdb-based testing, need to construct some meaningless variables.
  rec_targets = torch.IntTensor(1, args.max_len).fill_(1)
  rec_targets[:,args.max_len-1] = dataset_info.char2id[dataset_info.EOS]
  input_dict['rec_targets'] = rec_targets
  input_dict['rec_lengths'] = [args.max_len]
  output_dict = model(input_dict)
  pred_rec = output_dict['output']['pred_rec']
  pred_str, _ = get_str_list(pred_rec, input_dict['rec_targets'], dataset=dataset_info)
  print('Recognition result: {0}'.format(pred_str[0])) 
开发者ID:ayumiymk,项目名称:aster.pytorch,代码行数:55,代码来源:demo.py


注:本文中的torch.backends方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。