本文整理汇总了Python中torchvision.transforms.ColorJitter方法的典型用法代码示例。如果您正苦于以下问题:Python transforms.ColorJitter方法的具体用法?Python transforms.ColorJitter怎么用?Python transforms.ColorJitter使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torchvision.transforms
的用法示例。
在下文中一共展示了transforms.ColorJitter方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_transforms
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def get_transforms(self):
if self.config['train']['transform'] == True:
self.train_transform = transforms.Compose([
transforms.ColorJitter(brightness=self.config['augmentation']['brightness'], contrast=self.config['augmentation']['contrast'], saturation=self.config['augmentation']['saturation'], hue=self.config['augmentation']['hue']),
transforms.ToTensor(),
])
else:
self.train_transform = transforms.Compose([
transforms.ToTensor(),
])
self.test_transform = transforms.Compose([
transforms.ToTensor(),
])
self.target_transform = transforms.Compose([
transforms.ToTensor(),
])
#Does data augmentation, ie. tranforms images by changing colour, hue brightness, etc., and returns tensor
示例2: cifar10_train_transform
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def cifar10_train_transform(ds_metainfo,
mean_rgb=(0.4914, 0.4822, 0.4465),
std_rgb=(0.2023, 0.1994, 0.2010),
jitter_param=0.4):
assert (ds_metainfo is not None)
assert (ds_metainfo.input_image_size[0] == 32)
return transforms.Compose([
transforms.RandomCrop(
size=32,
padding=4),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(
brightness=jitter_param,
contrast=jitter_param,
saturation=jitter_param),
transforms.ToTensor(),
transforms.Normalize(
mean=mean_rgb,
std=std_rgb)
])
示例3: get_jig_train_transformers
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def get_jig_train_transformers(args):
size = args.img_transform.random_resize_crop.size
scale = args.img_transform.random_resize_crop.scale
img_tr = [transforms.RandomResizedCrop((int(size[0]), int(size[1])), (scale[0], scale[1]))]
if args.img_transform.random_horiz_flip > 0.0:
img_tr.append(transforms.RandomHorizontalFlip(args.img_transform.random_horiz_flip))
if args.img_transform.jitter > 0.0:
img_tr.append(transforms.ColorJitter(
brightness=args.img_transform.jitter, contrast=args.img_transform.jitter,
saturation=args.jitter, hue=min(0.5, args.jitter)))
tile_tr = []
if args.jig_transform.tile_random_grayscale:
tile_tr.append(transforms.RandomGrayscale(args.jig_transform.tile_random_grayscale))
mean = args.normalize.mean
std = args.normalize.std
tile_tr = tile_tr + [transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]
return transforms.Compose(img_tr), transforms.Compose(tile_tr)
示例4: get_rot_train_transformers
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def get_rot_train_transformers(args):
size = args.img_transform.random_resize_crop.size
scale = args.img_transform.random_resize_crop.scale
img_tr = [transforms.RandomResizedCrop((int(size[0]), int(size[1])), (scale[0], scale[1]))]
if args.img_transform.random_horiz_flip > 0.0:
img_tr.append(transforms.RandomHorizontalFlip(args.img_transform.random_horiz_flip))
if args.img_transform.jitter > 0.0:
img_tr.append(transforms.ColorJitter(
brightness=args.img_transform.jitter, contrast=args.img_transform.jitter,
saturation=args.jitter, hue=min(0.5, args.jitter)))
mean = args.normalize.mean
std = args.normalize.std
img_tr += [transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]
return transforms.Compose(img_tr)
示例5: preprocess
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def preprocess(self):
if self.train:
return transforms.Compose([
transforms.RandomResizedCrop(self.image_size),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
transforms.ToTensor(),
transforms.Normalize(self.mean, self.std),
])
else:
return transforms.Compose([
transforms.Resize((int(self.image_size / 0.875), int(self.image_size / 0.875))),
transforms.CenterCrop(self.image_size),
transforms.ToTensor(),
transforms.Normalize(self.mean, self.std),
])
示例6: __init__
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def __init__(self, root_dir, txtlist, use_noise, length):
self.path = []
self.real_path = []
self.use_noise = use_noise
self.root = root_dir
input_file = open(txtlist)
while 1:
input_line = input_file.readline()
if not input_line:
break
if input_line[-1:] == '\n':
input_line = input_line[:-1]
self.path.append(copy.deepcopy(input_line))
if input_line[:5] == 'data/':
self.real_path.append(copy.deepcopy(input_line))
input_file.close()
self.length = length
self.data_len = len(self.path)
self.back_len = len(self.real_path)
self.trancolor = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05)
self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
self.back_front = np.array([[1 for i in range(640)] for j in range(480)])
示例7: __init__
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def __init__(self, transform, mode, select_attrs=[], out_img_size=64, bbox_out_size=32, randomrotate=0, scaleRange=[0.1, 0.9], squareAspectRatio=False, use_celeb=False):
self.image_path = os.path.join('data','mnist')
self.mode = mode
self.iouThresh = 0.5
self.maxDigits= 1
self.minDigits = 1
self.use_celeb = use_celeb
self.scaleRange = scaleRange
self.squareAspectRatio = squareAspectRatio
self.nc = 1 if not self.use_celeb else 3
transList = [transforms.RandomHorizontalFlip(), transforms.RandomRotation(randomrotate,resample=Image.BICUBIC)]#, transforms.ColorJitter(0.5,0.5,0.5,0.3)
self.digitTransforms = transforms.Compose(transList)
self.dataset = MNIST(self.image_path,train=True, transform=self.digitTransforms) if not use_celeb else CelebDataset('./data/celebA/images', './data/celebA/list_attr_celeba.txt', self.digitTransforms, mode)
self.num_data = len(self.dataset)
self.metadata = {'images':[]}
self.catid2attr = {}
self.out_img_size = out_img_size
self.bbox_out_size = bbox_out_size
self.selected_attrs = select_attrs
print ('Start preprocessing dataset..!')
self.preprocess()
print ('Finished preprocessing dataset..!')
示例8: imgnet_transform
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def imgnet_transform(is_training=True):
if is_training:
transform_list = transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.5,
contrast=0.5,
saturation=0.3),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
else:
transform_list = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
return transform_list
示例9: get_transform
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def get_transform(resize, phase='train'):
if phase == 'train':
return transforms.Compose([
transforms.Resize(size=(int(resize[0] / 0.875), int(resize[1] / 0.875))),
transforms.RandomCrop(resize),
transforms.RandomHorizontalFlip(0.5),
transforms.ColorJitter(brightness=0.126, saturation=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
else:
return transforms.Compose([
transforms.Resize(size=(int(resize[0] / 0.875), int(resize[1] / 0.875))),
transforms.CenterCrop(resize),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
示例10: get_data_transforms
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def get_data_transforms():
data_transforms = {
'train': transforms.Compose([
transforms.CenterCrop(config.patch_size),
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.2),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
Random90Rotation(),
transforms.ToTensor(),
transforms.Normalize([0.7, 0.6, 0.7], [0.15, 0.15, 0.15]) #mean and standard deviations for lung adenocarcinoma resection slides
]),
'val': transforms.Compose([
transforms.CenterCrop(config.patch_size),
transforms.ToTensor(),
transforms.Normalize([0.7, 0.6, 0.7], [0.15, 0.15, 0.15])
]),
'unnormalize': transforms.Compose([
transforms.Normalize([1/0.15, 1/0.15, 1/0.15], [1/0.15, 1/0.15, 1/0.15])
]),
}
return data_transforms
#printing the model
示例11: __init__
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def __init__(self, data_dir, image_size, is_train=True, **kwargs):
self.image_size = image_size
self.image_paths = []
self.image_labels = []
self.classes = sorted(os.listdir(data_dir))
for idx, cls_ in enumerate(self.classes):
self.image_paths += glob.glob(os.path.join(data_dir, cls_, '*.*'))
self.image_labels += [idx] * len(glob.glob(os.path.join(data_dir, cls_, '*.*')))
self.indexes = list(range(len(self.image_paths)))
if is_train:
random.shuffle(self.indexes)
self.transform = transforms.Compose([transforms.RandomResizedCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=1, contrast=1, saturation=0.5, hue=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
else:
self.transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
示例12: build_train_transform
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def build_train_transform(self, distort_color, resize_scale):
print('Color jitter: %s' % distort_color)
if distort_color == 'strong':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif distort_color == 'normal':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
else:
color_transform = None
if color_transform is None:
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(self.image_size, scale=(resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
self.normalize,
])
else:
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(self.image_size, scale=(resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
color_transform,
transforms.ToTensor(),
self.normalize,
])
return train_transforms
示例13: data_transforms_imagenet
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def data_transforms_imagenet():
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
transforms.RandomResizedCrop(128),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.2),
transforms.ToTensor(),
normalize,
])
valid_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(128),
transforms.ToTensor(),
normalize,
])
return train_transform, valid_transform
示例14: data_transforms_food101
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def data_transforms_food101():
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
transforms.RandomResizedCrop(128), # default bilinear for interpolation
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(
brightness=0.2,
contrast=0.2,
saturation=0.2,
hue=0.2),
transforms.ToTensor(),
normalize,
])
valid_transform = transforms.Compose([
transforms.Resize(128),
transforms.CenterCrop(128),
transforms.ToTensor(),
normalize,
])
return train_transform, valid_transform
示例15: __init__
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def __init__(self):
# flipping image along vertical axis
self.flip_lr = transforms.RandomHorizontalFlip(p=0.5)
# image augmentation functions
normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
col_jitter = transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8)
img_jitter = transforms.RandomApply([
RandomTranslateWithReflect(4)], p=0.8)
rnd_gray = transforms.RandomGrayscale(p=0.25)
# main transform for self-supervised training
self.train_transform = transforms.Compose([
img_jitter,
col_jitter,
rnd_gray,
transforms.ToTensor(),
normalize
])
# transform for testing
self.test_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])