本文整理汇总了Python中unet.UNet方法的典型用法代码示例。如果您正苦于以下问题:Python unet.UNet方法的具体用法?Python unet.UNet怎么用?Python unet.UNet使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类unet
的用法示例。
在下文中一共展示了unet.UNet方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_args
# 需要导入模块: import unet [as 别名]
# 或者: from unet import UNet [as 别名]
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-e', '--epochs', metavar='E', type=int, default=5,
help='Number of epochs', dest='epochs')
parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1,
help='Batch size', dest='batchsize')
parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.1,
help='Learning rate', dest='lr')
parser.add_argument('-f', '--load', dest='load', type=str, default=False,
help='Load model from a .pth file')
parser.add_argument('-s', '--scale', dest='scale', type=float, default=0.5,
help='Downscaling factor of the images')
parser.add_argument('-v', '--validation', dest='val', type=float, default=10.0,
help='Percent of the data that is used as validation (0-100)')
return parser.parse_args()
示例2: unet
# 需要导入模块: import unet [as 别名]
# 或者: from unet import UNet [as 别名]
def unet(pretrained=False, **kwargs):
"""
U-Net segmentation model with batch normalization for biomedical image segmentation
pretrained (bool): load pretrained weights into the model
in_channels (int): number of input channels
out_channels (int): number of output channels
init_features (int): number of feature-maps in the first encoder layer
"""
model = UNet(**kwargs)
if pretrained:
checkpoint = "https://github.com/mateuszbuda/brain-segmentation-pytorch/releases/download/v1.0/unet-e012d006.pt"
state_dict = torch.hub.load_state_dict_from_url(checkpoint, progress=False, map_location='cpu')
model.load_state_dict(state_dict)
return model
示例3: main
# 需要导入模块: import unet [as 别名]
# 或者: from unet import UNet [as 别名]
def main():
# init conv net
unet = UNet(3,1)
if os.path.exists("./unet.pkl"):
unet.load_state_dict(torch.load("./unet.pkl"))
print("load unet")
unet.cuda()
cnn = CNNEncoder()
if os.path.exists("./cnn.pkl"):
cnn.load_state_dict(torch.load("./cnn.pkl"))
print("load cnn")
cnn.cuda()
unet.eval()
cnn.eval()
print("load ok")
while True:
pull_screenshot("autojump.png") # obtain screen and save it to autojump.png
image = Image.open('./autojump.png')
set_button_position(image)
image = preprocess(image)
image = Variable(image.unsqueeze(0)).cuda()
mask = unet(image)
plt.imshow(mask.squeeze(0).squeeze(0).cpu().data.numpy(), cmap='hot', interpolation='nearest')
plt.show()
segmentation = image * mask
press_time = cnn(segmentation)
press_time = press_time.cpu().data[0].numpy()
print(press_time)
jump(press_time)
time.sleep(random.uniform(0.6, 1.1))
示例4: _compile
# 需要导入模块: import unet [as 别名]
# 或者: from unet import UNet [as 别名]
def _compile(self):
"""Compiles model (architecture, loss function, optimizers, etc.)."""
print('Noise2Noise: Learning Image Restoration without Clean Data (Lethinen et al., 2018)')
# Model (3x3=9 channels for Monte Carlo since it uses 3 HDR buffers)
if self.p.noise_type == 'mc':
self.is_mc = True
self.model = UNet(in_channels=9)
else:
self.is_mc = False
self.model = UNet(in_channels=3)
# Set optimizer and loss, if in training mode
if self.trainable:
self.optim = Adam(self.model.parameters(),
lr=self.p.learning_rate,
betas=self.p.adam[:2],
eps=self.p.adam[2])
# Learning rate adjustment
self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optim,
patience=self.p.nb_epochs/4, factor=0.5, verbose=True)
# Loss function
if self.p.loss == 'hdr':
assert self.is_mc, 'Using HDR loss on non Monte Carlo images'
self.loss = HDRLoss()
elif self.p.loss == 'l2':
self.loss = nn.MSELoss()
else:
self.loss = nn.L1Loss()
# CUDA support
self.use_cuda = torch.cuda.is_available() and self.p.cuda
if self.use_cuda:
self.model = self.model.cuda()
if self.trainable:
self.loss = self.loss.cuda()
示例5: main
# 需要导入模块: import unet [as 别名]
# 或者: from unet import UNet [as 别名]
def main():
# init conv net
print("init net")
unet = UNet(3,1)
if os.path.exists("./unet.pkl"):
unet.load_state_dict(torch.load("./unet.pkl"))
print("load unet")
unet.cuda()
cnn = CNNEncoder()
if os.path.exists("./cnn.pkl"):
cnn.load_state_dict(torch.load("./cnn.pkl"))
print("load cnn")
cnn.cuda()
# init dataset
print("init dataset")
data_loader = dataset84.jump_data_loader()
# init optimizer
unet_optimizer = torch.optim.Adam(unet.parameters(),lr=0.001)
cnn_optimizer = torch.optim.Adam(cnn.parameters(),lr = 0.001)
criterion = nn.MSELoss()
# train
print("training...")
for epoch in range(1000):
for i, (images, press_times) in enumerate(data_loader):
images = Variable(images).cuda()
press_times = Variable(press_times.float()).cuda()
masks = unet(images)
segmentations = images * masks
predict_press_times = cnn(segmentations)
loss = criterion(predict_press_times,press_times)
unet_optimizer.zero_grad()
cnn_optimizer.zero_grad()
loss.backward()
unet_optimizer.step()
cnn_optimizer.step()
if (i+1) % 10 == 0:
print("epoch:",epoch,"step:",i,"loss:",loss.data[0])
if (epoch+1) % 5 == 0 and i == 0:
torch.save(unet.state_dict(),"./unet.pkl")
torch.save(cnn.state_dict(),"./cnn.pkl")
print("save model")
示例6: main
# 需要导入模块: import unet [as 别名]
# 或者: from unet import UNet [as 别名]
def main(args):
input_images = tf.placeholder(tf.float32, shape=[1, 480, 360, 4])
with tf.variable_scope("Gen"):
gen = UNet(4,4)
output = tf.sigmoid(gen(input_images))
global_step = tf.get_variable('global_step', initializer=0, trainable=False)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
saver = tf.train.Saver()
if args.checkpoint is not None and os.path.exists(os.path.join(args.logdir, 'checkpoint')):
if args.checkpoint == -1:#latest checkpoint
saver.restore(sess, tf.train.latest_checkpoint(args.logdir))
else:#Specified checkpoint
saver.restore(sess, os.path.join(args.logdir, model_name+".ckpt-"+str(args.checkpoint)))
logging.info('Model restored to step ' + str(global_step.eval(sess)))
images, targets = [], []
input_filename = args.input
image = load_image(input_filename)
print(image.shape)
trimap = generate_trimap(args.object)
image = np.array(image)
trimap = np.array(trimap)[..., np.newaxis]
print(image.shape)
print(trimap.shape)
image = np.concatenate((image, trimap), axis = 2).astype(np.float32) / 255
result = sess.run(output, feed_dict={
input_images: np.asarray([image]),
})
print(result.shape)
image = Image.fromarray((result[0,:,:,:]*255).astype(np.uint8))
image.save(args.output)
示例7: main
# 需要导入模块: import unet [as 别名]
# 或者: from unet import UNet [as 别名]
def main(args):
makedirs(args)
device = torch.device("cpu" if not torch.cuda.is_available() else args.device)
loader = data_loader(args)
with torch.set_grad_enabled(False):
unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
state_dict = torch.load(args.weights, map_location=device)
unet.load_state_dict(state_dict)
unet.eval()
unet.to(device)
input_list = []
pred_list = []
true_list = []
for i, data in tqdm(enumerate(loader)):
x, y_true = data
x, y_true = x.to(device), y_true.to(device)
y_pred = unet(x)
y_pred_np = y_pred.detach().cpu().numpy()
pred_list.extend([y_pred_np[s] for s in range(y_pred_np.shape[0])])
y_true_np = y_true.detach().cpu().numpy()
true_list.extend([y_true_np[s] for s in range(y_true_np.shape[0])])
x_np = x.detach().cpu().numpy()
input_list.extend([x_np[s] for s in range(x_np.shape[0])])
volumes = postprocess_per_volume(
input_list,
pred_list,
true_list,
loader.dataset.patient_slice_index,
loader.dataset.patients,
)
dsc_dist = dsc_distribution(volumes)
dsc_dist_plot = plot_dsc(dsc_dist)
imsave(args.figure, dsc_dist_plot)
for p in volumes:
x = volumes[p][0]
y_pred = volumes[p][1]
y_true = volumes[p][2]
for s in range(x.shape[0]):
image = gray2rgb(x[s, 1]) # channel 1 is for FLAIR
image = outline(image, y_pred[s, 0], color=[255, 0, 0])
image = outline(image, y_true[s, 0], color=[0, 255, 0])
filename = "{}-{}.png".format(p, str(s).zfill(2))
filepath = os.path.join(args.predictions, filename)
imsave(filepath, image)
示例8: _add_bgnd
# 需要导入模块: import unet [as 别名]
# 或者: from unet import UNet [as 别名]
def _add_bgnd(fname, _model_path = model_path, _int_scale = (0, 255), cuda_id = 0):
if torch.cuda.is_available():
print("THIS IS CUDA!!!!")
dev_str = "cuda:" + str(cuda_id)
else:
dev_str = 'cpu'
device = torch.device(dev_str)
model = UNet(n_channels = 1, n_classes = 1)
state = torch.load(_model_path, map_location = 'cpu')
model.load_state_dict(state['state_dict'])
model = model.to(device)
model.eval()
with tables.File(fname, 'r+') as fid:
full_data = fid.get_node('/full_data')
if '/bgnd' in fid:
fid.remove_node('/bgnd')
bgnd = createImgGroup(fid, "/bgnd", *full_data.shape, is_expandable = False)
bgnd._v_attrs['save_interval'] = full_data._v_attrs['save_interval']
for ii in tqdm.trange(full_data.shape[0]):
img = full_data[ii]
x = img.astype(np.float32)
x = (x - _int_scale[0])/(_int_scale[1] - _int_scale[0])
with torch.no_grad():
X = torch.from_numpy(x[None, None])
X = X.to(device)
Xhat = model(X)
xhat = Xhat.squeeze().detach().cpu().numpy()
bg = xhat*(_int_scale[1] - _int_scale[0]) + _int_scale[0]
bg = bg.round().astype(img.dtype)
bgnd[ii] = bg
示例9: create_model
# 需要导入模块: import unet [as 别名]
# 或者: from unet import UNet [as 别名]
def create_model(args, input_shape):
# If using CPU or single GPU
if args.gpus <= 1:
if args.net == 'unet':
from unet import UNet
model = UNet(input_shape)
return [model]
elif args.net == 'tiramisu':
from densenets import DenseNetFCN
model = DenseNetFCN(input_shape)
return [model]
elif args.net == 'segcapsr1':
from capsnet import CapsNetR1
model_list = CapsNetR1(input_shape)
return model_list
elif args.net == 'segcapsr3':
from capsnet import CapsNetR3
model_list = CapsNetR3(input_shape)
return model_list
elif args.net == 'capsbasic':
from capsnet import CapsNetBasic
model_list = CapsNetBasic(input_shape)
return model_list
else:
raise Exception('Unknown network type specified: {}'.format(args.net))
# If using multiple GPUs
else:
with tf.device("/cpu:0"):
if args.net == 'unet':
from unet import UNet
model = UNet(input_shape)
return [model]
elif args.net == 'tiramisu':
from densenets import DenseNetFCN
model = DenseNetFCN(input_shape)
return [model]
elif args.net == 'segcapsr1':
from capsnet import CapsNetR1
model_list = CapsNetR1(input_shape)
return model_list
elif args.net == 'segcapsr3':
from capsnet import CapsNetR3
model_list = CapsNetR3(input_shape)
return model_list
elif args.net == 'capsbasic':
from capsnet import CapsNetBasic
model_list = CapsNetBasic(input_shape)
return model_list
else:
raise Exception('Unknown network type specified: {}'.format(args.net))