當前位置: 首頁>>代碼示例>>Python>>正文


Python torch.backends方法代碼示例

本文整理匯總了Python中torch.backends方法的典型用法代碼示例。如果您正苦於以下問題:Python torch.backends方法的具體用法?Python torch.backends怎麽用?Python torch.backends使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch的用法示例。


在下文中一共展示了torch.backends方法的2個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: main

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import backends [as 別名]
def main(config):
    from torch.backends import cudnn
    # For fast training
    cudnn.benchmark = True

    data_loader = get_loader(
        config.mode_data,
        config.image_size,
        config.batch_size,
        config.dataset_fake,
        config.mode,
        num_workers=config.num_workers,
        all_attr=config.ALL_ATTR,
        c_dim=config.c_dim)

    from misc.scores import set_score
    if set_score(config):
        return

    if config.mode == 'train':
        from train import Train
        Train(config, data_loader)
        from test import Test
        test = Test(config, data_loader)
        test(dataset=config.dataset_real)

    elif config.mode == 'test':
        from test import Test
        test = Test(config, data_loader)
        if config.DEMO_PATH:
            test.DEMO(config.DEMO_PATH)
        else:
            test(dataset=config.dataset_real) 
開發者ID:BCV-Uniandes,項目名稱:SMIT,代碼行數:35,代碼來源:main.py

示例2: main

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import backends [as 別名]
def main(args):
  np.random.seed(args.seed)
  torch.manual_seed(args.seed)
  torch.cuda.manual_seed(args.seed)
  torch.cuda.manual_seed_all(args.seed)
  cudnn.benchmark = True
  torch.backends.cudnn.deterministic = True

  args.cuda = args.cuda and torch.cuda.is_available()
  if args.cuda:
    print('using cuda.')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
  else:
    torch.set_default_tensor_type('torch.FloatTensor')
  
  # Create data loaders
  if args.height is None or args.width is None:
    args.height, args.width = (32, 100)

  dataset_info = DataInfo(args.voc_type)

  # Create model
  model = ModelBuilder(arch=args.arch, rec_num_classes=dataset_info.rec_num_classes,
                       sDim=args.decoder_sdim, attDim=args.attDim, max_len_labels=args.max_len,
                       eos=dataset_info.char2id[dataset_info.EOS], STN_ON=args.STN_ON)

  # Load from checkpoint
  if args.resume:
    checkpoint = load_checkpoint(args.resume)
    model.load_state_dict(checkpoint['state_dict'])

  if args.cuda:
    device = torch.device("cuda")
    model = model.to(device)
    model = nn.DataParallel(model)

  # Evaluation
  model.eval()
  img = image_process(args.image_path)
  with torch.no_grad():
    img = img.to(device)
  input_dict = {}
  input_dict['images'] = img.unsqueeze(0)
  # TODO: testing should be more clean.
  # to be compatible with the lmdb-based testing, need to construct some meaningless variables.
  rec_targets = torch.IntTensor(1, args.max_len).fill_(1)
  rec_targets[:,args.max_len-1] = dataset_info.char2id[dataset_info.EOS]
  input_dict['rec_targets'] = rec_targets
  input_dict['rec_lengths'] = [args.max_len]
  output_dict = model(input_dict)
  pred_rec = output_dict['output']['pred_rec']
  pred_str, _ = get_str_list(pred_rec, input_dict['rec_targets'], dataset=dataset_info)
  print('Recognition result: {0}'.format(pred_str[0])) 
開發者ID:ayumiymk,項目名稱:aster.pytorch,代碼行數:55,代碼來源:demo.py


注:本文中的torch.backends方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。