本文整理汇总了Python中torchvision.transforms.CenterCrop方法的典型用法代码示例。如果您正苦于以下问题:Python transforms.CenterCrop方法的具体用法?Python transforms.CenterCrop怎么用?Python transforms.CenterCrop使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torchvision.transforms
的用法示例。
在下文中一共展示了transforms.CenterCrop方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _get_ds_val
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def _get_ds_val(self, images_spec, crop=False, truncate=False):
img_to_tensor_t = [images_loader.IndexImagesDataset.to_tensor_uint8_transform()]
if crop:
img_to_tensor_t.insert(0, transforms.CenterCrop(crop))
img_to_tensor_t = transforms.Compose(img_to_tensor_t)
fixed_first = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'fixedimg.jpg')
if not os.path.isfile(fixed_first):
print(f'INFO: No file found at {fixed_first}')
fixed_first = None
ds = images_loader.IndexImagesDataset(
images=images_loader.ImagesCached(
images_spec, self.config_dl.image_cache_pkl,
min_size=self.config_dl.val_glob_min_size),
to_tensor_transform=img_to_tensor_t,
fixed_first=fixed_first) # fix a first image to have consistency in tensor board
if truncate:
ds = pe.TruncatedDataset(ds, num_elemens=truncate)
return ds
示例2: get_lsun_dataloader
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def get_lsun_dataloader(path_to_data='../lsun', dataset='bedroom_train',
batch_size=64):
"""LSUN dataloader with (128, 128) sized images.
path_to_data : str
One of 'bedroom_val' or 'bedroom_train'
"""
# Compose transforms
transform = transforms.Compose([
transforms.Resize(128),
transforms.CenterCrop(128),
transforms.ToTensor()
])
# Get dataset
lsun_dset = datasets.LSUN(db_path=path_to_data, classes=[dataset],
transform=transform)
# Create dataloader
return DataLoader(lsun_dset, batch_size=batch_size, shuffle=True)
示例3: save_distorted
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def save_distorted(method=gaussian_noise):
for severity in range(1, 6):
print(method.__name__, severity)
distorted_dataset = DistortImageFolder(
root="/share/data/vision-greg/ImageNet/clsloc/images/val",
method=method, severity=severity,
transform=trn.Compose([trn.Resize(256), trn.CenterCrop(224)]))
distorted_dataset_loader = torch.utils.data.DataLoader(
distorted_dataset, batch_size=100, shuffle=False, num_workers=4)
for _ in distorted_dataset_loader: continue
# /////////////// End Further Setup ///////////////
# /////////////// Display Results ///////////////
示例4: transform
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def transform(is_train=True, normalize=True):
"""
Returns a transform object
"""
filters = []
filters.append(Scale(256))
if is_train:
filters.append(RandomCrop(224))
else:
filters.append(CenterCrop(224))
if is_train:
filters.append(RandomHorizontalFlip())
filters.append(ToTensor())
if normalize:
filters.append(Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]))
return Compose(filters)
示例5: __init__
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def __init__(
self,
resize: int = ImagenetConstants.RESIZE,
crop_size: int = ImagenetConstants.CROP_SIZE,
mean: List[float] = ImagenetConstants.MEAN,
std: List[float] = ImagenetConstants.STD,
):
"""The constructor method of ImagenetNoAugmentTransform class.
Args:
resize: expected image size per dimension after resizing
crop_size: expected size for a dimension of central cropping
mean: a 3-tuple denoting the pixel RGB mean
std: a 3-tuple denoting the pixel RGB standard deviation
"""
self.transform = transforms.Compose(
[
transforms.Resize(resize),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
]
)
示例6: make
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def make(sz_resize = 256, sz_crop = 227, mean = [104, 117, 128],
std = [1, 1, 1], rgb_to_bgr = True, is_train = True,
intensity_scale = None):
return transforms.Compose([
RGBToBGR() if rgb_to_bgr else Identity(),
transforms.RandomResizedCrop(sz_crop) if is_train else Identity(),
transforms.Resize(sz_resize) if not is_train else Identity(),
transforms.CenterCrop(sz_crop) if not is_train else Identity(),
transforms.RandomHorizontalFlip() if is_train else Identity(),
transforms.ToTensor(),
ScaleIntensities(
*intensity_scale) if intensity_scale is not None else Identity(),
transforms.Normalize(
mean=mean,
std=std,
)
])
示例7: test_on_validation_set
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def test_on_validation_set(model, validation_set=None):
if validation_set is None:
validation_set = get_validation_set()
total_ssim = 0
total_psnr = 0
iters = len(validation_set.tuples)
crop = CenterCrop(config.CROP_SIZE)
for i, tup in enumerate(validation_set.tuples):
x1, gt, x2, = [crop(load_img(p)) for p in tup]
pred = interpolate(model, x1, x2)
gt = pil_to_tensor(gt)
pred = pil_to_tensor(pred)
total_ssim += ssim(pred, gt).item()
total_psnr += psnr(pred, gt).item()
print(f'#{i+1} done')
avg_ssim = total_ssim / iters
avg_psnr = total_psnr / iters
print(f'avg_ssim: {avg_ssim}, avg_psnr: {avg_psnr}')
示例8: test_linear_interp
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def test_linear_interp(validation_set=None):
if validation_set is None:
validation_set = get_validation_set()
total_ssim = 0
total_psnr = 0
iters = len(validation_set.tuples)
crop = CenterCrop(config.CROP_SIZE)
for tup in validation_set.tuples:
x1, gt, x2, = [pil_to_tensor(crop(load_img(p))) for p in tup]
pred = torch.mean(torch.stack((x1, x2), dim=0), dim=0)
total_ssim += ssim(pred, gt).item()
total_psnr += psnr(pred, gt).item()
avg_ssim = total_ssim / iters
avg_psnr = total_psnr / iters
print(f'avg_ssim: {avg_ssim}, avg_psnr: {avg_psnr}')
示例9: __init__
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def __init__(self, patches, use_cache, augment_data):
super(PatchDataset, self).__init__()
self.patches = patches
self.crop = CenterCrop(config.CROP_SIZE)
if augment_data:
self.random_transforms = [RandomRotation((90, 90)), RandomVerticalFlip(1.0), RandomHorizontalFlip(1.0),
(lambda x: x)]
self.get_aug_transform = (lambda: random.sample(self.random_transforms, 1)[0])
else:
# Transform does nothing. Not sure if horrible or very elegant...
self.get_aug_transform = (lambda: (lambda x: x))
if use_cache:
self.load_patch = data_manager.load_cached_patch
else:
self.load_patch = data_manager.load_patch
print('Dataset ready with {} tuples.'.format(len(patches)))
示例10: preprocess
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def preprocess(self):
if self.train:
return transforms.Compose([
transforms.RandomResizedCrop(self.image_size),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
transforms.ToTensor(),
transforms.Normalize(self.mean, self.std),
])
else:
return transforms.Compose([
transforms.Resize((int(self.image_size / 0.875), int(self.image_size / 0.875))),
transforms.CenterCrop(self.image_size),
transforms.ToTensor(),
transforms.Normalize(self.mean, self.std),
])
示例11: __getitem__
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def __getitem__(self, index):
# get downscaled, cropped and gt (if available) image
hr_image = Image.open(self.hr_files[index])
w, h = hr_image.size
cs = utils.calculate_valid_crop_size(min(w, h), self.upscale_factor)
if self.crop_size is not None:
cs = min(cs, self.crop_size)
cropped_image = TF.to_tensor(T.CenterCrop(cs // self.upscale_factor)(hr_image))
hr_image = T.CenterCrop(cs)(hr_image)
hr_image = TF.to_tensor(hr_image)
resized_image = utils.imresize(hr_image, 1.0 / self.upscale_factor, True)
if self.lr_files is None:
return resized_image, cropped_image, resized_image
else:
lr_image = Image.open(self.lr_files[index])
lr_image = TF.to_tensor(T.CenterCrop(cs // self.upscale_factor)(lr_image))
return resized_image, cropped_image, lr_image
示例12: __init__
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def __init__(self, options):
transform_list = []
if options.image_size is not None:
transform_list.append(transforms.Resize((options.image_size, options.image_size)))
# transform_list.append(transforms.CenterCrop(options.image_size))
transform_list.append(transforms.ToTensor())
if options.image_colors == 1:
transform_list.append(transforms.Normalize(mean=[0.5], std=[0.5]))
elif options.image_colors == 3:
transform_list.append(transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]))
transform = transforms.Compose(transform_list)
dataset = ImagePairs(options.data_dir, split=options.split, transform=transform)
self.dataloader = DataLoader(
dataset,
batch_size=options.batch_size,
num_workers=options.loader_workers,
shuffle=True,
drop_last=True,
pin_memory=options.pin_memory
)
self.iterator = iter(self.dataloader)
示例13: __init__
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def __init__(self, path, classes, stage='train'):
self.data = []
for i, c in enumerate(classes):
cls_path = osp.join(path, c)
images = os.listdir(cls_path)
for image in images:
self.data.append((osp.join(cls_path, image), i))
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if stage == 'train':
self.transforms = transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize])
if stage == 'test':
self.transforms = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize])
示例14: __init__
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def __init__(self, opt):
self.image_path = opt.dataroot
self.is_train = opt.is_train
self.d_num = opt.n_attribute
print ('Start preprocessing dataset..!')
random.seed(1234)
self.preprocess()
print ('Finished preprocessing dataset..!')
if self.is_train:
trs = [transforms.Resize(opt.load_size, interpolation=Image.ANTIALIAS), transforms.RandomCrop(opt.fine_size)]
else:
trs = [transforms.Resize(opt.load_size, interpolation=Image.ANTIALIAS), transforms.CenterCrop(opt.fine_size)]
if opt.is_flip:
trs.append(transforms.RandomHorizontalFlip())
self.transform = transforms.Compose(trs)
self.norm = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
self.num_data = max(self.num)
示例15: __init__
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def __init__(self, opt):
'''Initialize this dataset class.
We need to specific the path of the dataset and the domain label of each image.
'''
self.image_list = []
self.label_list = []
if opt.is_train:
trs = [transforms.Resize(opt.load_size, interpolation=Image.ANTIALIAS), transforms.RandomCrop(opt.fine_size)]
else:
trs = [transforms.Resize(opt.load_size, interpolation=Image.ANTIALIAS), transforms.CenterCrop(opt.fine_size)]
if opt.is_flip:
trs.append(transforms.RandomHorizontalFlip())
trs.append(transforms.ToTensor())
trs.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
self.transform = transforms.Compose(trs)
self.num_data = len(self.image_list)