本文整理汇总了Python中torchvision.transforms.transforms.RandomHorizontalFlip方法的典型用法代码示例。如果您正苦于以下问题:Python transforms.RandomHorizontalFlip方法的具体用法?Python transforms.RandomHorizontalFlip怎么用?Python transforms.RandomHorizontalFlip使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torchvision.transforms.transforms
的用法示例。
在下文中一共展示了transforms.RandomHorizontalFlip方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_datasets
# 需要导入模块: from torchvision.transforms import transforms [as 别名]
# 或者: from torchvision.transforms.transforms import RandomHorizontalFlip [as 别名]
def get_datasets(initial_pool):
transform = transforms.Compose(
[transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(30),
transforms.ToTensor(),
transforms.Normalize(3 * [0.5], 3 * [0.5]), ])
test_transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(3 * [0.5], 3 * [0.5]),
]
)
# Note: We use the test set here as an example. You should make your own validation set.
train_ds = datasets.CIFAR10('.', train=True,
transform=transform, target_transform=None, download=True)
test_set = datasets.CIFAR10('.', train=False,
transform=test_transform, target_transform=None, download=True)
active_set = ActiveLearningDataset(train_ds, pool_specifics={'transform': test_transform})
# We start labeling randomly.
active_set.label_randomly(initial_pool)
return active_set, test_set
示例2: get_transforms
# 需要导入模块: from torchvision.transforms import transforms [as 别名]
# 或者: from torchvision.transforms.transforms import RandomHorizontalFlip [as 别名]
def get_transforms(eval=False, aug=None):
trans = []
if aug["randcrop"] and not eval:
trans.append(transforms.RandomCrop(aug["randcrop"]))
if aug["randcrop"] and eval:
trans.append(transforms.CenterCrop(aug["randcrop"]))
if aug["flip"] and not eval:
trans.append(transforms.RandomHorizontalFlip())
if aug["grayscale"]:
trans.append(transforms.Grayscale())
trans.append(transforms.ToTensor())
trans.append(transforms.Normalize(mean=aug["bw_mean"], std=aug["bw_std"]))
elif aug["mean"]:
trans.append(transforms.ToTensor())
trans.append(transforms.Normalize(mean=aug["mean"], std=aug["std"]))
else:
trans.append(transforms.ToTensor())
trans = transforms.Compose(trans)
return trans
示例3: handle
# 需要导入模块: from torchvision.transforms import transforms [as 别名]
# 或者: from torchvision.transforms.transforms import RandomHorizontalFlip [as 别名]
def handle(self, source, copy_to_local=False, normalize=True,
split=None, classification_mode=False, **transform_args):
"""
Args:
source:
copy_to_local:
normalize:
**transform_args:
Returns:
"""
Dataset = self.make_indexing(CelebA)
data_path = self.get_path(source)
if copy_to_local:
data_path = self.copy_to_local_path(data_path)
if normalize and isinstance(normalize, bool):
normalize = [(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)]
if classification_mode:
train_transform = transforms.Compose([
transforms.RandomResizedCrop(64),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(*normalize),
])
test_transform = transforms.Compose([
transforms.Resize(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize(*normalize),
])
else:
train_transform = build_transforms(normalize=normalize,
**transform_args)
test_transform = train_transform
if split is None:
train_set = Dataset(root=data_path, transform=train_transform,
download=True)
test_set = Dataset(root=data_path, transform=test_transform)
else:
train_set, test_set = self.make_split(
data_path, split, Dataset, train_transform, test_transform)
input_names = ['images', 'labels', 'attributes']
dim_c, dim_x, dim_y = train_set[0][0].size()
dim_l = len(train_set.classes)
dim_a = train_set.attributes[0].shape[0]
dims = dict(x=dim_x, y=dim_y, c=dim_c, labels=dim_l, attributes=dim_a)
self.add_dataset('train', train_set)
self.add_dataset('test', test_set)
self.set_input_names(input_names)
self.set_dims(**dims)
self.set_scale((-1, 1))
示例4: _handle_STL
# 需要导入模块: from torchvision.transforms import transforms [as 别名]
# 或者: from torchvision.transforms.transforms import RandomHorizontalFlip [as 别名]
def _handle_STL(self, Dataset, data_path, transform=None,
labeled_only=False, stl_center_crop=False,
stl_resize_only=False, stl_no_resize=False):
normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
if stl_no_resize:
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([
transforms.ToTensor(),
normalize,
])
else:
if stl_center_crop:
tr_trans = transforms.CenterCrop(64)
te_trans = transforms.CenterCrop(64)
elif stl_resize_only:
tr_trans = transforms.Resize(64)
te_trans = transforms.Resize(64)
elif stl_no_resize:
pass
else:
tr_trans = transforms.RandomResizedCrop(64)
te_trans = transforms.Resize(64)
train_transform = transforms.Compose([
tr_trans,
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([
te_trans,
transforms.ToTensor(),
normalize,
])
if labeled_only:
split = 'train'
else:
split = 'train+unlabeled'
train_set = Dataset(
data_path, split=split, transform=train_transform, download=True)
test_set = Dataset(
data_path, split='test', transform=test_transform, download=True)
return train_set, test_set
示例5: __init__
# 需要导入模块: from torchvision.transforms import transforms [as 别名]
# 或者: from torchvision.transforms.transforms import RandomHorizontalFlip [as 别名]
def __init__(self, root, mode, batchsz, n_way, k_shot, k_query, resize, startidx=0):
"""
:param root: root path of mini-imagenet
:param mode: train, val or test
:param batchsz: batch size of sets, not batch of imgs
:param n_way:
:param k_shot:
:param k_query: num of qeruy imgs per class
:param resize: resize to
:param startidx: start to index label from startidx
"""
self.batchsz = batchsz # batch of set, not batch of imgs
self.n_way = n_way # n-way
self.k_shot = k_shot # k-shot
self.k_query = k_query # for evaluation
self.setsz = self.n_way * self.k_shot # num of samples per set
self.querysz = self.n_way * self.k_query # number of samples per set for evaluation
self.resize = resize # resize to
self.startidx = startidx # index label not from 0, but from startidx
print('shuffle DB :%s, b:%d, %d-way, %d-shot, %d-query, resize:%d' % (
mode, batchsz, n_way, k_shot, k_query, resize))
if mode == 'train':
self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'),
transforms.Resize((self.resize, self.resize)),
# transforms.RandomHorizontalFlip(),
# transforms.RandomRotation(5),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
else:
self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'),
transforms.Resize((self.resize, self.resize)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
self.path = os.path.join(root, 'images') # image path
csvdata = self.loadCSV(os.path.join(root, mode + '.csv')) # csv path
self.data = []
self.img2label = {}
for i, (k, v) in enumerate(csvdata.items()):
self.data.append(v) # [[img1, img2, ...], [img111, ...]]
self.img2label[k] = i + self.startidx # {"img_name[:9]":label}
self.cls_num = len(self.data)
self.create_batch(self.batchsz)