当前位置: 首页>>代码示例>>Python>>正文


Python caffe.TRAIN属性代码示例

本文整理汇总了Python中caffe.TRAIN属性的典型用法代码示例。如果您正苦于以下问题:Python caffe.TRAIN属性的具体用法?Python caffe.TRAIN怎么用?Python caffe.TRAIN使用的例子?那么恭喜您, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在caffe的用法示例。


在下文中一共展示了caffe.TRAIN属性的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: add_batchnormscale

# 需要导入模块: import caffe [as 别名]
# 或者: from caffe import TRAIN [as 别名]
def add_batchnormscale(self, input, name):

        if True: # necessary?
            batch_norm_param = {'moving_average_fraction': 0.95, 'use_global_stats': True}
            param = [dict(lr_mult=0), dict(lr_mult=0), dict(lr_mult=0)]
            l = L.BatchNorm(input, name=name + '_bn', batch_norm_param=batch_norm_param, param=param, include={'phase': caffe.TEST}, ntop=1)
            setattr(self.net_spec, name + '_bn', l)

            batch_norm_param = {'moving_average_fraction': 0.95, 'use_global_stats': False}
            l = L.BatchNorm(input, name=name + '_bn', top=name + '_bn', batch_norm_param=batch_norm_param, param=param, include={'phase': caffe.TRAIN}, ntop=0)
            setattr(self.net_spec, name + '_bn' + '_train', l)

            l = L.Scale(getattr(self.net_spec, name + '_bn'), scale_param={'bias_term': True})
            setattr(self.net_spec, name, l)
        else: # here without split in use_global_stats True/False
            l = L.Scale(L.BatchNorm(input), scale_param={'bias_term': True})
            setattr(self.net_spec, name, l)

        return l 
开发者ID:peterneher,项目名称:peters-stuff,代码行数:21,代码来源:CaffeUNet_3D.py

示例2: add_batchnormscale

# 需要导入模块: import caffe [as 别名]
# 或者: from caffe import TRAIN [as 别名]
def add_batchnormscale(self, input, name):

        if True : # necessary?
            batch_norm_param={'moving_average_fraction': 0.95, 'use_global_stats': True }
            param = [dict(lr_mult=0),dict(lr_mult=0),dict(lr_mult=0)]
            l = L.BatchNorm(input, name=name+'_bn', batch_norm_param=batch_norm_param, param=param, include={'phase': caffe.TEST}, ntop=1)
            setattr(self.net_spec, name+'_bn', l)

            batch_norm_param={'moving_average_fraction': 0.95, 'use_global_stats': False }
            l = L.BatchNorm(input, name=name+'_bn', top=name+'_bn', batch_norm_param=batch_norm_param, param=param, include={'phase': caffe.TRAIN}, ntop=0)
            setattr(self.net_spec, name+'_bn' + '_train', l)

            l = L.Scale(getattr(self.net_spec, name+'_bn'), scale_param = { 'bias_term': True } )
            setattr(self.net_spec, name, l)
        else : # here without split in use_global_stats True/False
            l = L.Scale(L.BatchNorm(input), scale_param={'bias_term': True})
            setattr(self.net_spec, name, l)

        return l 
开发者ID:peterneher,项目名称:peters-stuff,代码行数:21,代码来源:CaffeUNet_2D.py

示例3: response_to_lmdb

# 需要导入模块: import caffe [as 别名]
# 或者: from caffe import TRAIN [as 别名]
def response_to_lmdb(fpath_net,
                     fpath_weights,
                     keys,
                     dst_prefix,
                     modes=None,
                     ):
    """
    keys -- name of responses to extract. Must be valid for all requested modes
    """
    modes = modes or [caffe.TRAIN, caffe.TEST]
    out = dict.fromkeys(modes)

    for m in modes:
        num_passes = est_min_num_fwd_passes(fpath_net, ['train', 'test'][m])
        out[m] = infer_to_lmdb(caffe.Net(fpath_net, fpath_weights, m),
                               keys,
                               num_passes,
                               dst_prefix + '%s_' + ['train', 'test'][m] + '_lmdb')
    return out 
开发者ID:nigroup,项目名称:nideep,代码行数:21,代码来源:inference.py

示例4: test_save_and_read

# 需要导入模块: import caffe [as 别名]
# 或者: from caffe import TRAIN [as 别名]
def test_save_and_read(self):
        f = tempfile.NamedTemporaryFile(mode='w+', delete=False)
        f.close()
        self.net.save(f.name)
        net_file = simple_net_file(self.num_output)
        # Test legacy constructor
        #   should print deprecation warning
        caffe.Net(net_file, f.name, caffe.TRAIN)
        # Test named constructor
        net2 = caffe.Net(net_file, caffe.TRAIN, weights=f.name)
        os.remove(net_file)
        os.remove(f.name)
        for name in self.net.params:
            for i in range(len(self.net.params[name])):
                self.assertEqual(abs(self.net.params[name][i].data
                    - net2.params[name][i].data).sum(), 0) 
开发者ID:QinganZhao,项目名称:Deep-Learning-Based-Structural-Damage-Detection,代码行数:18,代码来源:test_net.py

示例5: print_network_sizes

# 需要导入模块: import caffe [as 别名]
# 或者: from caffe import TRAIN [as 别名]
def print_network_sizes(model_file) :
    net = caffe.Net(model_file, caffe.TRAIN)
    for k, v in net.blobs.items():
        print k, v.data.shape 
开发者ID:peterneher,项目名称:peters-stuff,代码行数:6,代码来源:unet_segmentation_no_db_example.py

示例6: print_network_sizes

# 需要导入模块: import caffe [as 别名]
# 或者: from caffe import TRAIN [as 别名]
def print_network_sizes(self, model_file):

        net = caffe.Net(model_file, caffe.TRAIN)
        for k, v in net.blobs.items():
            print k, v.data.shape 
开发者ID:peterneher,项目名称:peters-stuff,代码行数:7,代码来源:CaffeUNet_3D.py

示例7: parse_args

# 需要导入模块: import caffe [as 别名]
# 或者: from caffe import TRAIN [as 别名]
def parse_args():
    """Parse input arguments
    """

    parser = ArgumentParser(description=__doc__,
                            formatter_class=ArgumentDefaultsHelpFormatter)

    parser.add_argument('--input_net_proto_file',
                        help='Input network prototxt file')
    parser.add_argument('--output_image_file',
                        help='Output image file')
    parser.add_argument('--rankdir',
                        help=('One of TB (top-bottom, i.e., vertical), '
                              'RL (right-left, i.e., horizontal), or another '
                              'valid dot option; see '
                              'http://www.graphviz.org/doc/info/'
                              'attrs.html#k:rankdir'),
                        default='LR')
    parser.add_argument('--phase',
                        help=('Which network phase to draw: can be TRAIN, '
                              'TEST, or ALL.  If ALL, then all layers are drawn '
                              'regardless of phase.'),
                        default="ALL")

    args = parser.parse_args()
    return args 
开发者ID:CUHKSZ-TQL,项目名称:EverybodyDanceNow_reproduce_pytorch,代码行数:28,代码来源:draw_caffe_net.py

示例8: main

# 需要导入模块: import caffe [as 别名]
# 或者: from caffe import TRAIN [as 别名]
def main():
    args = parse_args()
    net = caffe_pb2.NetParameter()
    text_format.Merge(open(args.input_net_proto_file).read(), net)
    print('Drawing net to %s' % args.output_image_file)
    phase=None;
    if args.phase == "TRAIN":
        phase = caffe.TRAIN
    elif args.phase == "TEST":
        phase = caffe.TEST
    elif args.phase != "ALL":
        raise ValueError("Unknown phase: " + args.phase)
    caffe.draw.draw_net_to_file(net, args.output_image_file, args.rankdir,
                                phase) 
开发者ID:CUHKSZ-TQL,项目名称:EverybodyDanceNow_reproduce_pytorch,代码行数:16,代码来源:draw_caffe_net.py

示例9: setUp

# 需要导入模块: import caffe [as 别名]
# 或者: from caffe import TRAIN [as 别名]
def setUp(self):
        self.num_output = 13
        net_file = simple_net_file(self.num_output)
        self.net = caffe.Net(net_file, caffe.TRAIN)
        # fill in valid labels
        self.net.blobs['label'].data[...] = \
                np.random.randint(self.num_output,
                    size=self.net.blobs['label'].data.shape)
        os.remove(net_file) 
开发者ID:XiaohangZhan,项目名称:mix-and-match,代码行数:11,代码来源:test_net.py

示例10: test_save_and_read

# 需要导入模块: import caffe [as 别名]
# 或者: from caffe import TRAIN [as 别名]
def test_save_and_read(self):
        f = tempfile.NamedTemporaryFile(delete=False)
        f.close()
        self.net.save(f.name)
        net_file = simple_net_file(self.num_output)
        net2 = caffe.Net(net_file, f.name, caffe.TRAIN)
        os.remove(net_file)
        os.remove(f.name)
        for name in self.net.params:
            for i in range(len(self.net.params[name])):
                self.assertEqual(abs(self.net.params[name][i].data
                    - net2.params[name][i].data).sum(), 0) 
开发者ID:XiaohangZhan,项目名称:mix-and-match,代码行数:14,代码来源:test_net.py

示例11: setUp

# 需要导入模块: import caffe [as 别名]
# 或者: from caffe import TRAIN [as 别名]
def setUp(self):
        net_file = python_net_file()
        self.net = caffe.Net(net_file, caffe.TRAIN)
        os.remove(net_file) 
开发者ID:XiaohangZhan,项目名称:mix-and-match,代码行数:6,代码来源:test_python_layer.py

示例12: est_min_num_fwd_passes

# 需要导入模块: import caffe [as 别名]
# 或者: from caffe import TRAIN [as 别名]
def est_min_num_fwd_passes(fpath_net, mode_str, key=None):
    """
    if multiple source for same mode, base num_passes on last
    fpath_net -- path to network definition
    mode_str -- train or test?

    return
    minimum no. of forward passes to cover training set
    """
    from nideep.proto.proto_utils import Parser
    mode_num = {'train' : caffe.TRAIN,
                'test' : caffe.TEST}[mode_str]
    np = Parser().from_net_params_file(fpath_net)
    num_passes_each = []
    for l in np.layer:
        if 'data' in l.type.lower():
            if ('hdf5data' in l.type.lower() and
                    (mode_str.lower() in l.hdf5_data_param.source.lower() or
                        [x.phase for x in l.include] == [mode_num])):
                num_entries = CreateDatasource.from_path(l.hdf5_data_param.source, key=key).num_entries()
                num_passes = int(num_entries / l.hdf5_data_param.batch_size)
                if num_entries % l.hdf5_data_param.batch_size != 0:
                    logger.warning("db size not a multiple of batch size. Adding another fwd. pass.")
                    num_passes += 1
                logger.info("%d fwd. passes with batch size %d" % (num_passes, l.hdf5_data_param.batch_size))
                num_passes_each.append(num_passes)
            elif (mode_str.lower() in l.data_param.source.lower() or
                    [x.phase for x in l.include] == [mode_num]):
                num_entries = CreateDatasource.from_path(l.data_param.source, key=key).num_entries()
                num_passes = int(num_entries / l.data_param.batch_size)
                if num_entries % l.data_param.batch_size != 0:
                    logger.warning("db size not a multiple of batch size. Adding another fwd. pass.")
                    num_passes += 1
                logger.info("%d fwd. passes with batch size %d" % (num_passes, l.data_param.batch_size))
                num_passes_each.append(num_passes)
    return max(num_passes_each) 
开发者ID:nigroup,项目名称:nideep,代码行数:38,代码来源:inference.py

示例13: test_response_to_lmdb

# 需要导入模块: import caffe [as 别名]
# 或者: from caffe import TRAIN [as 别名]
def test_response_to_lmdb(self, mock_net, mock_num):

        # fake minimal test data
        b = {k : Bunch(data=np.random.rand(4, 1, 3, 2)) for k in ['x', 'y', 'z']}

        # mock methods and properties of Net objects
        mock_num.return_value = 3
        mock_net.return_value.forward.return_value = np.zeros(1)
        type(mock_net.return_value).blobs = PropertyMock(return_value=b)
        net = mock_net()

        dst_prefix = os.path.join(self.dir_tmp, 'test_response_to_lmdb_')
        for m in ['train', 'test']:
            for k in b.keys():
                assert_false(os.path.isdir(dst_prefix + ('%s_' + m + '_lmdb') % k))
        import nideep
        out = nideep.eval.inference.response_to_lmdb("net.prototxt",
                                                     "w.caffemodel",
                                                     ['x', 'z'],
                                                     dst_prefix)

        assert_equal(net.forward.call_count, 3 * 2)  # double for both modes
        from caffe import TRAIN, TEST
        assert_list_equal(out.keys(), [TRAIN, TEST])
        assert_list_equal(out[TRAIN], [3 * 4] * 2)
        assert_list_equal(out[TEST], [3 * 4] * 2)

        for m in ['train', 'test']:
            for k in b.keys():
                if k in ['x', 'z']:
                    assert_true(os.path.isdir(dst_prefix + ('%s_' + m + '_lmdb') % k))
                else:
                    assert_false(os.path.isdir(dst_prefix + ('%s_' + m + '_lmdb') % k)) 
开发者ID:nigroup,项目名称:nideep,代码行数:35,代码来源:test_inference.py

示例14: setUp

# 需要导入模块: import caffe [as 别名]
# 或者: from caffe import TRAIN [as 别名]
def setUp(self):
        net_file = python_param_net_file()
        self.net = caffe.Net(net_file, caffe.TRAIN)
        os.remove(net_file) 
开发者ID:QinganZhao,项目名称:Deep-Learning-Based-Structural-Damage-Detection,代码行数:6,代码来源:test_python_layer_with_param_str.py

示例15: test_save_hdf5

# 需要导入模块: import caffe [as 别名]
# 或者: from caffe import TRAIN [as 别名]
def test_save_hdf5(self):
        f = tempfile.NamedTemporaryFile(mode='w+', delete=False)
        f.close()
        self.net.save_hdf5(f.name)
        net_file = simple_net_file(self.num_output)
        net2 = caffe.Net(net_file, caffe.TRAIN)
        net2.load_hdf5(f.name)
        os.remove(net_file)
        os.remove(f.name)
        for name in self.net.params:
            for i in range(len(self.net.params[name])):
                self.assertEqual(abs(self.net.params[name][i].data
                    - net2.params[name][i].data).sum(), 0) 
开发者ID:QinganZhao,项目名称:Deep-Learning-Based-Structural-Damage-Detection,代码行数:15,代码来源:test_net.py


注:本文中的caffe.TRAIN属性示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。