当前位置: 首页>>代码示例>>Python>>正文


Python caffe.get_solver方法代码示例

本文整理汇总了Python中caffe.get_solver方法的典型用法代码示例。如果您正苦于以下问题:Python caffe.get_solver方法的具体用法?Python caffe.get_solver怎么用?Python caffe.get_solver使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在caffe的用法示例。


在下文中一共展示了caffe.get_solver方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: train

# 需要导入模块: import caffe [as 别名]
# 或者: from caffe import get_solver [as 别名]
def train(solver_proto_path, snapshot_solver_path=None, init_weights=None, GPU_ID=0):
	"""
	Train the defined net. While we did not use this function for our final net, we used the caffe executable for multi-gpu use, this was used for prototyping
	"""

	import time
	t0 = time.time()
	caffe.set_mode_gpu()
	caffe.set_device(GPU_ID)
	solver = caffe.get_solver(solver_proto_path)
	if snapshot_solver_path is not None:
		solver.solve(snapshot_solver_path) # train from previous solverstate
	else:
		if init_weights is not None:
			solver.net.copy_from(init_weights) # for copying weights from a model without solverstate		
		solver.solve() # train form scratch

	t1 = time.time()
	print 'Total training time: ', t1-t0, ' sec'
	model_dir = "calc_" +  time.strftime("%d-%m-%Y_%I%M%S")
	moveModel(model_dir=model_dir) # move all the model files to a directory  
	print "Moved model to model/"+model_dir 
开发者ID:rpng,项目名称:calc,代码行数:24,代码来源:makeNet.py

示例2: setUp

# 需要导入模块: import caffe [as 别名]
# 或者: from caffe import get_solver [as 别名]
def setUp(self):
        self.num_output = 13
        net_f = simple_net_file(self.num_output)
        f = tempfile.NamedTemporaryFile(delete=False)
        f.write("""net: '""" + net_f + """'
        test_iter: 10 test_interval: 10 base_lr: 0.01 momentum: 0.9
        weight_decay: 0.0005 lr_policy: 'inv' gamma: 0.0001 power: 0.75
        display: 100 max_iter: 100 snapshot_after_train: false""")
        f.close()
        self.solver = caffe.SGDSolver(f.name)
        # also make sure get_solver runs
        caffe.get_solver(f.name)
        caffe.set_mode_cpu()
        # fill in valid labels
        self.solver.net.blobs['label'].data[...] = \
                np.random.randint(self.num_output,
                    size=self.solver.net.blobs['label'].data.shape)
        self.solver.test_nets[0].blobs['label'].data[...] = \
                np.random.randint(self.num_output,
                    size=self.solver.test_nets[0].blobs['label'].data.shape)
        os.remove(f.name)
        os.remove(net_f) 
开发者ID:XiaohangZhan,项目名称:mix-and-match,代码行数:24,代码来源:test_solver.py

示例3: setUp

# 需要导入模块: import caffe [as 别名]
# 或者: from caffe import get_solver [as 别名]
def setUp(self):
        self.num_output = 13
        net_f = simple_net_file(self.num_output)
        f = tempfile.NamedTemporaryFile(mode='w+', delete=False)
        f.write("""net: '""" + net_f + """'
        test_iter: 10 test_interval: 10 base_lr: 0.01 momentum: 0.9
        weight_decay: 0.0005 lr_policy: 'inv' gamma: 0.0001 power: 0.75
        display: 100 max_iter: 100 snapshot_after_train: false
        snapshot_prefix: "model" """)
        f.close()
        self.solver = caffe.SGDSolver(f.name)
        # also make sure get_solver runs
        caffe.get_solver(f.name)
        caffe.set_mode_cpu()
        # fill in valid labels
        self.solver.net.blobs['label'].data[...] = \
                np.random.randint(self.num_output,
                    size=self.solver.net.blobs['label'].data.shape)
        self.solver.test_nets[0].blobs['label'].data[...] = \
                np.random.randint(self.num_output,
                    size=self.solver.test_nets[0].blobs['label'].data.shape)
        os.remove(f.name)
        os.remove(net_f) 
开发者ID:QinganZhao,项目名称:Deep-Learning-Based-Structural-Damage-Detection,代码行数:25,代码来源:test_solver.py

示例4: train_network

# 需要导入模块: import caffe [as 别名]
# 或者: from caffe import get_solver [as 别名]
def train_network(solver_file, num_classes, batch_size, num_iterations, use_gpu=True) :

    if use_gpu :
        caffe.set_mode_gpu()
    else :
        caffe.set_mode_cpu()

    solver = caffe.get_solver(solver_file)
    solver.net.blobs['data'].reshape(batch_size, solver.net.blobs['data'].shape[1], solver.net.blobs['data'].shape[2], solver.net.blobs['data'].shape[3])
    solver.net.blobs['target'].reshape(batch_size, solver.net.blobs['target'].shape[1], solver.net.blobs['target'].shape[2], solver.net.blobs['target'].shape[3])
    solver.net.reshape()

    for i in range(num_iterations):

        data, target = get_data(batch_size, numclasses=num_classes)

        solver.net.blobs['data'].data[...] = data
        solver.net.blobs['target'].data[...] = target
        solver.step(1)
        output = solver.net.blobs['argmax'].data[...]

    fig, sub = plt.subplots(ncols=3, figsize=(15, 5))
    sub[0].set_title('Input')
    sub[0].imshow(data[0, 0, :, :])
    sub[1].set_title('Ground Truth')
    sub[1].imshow(np.argmax(target[0, :, :, :], axis=0))
    sub[2].set_title('Segmentation')
    sub[2].imshow(output[0, 0, :, :])
    plt.show() 
开发者ID:peterneher,项目名称:peters-stuff,代码行数:31,代码来源:unet_segmentation_no_db_example.py

示例5: train

# 需要导入模块: import caffe [as 别名]
# 或者: from caffe import get_solver [as 别名]
def train(self, train_hdf5_fname, val_hdf5_fname=None, solverstate_fname=None, solver_param=None, batch_size=32, visualize_response_maps=False):
        hdf5_txt_fnames = []
        for hdf5_fname in [train_hdf5_fname, val_hdf5_fname]:
            if hdf5_fname is not None:
                head, tail = os.path.split(hdf5_fname)
                root, _ = os.path.splitext(tail)
                hdf5_txt_fname = os.path.join(head, '.' + root + '.txt')
                if not os.path.isfile(hdf5_txt_fname):
                    with open(hdf5_txt_fname, 'w') as f:
                        f.write(hdf5_fname + '\n')
                hdf5_txt_fnames.append(hdf5_txt_fname)
            else:
                hdf5_txt_fnames.append(None)
        train_hdf5_txt_fname, val_hdf5_txt_fname = hdf5_txt_fnames

        input_shapes = (self.x_shape, self.u_shape)
        train_net_param, weight_fillers = self.net_func(input_shapes, train_hdf5_txt_fname, batch_size, self.net_name, phase=caffe.TRAIN)
        if val_hdf5_fname is not None:
            val_net_param, _ = self.net_func(input_shapes, val_hdf5_txt_fname, batch_size, self.net_name, phase=caffe.TEST)

        self.train_val_net_param = train_net_param
        if val_hdf5_fname is not None:
            layers = [layer for layer in self.train_val_net_param.layer]
            # remove layers except for data layers
            for layer in layers:
                if 'Data' not in layer.type:
                    self.train_val_net_param.layer.remove(layer)
            # add data layers from validation net_caffe
            self.train_val_net_param.layer.extend([layer for layer in val_net_param.layer if 'Data' in layer.type])
            # add back the layers that are not data layers
            self.train_val_net_param.layer.extend([layer for layer in layers if 'Data' not in layer.type])
        self.train_val_net_param = net_caffe.train_val_net(self.train_val_net_param)
        train_val_fname = self.get_model_fname('train_val')
        with open(train_val_fname, 'w') as f:
            f.write(str(self.train_val_net_param))

        if solver_param is None:
            solver_param = pb2.SolverParameter()
        self.add_default_parameters(solver_param, val_net=val_hdf5_fname is not None)

        solver_fname = self.get_model_fname('solver')
        with open(solver_fname, 'w') as f:
            f.write(str(solver_param))

        solver = caffe.get_solver(solver_fname)
        self.set_weight_fillers(solver.net.params, weight_fillers)
        for param_name, param in self.params.items():
            for blob, solver_blob in zip(param, solver.net.params[param_name]):
                solver_blob.data[...] = blob.data
        if solverstate_fname is not None:
            if not solverstate_fname.endswith('.solverstate'):
                solverstate_fname = self.get_snapshot_prefix() + '_iter_' + solverstate_fname + '.solverstate'
            solver.restore(solverstate_fname)
        self.solve(solver, solver_param, visualize_response_maps=visualize_response_maps)
        for param_name, param in self.params.items():
            for blob, solver_blob in zip(param, solver.net.params[param_name]):
                blob.data[...] = solver_blob.data

        self.train_net = solver.net
        if val_hdf5_fname is not None:
            self.val_net = solver.test_nets[0] 
开发者ID:alexlee-gk,项目名称:visual_dynamics,代码行数:63,代码来源:predictor_caffe.py


注:本文中的caffe.get_solver方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。