当前位置: 首页>>代码示例>>Python>>正文


Python unet.UNet方法代码示例

本文整理汇总了Python中unet.UNet方法的典型用法代码示例。如果您正苦于以下问题:Python unet.UNet方法的具体用法?Python unet.UNet怎么用?Python unet.UNet使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在unet的用法示例。


在下文中一共展示了unet.UNet方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_args

# 需要导入模块: import unet [as 别名]
# 或者: from unet import UNet [as 别名]
def get_args():
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-e', '--epochs', metavar='E', type=int, default=5,
                        help='Number of epochs', dest='epochs')
    parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1,
                        help='Batch size', dest='batchsize')
    parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.1,
                        help='Learning rate', dest='lr')
    parser.add_argument('-f', '--load', dest='load', type=str, default=False,
                        help='Load model from a .pth file')
    parser.add_argument('-s', '--scale', dest='scale', type=float, default=0.5,
                        help='Downscaling factor of the images')
    parser.add_argument('-v', '--validation', dest='val', type=float, default=10.0,
                        help='Percent of the data that is used as validation (0-100)')

    return parser.parse_args() 
开发者ID:milesial,项目名称:Pytorch-UNet,代码行数:19,代码来源:train.py

示例2: unet

# 需要导入模块: import unet [as 别名]
# 或者: from unet import UNet [as 别名]
def unet(pretrained=False, **kwargs):
    """
    U-Net segmentation model with batch normalization for biomedical image segmentation
    pretrained (bool): load pretrained weights into the model
    in_channels (int): number of input channels
    out_channels (int): number of output channels
    init_features (int): number of feature-maps in the first encoder layer
    """
    model = UNet(**kwargs)

    if pretrained:
        checkpoint = "https://github.com/mateuszbuda/brain-segmentation-pytorch/releases/download/v1.0/unet-e012d006.pt"
        state_dict = torch.hub.load_state_dict_from_url(checkpoint, progress=False, map_location='cpu')
        model.load_state_dict(state_dict)

    return model 
开发者ID:mateuszbuda,项目名称:brain-segmentation-pytorch,代码行数:18,代码来源:hubconf.py

示例3: main

# 需要导入模块: import unet [as 别名]
# 或者: from unet import UNet [as 别名]
def main():

    # init conv net
    
    unet = UNet(3,1)
    if os.path.exists("./unet.pkl"):
        unet.load_state_dict(torch.load("./unet.pkl"))
        print("load unet")
    unet.cuda()

    cnn = CNNEncoder()
    if os.path.exists("./cnn.pkl"):
        cnn.load_state_dict(torch.load("./cnn.pkl"))
        print("load cnn")
    cnn.cuda()

    unet.eval()
    cnn.eval()
    
    print("load ok")

    while True:
        pull_screenshot("autojump.png") # obtain screen and save it to autojump.png
        image = Image.open('./autojump.png')
        set_button_position(image)
        image = preprocess(image)
        
        image = Variable(image.unsqueeze(0)).cuda()
        mask = unet(image)

        plt.imshow(mask.squeeze(0).squeeze(0).cpu().data.numpy(), cmap='hot', interpolation='nearest')
        plt.show()
        
        segmentation = image * mask

        press_time = cnn(segmentation)
        press_time = press_time.cpu().data[0].numpy()
        print(press_time)
        jump(press_time)
        
        time.sleep(random.uniform(0.6, 1.1)) 
开发者ID:floodsung,项目名称:wechat_jump_end_to_end_train,代码行数:43,代码来源:run_mask.py

示例4: _compile

# 需要导入模块: import unet [as 别名]
# 或者: from unet import UNet [as 别名]
def _compile(self):
        """Compiles model (architecture, loss function, optimizers, etc.)."""

        print('Noise2Noise: Learning Image Restoration without Clean Data (Lethinen et al., 2018)')

        # Model (3x3=9 channels for Monte Carlo since it uses 3 HDR buffers)
        if self.p.noise_type == 'mc':
            self.is_mc = True
            self.model = UNet(in_channels=9)
        else:
            self.is_mc = False
            self.model = UNet(in_channels=3)

        # Set optimizer and loss, if in training mode
        if self.trainable:
            self.optim = Adam(self.model.parameters(),
                              lr=self.p.learning_rate,
                              betas=self.p.adam[:2],
                              eps=self.p.adam[2])

            # Learning rate adjustment
            self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optim,
                patience=self.p.nb_epochs/4, factor=0.5, verbose=True)

            # Loss function
            if self.p.loss == 'hdr':
                assert self.is_mc, 'Using HDR loss on non Monte Carlo images'
                self.loss = HDRLoss()
            elif self.p.loss == 'l2':
                self.loss = nn.MSELoss()
            else:
                self.loss = nn.L1Loss()

        # CUDA support
        self.use_cuda = torch.cuda.is_available() and self.p.cuda
        if self.use_cuda:
            self.model = self.model.cuda()
            if self.trainable:
                self.loss = self.loss.cuda() 
开发者ID:joeylitalien,项目名称:noise2noise-pytorch,代码行数:41,代码来源:noise2noise.py

示例5: main

# 需要导入模块: import unet [as 别名]
# 或者: from unet import UNet [as 别名]
def main():

	# init conv net
	print("init net")

	unet = UNet(3,1)
	if os.path.exists("./unet.pkl"):
		unet.load_state_dict(torch.load("./unet.pkl"))
		print("load unet")
	unet.cuda()

	cnn = CNNEncoder()
	if os.path.exists("./cnn.pkl"):
		cnn.load_state_dict(torch.load("./cnn.pkl"))
		print("load cnn")
	cnn.cuda()

	# init dataset
	print("init dataset")
	data_loader = dataset84.jump_data_loader()

	# init optimizer
	unet_optimizer = torch.optim.Adam(unet.parameters(),lr=0.001)
	cnn_optimizer = torch.optim.Adam(cnn.parameters(),lr = 0.001)
	criterion = nn.MSELoss()

	# train
	print("training...")
	for epoch in range(1000):
		for i, (images, press_times) in enumerate(data_loader):
			images = Variable(images).cuda()
			press_times = Variable(press_times.float()).cuda()

			masks = unet(images)
			
			segmentations = images * masks
			predict_press_times = cnn(segmentations)

			loss = criterion(predict_press_times,press_times)

			unet_optimizer.zero_grad()
			cnn_optimizer.zero_grad()
			loss.backward()
			unet_optimizer.step()
			cnn_optimizer.step()

			if (i+1) % 10 == 0:
				print("epoch:",epoch,"step:",i,"loss:",loss.data[0])
			if (epoch+1) % 5 == 0 and i == 0:
				torch.save(unet.state_dict(),"./unet.pkl")
				torch.save(cnn.state_dict(),"./cnn.pkl")
				print("save model") 
开发者ID:floodsung,项目名称:wechat_jump_end_to_end_train,代码行数:54,代码来源:train_mask.py

示例6: main

# 需要导入模块: import unet [as 别名]
# 或者: from unet import UNet [as 别名]
def main(args):
    input_images = tf.placeholder(tf.float32, shape=[1, 480, 360, 4])

    with tf.variable_scope("Gen"):
        gen = UNet(4,4)
        output = tf.sigmoid(gen(input_images))

    global_step = tf.get_variable('global_step', initializer=0, trainable=False)

    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)

    saver = tf.train.Saver()
    if args.checkpoint is not None and os.path.exists(os.path.join(args.logdir, 'checkpoint')):
        if args.checkpoint == -1:#latest checkpoint
            saver.restore(sess, tf.train.latest_checkpoint(args.logdir))
        else:#Specified checkpoint
            saver.restore(sess, os.path.join(args.logdir, model_name+".ckpt-"+str(args.checkpoint)))
        logging.info('Model restored to step ' + str(global_step.eval(sess)))


    images, targets = [], []

    input_filename = args.input
    image = load_image(input_filename)
    print(image.shape)
    trimap = generate_trimap(args.object)

    image = np.array(image)
    trimap = np.array(trimap)[..., np.newaxis]
    print(image.shape)
    print(trimap.shape)
    image = np.concatenate((image, trimap), axis = 2).astype(np.float32) / 255

    result = sess.run(output, feed_dict={
            input_images: np.asarray([image]),
            })

    print(result.shape)
    image = Image.fromarray((result[0,:,:,:]*255).astype(np.uint8))
    image.save(args.output) 
开发者ID:eti-p-doray,项目名称:unet-gan-matting,代码行数:44,代码来源:eval.py

示例7: main

# 需要导入模块: import unet [as 别名]
# 或者: from unet import UNet [as 别名]
def main(args):
    makedirs(args)
    device = torch.device("cpu" if not torch.cuda.is_available() else args.device)

    loader = data_loader(args)

    with torch.set_grad_enabled(False):
        unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
        state_dict = torch.load(args.weights, map_location=device)
        unet.load_state_dict(state_dict)
        unet.eval()
        unet.to(device)

        input_list = []
        pred_list = []
        true_list = []

        for i, data in tqdm(enumerate(loader)):
            x, y_true = data
            x, y_true = x.to(device), y_true.to(device)

            y_pred = unet(x)
            y_pred_np = y_pred.detach().cpu().numpy()
            pred_list.extend([y_pred_np[s] for s in range(y_pred_np.shape[0])])

            y_true_np = y_true.detach().cpu().numpy()
            true_list.extend([y_true_np[s] for s in range(y_true_np.shape[0])])

            x_np = x.detach().cpu().numpy()
            input_list.extend([x_np[s] for s in range(x_np.shape[0])])

    volumes = postprocess_per_volume(
        input_list,
        pred_list,
        true_list,
        loader.dataset.patient_slice_index,
        loader.dataset.patients,
    )

    dsc_dist = dsc_distribution(volumes)

    dsc_dist_plot = plot_dsc(dsc_dist)
    imsave(args.figure, dsc_dist_plot)

    for p in volumes:
        x = volumes[p][0]
        y_pred = volumes[p][1]
        y_true = volumes[p][2]
        for s in range(x.shape[0]):
            image = gray2rgb(x[s, 1])  # channel 1 is for FLAIR
            image = outline(image, y_pred[s, 0], color=[255, 0, 0])
            image = outline(image, y_true[s, 0], color=[0, 255, 0])
            filename = "{}-{}.png".format(p, str(s).zfill(2))
            filepath = os.path.join(args.predictions, filename)
            imsave(filepath, image) 
开发者ID:mateuszbuda,项目名称:brain-segmentation-pytorch,代码行数:57,代码来源:inference.py

示例8: _add_bgnd

# 需要导入模块: import unet [as 别名]
# 或者: from unet import UNet [as 别名]
def _add_bgnd(fname, _model_path = model_path, _int_scale = (0, 255), cuda_id = 0):
    
    if torch.cuda.is_available():
        print("THIS IS CUDA!!!!")
        dev_str = "cuda:" + str(cuda_id)
    else:
        dev_str = 'cpu'
    device = torch.device(dev_str)
    
    model = UNet(n_channels = 1, n_classes = 1)
    state = torch.load(_model_path, map_location = 'cpu')
    model.load_state_dict(state['state_dict'])
    
    model = model.to(device)
    model.eval()
    
    with tables.File(fname, 'r+') as fid:
        full_data = fid.get_node('/full_data')
        
        if '/bgnd' in fid:
            fid.remove_node('/bgnd')
            
        bgnd = createImgGroup(fid, "/bgnd", *full_data.shape, is_expandable = False)
        bgnd._v_attrs['save_interval'] = full_data._v_attrs['save_interval']
        
        for ii in tqdm.trange(full_data.shape[0]):
            img = full_data[ii]
            
            x = img.astype(np.float32)
            x = (x - _int_scale[0])/(_int_scale[1] - _int_scale[0])
            
            
            with torch.no_grad():
                X = torch.from_numpy(x[None, None])
                X = X.to(device)
                Xhat = model(X)
            
            xhat = Xhat.squeeze().detach().cpu().numpy()
            
            bg = xhat*(_int_scale[1] - _int_scale[0]) + _int_scale[0]
            bg = bg.round().astype(img.dtype)
            bgnd[ii] = bg 
开发者ID:ver228,项目名称:tierpsy-tracker,代码行数:44,代码来源:add_bgnd.py

示例9: create_model

# 需要导入模块: import unet [as 别名]
# 或者: from unet import UNet [as 别名]
def create_model(args, input_shape):
    # If using CPU or single GPU
    if args.gpus <= 1:
        if args.net == 'unet':
            from unet import UNet
            model = UNet(input_shape)
            return [model]
        elif args.net == 'tiramisu':
            from densenets import DenseNetFCN
            model = DenseNetFCN(input_shape)
            return [model]
        elif args.net == 'segcapsr1':
            from capsnet import CapsNetR1
            model_list = CapsNetR1(input_shape)
            return model_list
        elif args.net == 'segcapsr3':
            from capsnet import CapsNetR3
            model_list = CapsNetR3(input_shape)
            return model_list
        elif args.net == 'capsbasic':
            from capsnet import CapsNetBasic
            model_list = CapsNetBasic(input_shape)
            return model_list
        else:
            raise Exception('Unknown network type specified: {}'.format(args.net))
    # If using multiple GPUs
    else:
        with tf.device("/cpu:0"):
            if args.net == 'unet':
                from unet import UNet
                model = UNet(input_shape)
                return [model]
            elif args.net == 'tiramisu':
                from densenets import DenseNetFCN
                model = DenseNetFCN(input_shape)
                return [model]
            elif args.net == 'segcapsr1':
                from capsnet import CapsNetR1
                model_list = CapsNetR1(input_shape)
                return model_list
            elif args.net == 'segcapsr3':
                from capsnet import CapsNetR3
                model_list = CapsNetR3(input_shape)
                return model_list
            elif args.net == 'capsbasic':
                from capsnet import CapsNetBasic
                model_list = CapsNetBasic(input_shape)
                return model_list
            else:
                raise Exception('Unknown network type specified: {}'.format(args.net)) 
开发者ID:lalonderodney,项目名称:SegCaps,代码行数:52,代码来源:model_helper.py


注:本文中的unet.UNet方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。