當前位置: 首頁>>代碼示例>>Python>>正文


Python resnet.resnet18方法代碼示例

本文整理匯總了Python中models.resnet.resnet18方法的典型用法代碼示例。如果您正苦於以下問題:Python resnet.resnet18方法的具體用法?Python resnet.resnet18怎麽用?Python resnet.resnet18使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在models.resnet的用法示例。


在下文中一共展示了resnet.resnet18方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: get_model_param

# 需要導入模塊: from models import resnet [as 別名]
# 或者: from models.resnet import resnet18 [as 別名]
def get_model_param(args):
    # assert args.model in ['resnet', 'vgg']

    if args.model == 'resnet':
        assert args.model_depth in [18, 34, 50, 101, 152]

        from models.resnet import get_fine_tuning_parameters

        if args.model_depth == 18:
            model = resnet.resnet18(pretrained=False, input_size=args.input_size, num_classes=args.n_classes)
        elif args.model_depth == 34:
            model = resnet.resnet34(pretrained=False, input_size=args.input_size, num_classes=args.n_classes)
        elif args.model_depth == 50:
            model = resnet.resnet50(pretrained=False, input_size=args.input_size, num_classes=args.n_classes)
        elif args.model_depth == 101:
            model = resnet.resnet101(pretrained=False, input_size=args.input_size, num_classes=args.n_classes)
        elif args.model_depth == 152:
            model = resnet.resnet152(pretrained=False, input_size=args.input_size, num_classes=args.n_classes)

    # elif args.model == 'vgg':
    #     pass

    # Load pretrained model here
    if args.finetune:
        pretrained_model = model_path[args.arch]
        args.pretrain_path = os.path.join(args.root_path, 'pretrained_models', pretrained_model)
        print("=> loading pretrained model '{}'...".format(pretrained_model))

        model.load_state_dict(torch.load(args.pretrain_path))

        # Only modify the last layer
        if args.model == 'resnet':
            model.fc = nn.Linear(model.fc.in_features, args.n_finetune_classes)
        # elif args.model == 'vgg':
        #     pass

        parameters = get_fine_tuning_parameters(model, args.ft_begin_index, args.lr_mult1, args.lr_mult2)
        return model, parameters

    return model, model.parameters() 
開發者ID:husencd,項目名稱:DriverPostureClassification,代碼行數:42,代碼來源:model.py

示例2: get_net

# 需要導入模塊: from models import resnet [as 別名]
# 或者: from models.resnet import resnet18 [as 別名]
def get_net(num_classes=None):  # pylint: disable=missing-docstring
  architecture = FLAGS.architecture
  task = FLAGS.task

  if "resnet18" in architecture:
    net = resnet.resnet18
  elif "resnet34" in architecture:
    net = resnet.resnet34
  elif "resnet50" in architecture or "resnext50" in architecture:
    net = resnet.resnet50
  elif "resnet101" in architecture or "resnext101" in architecture:
    net = resnet.resnet101
  elif "resnet152" in architecture or "resnext152" in architecture:
    net = resnet.resnet152
  elif "revnet18" in architecture:
    net = resnet.revnet18
  elif "revnet34" in architecture:
    net = resnet.revnet34
  elif "revnet50" in architecture:
    net = resnet.revnet50
  elif "revnet101" in architecture:
    net = resnet.revnet101
  elif "revnet152" in architecture:
    net = resnet.revnet152
  else:
    raise ValueError("Unsupported architecture: %s" % architecture)

  net = functools.partial(net, filters_factor=FLAGS.filters_factor, mode="v2")

  if "resnext" in architecture:
    net = functools.partial(net, groups=32)

  # Few things that are common across all models.
  net = functools.partial(
      net, num_classes=num_classes,
      weight_decay=FLAGS.weight_decay)

  return net 
開發者ID:google-research,項目名稱:s4l,代碼行數:40,代碼來源:utils.py

示例3: generate_model

# 需要導入模塊: from models import resnet [as 別名]
# 或者: from models.resnet import resnet18 [as 別名]
def generate_model(opt):
    assert opt.model in ['resnet', 'alexnet']

    if opt.model=='alexnet':
        model=deepSBD()
    elif opt.model == 'resnet':
        from models.resnet import get_fine_tuning_parameters

        model = resnet.resnet18(num_classes=opt.n_classes,
                                sample_size=opt.sample_size, sample_duration=opt.sample_duration)
    else:
        raise Exception("Unknown model name")

    return model 
開發者ID:Tangshitao,項目名稱:ClipShots_basline,代碼行數:16,代碼來源:__init__.py


注:本文中的models.resnet.resnet18方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。