当前位置: 首页>>代码示例>>Python>>正文


Python resnet.ResNet方法代码示例

本文整理汇总了Python中resnet.ResNet方法的典型用法代码示例。如果您正苦于以下问题:Python resnet.ResNet方法的具体用法?Python resnet.ResNet怎么用?Python resnet.ResNet使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在resnet的用法示例。


在下文中一共展示了resnet.ResNet方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: import resnet [as 别名]
# 或者: from resnet import ResNet [as 别名]
def main(argv=None):  # pylint: disable=unused-argument
    assert args.detect or args.segment, "Either detect or segment should be True"
    assert args.ckpt > 0, "Specify the number of checkpoint"
    net = ResNet(config=net_config, depth=50, training=False)
    loader = Loader(osp.join(EVAL_DIR, 'demodemo'))


    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False)) as sess:
        detector = Detector(sess, net, loader, net_config, no_gt=args.no_seg_gt,
                            folder=osp.join(loader.folder, 'output'))
        detector.restore_from_ckpt(args.ckpt)
        for name in loader.get_filenames():
            image = loader.load_image(name)
            h, w = image.shape[:2]
            print('Processing {}'.format(name + loader.data_format))
            detector.feed_forward(img=image, name=name, w=w, h=h, draw=True,
                                  seg_gt=None, gt_bboxes=None, gt_cats=None)
    print('Done') 
开发者ID:dvornikita,项目名称:blitznet,代码行数:21,代码来源:demo.py

示例2: __init__

# 需要导入模块: import resnet [as 别名]
# 或者: from resnet import ResNet [as 别名]
def __init__(self, num_kpt=7, image_size=(80, 80), onnx_mode=False, init_weight=True):
        super(KeypointNet, self).__init__()
        net_size = 16

        self.conv = nn.Conv2d(in_channels=3, out_channels=net_size, kernel_size=7, stride=1, padding=3)
        # torch.nn.init.xavier_uniform(self.conv.weight)
        self.bn = nn.BatchNorm2d(net_size)
        self.relu = nn.ReLU()
        self.res1 = ResNet(net_size, net_size)
        self.res2 = ResNet(net_size, net_size * 2)
        self.res3 = ResNet(net_size * 2, net_size * 4)
        self.res4 = ResNet(net_size * 4, net_size * 8)
        self.out = nn.Conv2d(in_channels=net_size * 8, out_channels=num_kpt, kernel_size=1, stride=1, padding=0)
        # torch.nn.init.xavier_uniform(self.out.weight)
        if init_weight:
            self._initialize_weights()
        self.image_size = image_size
        self.num_kpt = num_kpt
        self.onnx_mode = onnx_mode 
开发者ID:cv-core,项目名称:MIT-Driverless-CV-TrainingInfra,代码行数:21,代码来源:keypoint_net.py

示例3: init_detectot

# 需要导入模块: import resnet [as 别名]
# 或者: from resnet import ResNet [as 别名]
def init_detectot(self):
        assert args.detect or args.segment, "Either detect or segment should be True"
        assert args.ckpt > 0, "Specify the number of checkpoint"
        net = ResNet(config=net_config, depth=50, training=False)
        self.loader = Loader(opj(EVAL_DIR, 'demodemo'))
        self.detector = Detector(self.sess, net, self.loader, net_config, no_gt=args.no_seg_gt,
                                 folder=opj(self.loader.folder, 'output'))
        self.detector.restore_from_ckpt(args.ckpt) 
开发者ID:dvornikita,项目名称:blitznet,代码行数:10,代码来源:main.py

示例4: net_select

# 需要导入模块: import resnet [as 别名]
# 或者: from resnet import ResNet [as 别名]
def net_select(name, data_format='NCHW', weight_decay=5e-4):
  if name == 'SphereNet':
    from sphere import SphereNet
    network = SphereNet(data_format=data_format, 
                        weight_decay=weight_decay)  
  elif name == 'ResNeXt-26':
    from resnext import ResNeXt
    network = ResNeXt(num_layers=26, num_card=32, 
                      data_format=data_format, 
                      weight_decay=weight_decay)
  elif name == 'ResNet-50':
    from resnet import ResNet
    network = ResNet(num_layers=50, 
                     data_format=data_format, 
                     weight_decay=weight_decay)
  elif name == 'ShuffleNet-v2-small':
    from shufflenet_v2 import ShuffleNet_v2_small
    network = ShuffleNet_v2_small(alpha=2.0, 
                                  se=False, residual=False,
                                  data_format=data_format, 
                                  weight_decay=weight_decay)
  elif name == 'ShuffleNet-v2-middle':
    from shufflenet_v2 import ShuffleNet_v2_middle
    network = ShuffleNet_v2_middle(se=False, residual=False,
                                   data_format=data_format, 
                                   weight_decay=weight_decay)
  elif name == 'ShuffleNet-v2-large':
    from shufflenet_v2 import ShuffleNet_v2_large
    network = ShuffleNet_v2_large(data_format=data_format, 
                                  weight_decay=weight_decay)
  elif name == 'MobileNet-v2':
    pass
  elif name == 'Inception-v4':
    pass
  elif name == 'VGG16':
    pass
  elif name == 'AlexNet':
    pass
  else:
    raise ValueError('Unsupport network architecture.')

  return network 
开发者ID:medivhna,项目名称:TF_Face_Toolbox,代码行数:44,代码来源:net_base.py


注:本文中的resnet.ResNet方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。