本文整理汇总了Python中torchvision.transforms.transforms.CenterCrop方法的典型用法代码示例。如果您正苦于以下问题:Python transforms.CenterCrop方法的具体用法?Python transforms.CenterCrop怎么用?Python transforms.CenterCrop使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torchvision.transforms.transforms
的用法示例。
在下文中一共展示了transforms.CenterCrop方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_transforms
# 需要导入模块: from torchvision.transforms import transforms [as 别名]
# 或者: from torchvision.transforms.transforms import CenterCrop [as 别名]
def get_transforms(eval=False, aug=None):
trans = []
if aug["randcrop"] and not eval:
trans.append(transforms.RandomCrop(aug["randcrop"]))
if aug["randcrop"] and eval:
trans.append(transforms.CenterCrop(aug["randcrop"]))
if aug["flip"] and not eval:
trans.append(transforms.RandomHorizontalFlip())
if aug["grayscale"]:
trans.append(transforms.Grayscale())
trans.append(transforms.ToTensor())
trans.append(transforms.Normalize(mean=aug["bw_mean"], std=aug["bw_std"]))
elif aug["mean"]:
trans.append(transforms.ToTensor())
trans.append(transforms.Normalize(mean=aug["mean"], std=aug["std"]))
else:
trans.append(transforms.ToTensor())
trans = transforms.Compose(trans)
return trans
示例2: __init__
# 需要导入模块: from torchvision.transforms import transforms [as 别名]
# 或者: from torchvision.transforms.transforms import CenterCrop [as 别名]
def __init__(self, size: Union[Tuple[int, int], int]):
super().__init__()
self._image_transform = tv.CenterCrop(size)
示例3: handle
# 需要导入模块: from torchvision.transforms import transforms [as 别名]
# 或者: from torchvision.transforms.transforms import CenterCrop [as 别名]
def handle(self, source, copy_to_local=False, normalize=True,
split=None, classification_mode=False, **transform_args):
"""
Args:
source:
copy_to_local:
normalize:
**transform_args:
Returns:
"""
Dataset = self.make_indexing(CelebA)
data_path = self.get_path(source)
if copy_to_local:
data_path = self.copy_to_local_path(data_path)
if normalize and isinstance(normalize, bool):
normalize = [(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)]
if classification_mode:
train_transform = transforms.Compose([
transforms.RandomResizedCrop(64),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(*normalize),
])
test_transform = transforms.Compose([
transforms.Resize(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize(*normalize),
])
else:
train_transform = build_transforms(normalize=normalize,
**transform_args)
test_transform = train_transform
if split is None:
train_set = Dataset(root=data_path, transform=train_transform,
download=True)
test_set = Dataset(root=data_path, transform=test_transform)
else:
train_set, test_set = self.make_split(
data_path, split, Dataset, train_transform, test_transform)
input_names = ['images', 'labels', 'attributes']
dim_c, dim_x, dim_y = train_set[0][0].size()
dim_l = len(train_set.classes)
dim_a = train_set.attributes[0].shape[0]
dims = dict(x=dim_x, y=dim_y, c=dim_c, labels=dim_l, attributes=dim_a)
self.add_dataset('train', train_set)
self.add_dataset('test', test_set)
self.set_input_names(input_names)
self.set_dims(**dims)
self.set_scale((-1, 1))
示例4: _handle_STL
# 需要导入模块: from torchvision.transforms import transforms [as 别名]
# 或者: from torchvision.transforms.transforms import CenterCrop [as 别名]
def _handle_STL(self, Dataset, data_path, transform=None,
labeled_only=False, stl_center_crop=False,
stl_resize_only=False, stl_no_resize=False):
normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
if stl_no_resize:
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([
transforms.ToTensor(),
normalize,
])
else:
if stl_center_crop:
tr_trans = transforms.CenterCrop(64)
te_trans = transforms.CenterCrop(64)
elif stl_resize_only:
tr_trans = transforms.Resize(64)
te_trans = transforms.Resize(64)
elif stl_no_resize:
pass
else:
tr_trans = transforms.RandomResizedCrop(64)
te_trans = transforms.Resize(64)
train_transform = transforms.Compose([
tr_trans,
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([
te_trans,
transforms.ToTensor(),
normalize,
])
if labeled_only:
split = 'train'
else:
split = 'train+unlabeled'
train_set = Dataset(
data_path, split=split, transform=train_transform, download=True)
test_set = Dataset(
data_path, split='test', transform=test_transform, download=True)
return train_set, test_set