当前位置: 首页>>代码示例>>Python>>正文


Python transforms.ColorJitter方法代码示例

本文整理汇总了Python中torchvision.transforms.ColorJitter方法的典型用法代码示例。如果您正苦于以下问题:Python transforms.ColorJitter方法的具体用法?Python transforms.ColorJitter怎么用?Python transforms.ColorJitter使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torchvision.transforms的用法示例。


在下文中一共展示了transforms.ColorJitter方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_transforms

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def get_transforms(self):

		if self.config['train']['transform'] == True:
			self.train_transform = transforms.Compose([
											transforms.ColorJitter(brightness=self.config['augmentation']['brightness'], contrast=self.config['augmentation']['contrast'], saturation=self.config['augmentation']['saturation'], hue=self.config['augmentation']['hue']),
											transforms.ToTensor(),
											])
		else:
			self.train_transform = transforms.Compose([
											transforms.ToTensor(),
											])

		self.test_transform = transforms.Compose([
										transforms.ToTensor(),
										])

		self.target_transform = transforms.Compose([
										transforms.ToTensor(),
										])

		#Does data augmentation, ie. tranforms images by changing colour, hue brightness, etc., and returns tensor 
开发者ID:mayank-git-hub,项目名称:Text-Recognition,代码行数:23,代码来源:Dlmodel.py

示例2: cifar10_train_transform

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def cifar10_train_transform(ds_metainfo,
                            mean_rgb=(0.4914, 0.4822, 0.4465),
                            std_rgb=(0.2023, 0.1994, 0.2010),
                            jitter_param=0.4):
    assert (ds_metainfo is not None)
    assert (ds_metainfo.input_image_size[0] == 32)
    return transforms.Compose([
        transforms.RandomCrop(
            size=32,
            padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(
            brightness=jitter_param,
            contrast=jitter_param,
            saturation=jitter_param),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=mean_rgb,
            std=std_rgb)
    ]) 
开发者ID:osmr,项目名称:imgclsmob,代码行数:22,代码来源:cifar10_cls_dataset.py

示例3: get_jig_train_transformers

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def get_jig_train_transformers(args):
    size = args.img_transform.random_resize_crop.size
    scale = args.img_transform.random_resize_crop.scale
    img_tr = [transforms.RandomResizedCrop((int(size[0]), int(size[1])), (scale[0], scale[1]))]
    if args.img_transform.random_horiz_flip > 0.0:
        img_tr.append(transforms.RandomHorizontalFlip(args.img_transform.random_horiz_flip))
    if args.img_transform.jitter > 0.0:
        img_tr.append(transforms.ColorJitter(
            brightness=args.img_transform.jitter, contrast=args.img_transform.jitter,
            saturation=args.jitter, hue=min(0.5, args.jitter)))

    tile_tr = []
    if args.jig_transform.tile_random_grayscale:
        tile_tr.append(transforms.RandomGrayscale(args.jig_transform.tile_random_grayscale))
    mean = args.normalize.mean
    std = args.normalize.std
    tile_tr = tile_tr + [transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]

    return transforms.Compose(img_tr), transforms.Compose(tile_tr) 
开发者ID:Jiaolong,项目名称:self-supervised-da,代码行数:21,代码来源:data_loader.py

示例4: get_rot_train_transformers

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def get_rot_train_transformers(args):
    size = args.img_transform.random_resize_crop.size
    scale = args.img_transform.random_resize_crop.scale
    img_tr = [transforms.RandomResizedCrop((int(size[0]), int(size[1])), (scale[0], scale[1]))]
    if args.img_transform.random_horiz_flip > 0.0:
        img_tr.append(transforms.RandomHorizontalFlip(args.img_transform.random_horiz_flip))
    if args.img_transform.jitter > 0.0:
        img_tr.append(transforms.ColorJitter(
            brightness=args.img_transform.jitter, contrast=args.img_transform.jitter,
            saturation=args.jitter, hue=min(0.5, args.jitter)))

    mean = args.normalize.mean
    std = args.normalize.std
    img_tr += [transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]

    return transforms.Compose(img_tr) 
开发者ID:Jiaolong,项目名称:self-supervised-da,代码行数:18,代码来源:data_loader.py

示例5: preprocess

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def preprocess(self):
        if self.train:
            return transforms.Compose([
                transforms.RandomResizedCrop(self.image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
                transforms.ToTensor(),
                transforms.Normalize(self.mean, self.std),
            ])
        else:
            return transforms.Compose([
                transforms.Resize((int(self.image_size / 0.875), int(self.image_size / 0.875))),
                transforms.CenterCrop(self.image_size),
                transforms.ToTensor(),
                transforms.Normalize(self.mean, self.std),
            ]) 
开发者ID:wandering007,项目名称:nasnet-pytorch,代码行数:18,代码来源:imagenet.py

示例6: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def __init__(self, root_dir, txtlist, use_noise, length):
        self.path = []
        self.real_path = []
        self.use_noise = use_noise
        self.root = root_dir
        input_file = open(txtlist)
        while 1:
            input_line = input_file.readline()
            if not input_line:
                break
            if input_line[-1:] == '\n':
                input_line = input_line[:-1]
            self.path.append(copy.deepcopy(input_line))
            if input_line[:5] == 'data/':
                self.real_path.append(copy.deepcopy(input_line))
        input_file.close()

        self.length = length
        self.data_len = len(self.path)
        self.back_len = len(self.real_path)

        self.trancolor = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05)
        self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.back_front = np.array([[1 for i in range(640)] for j in range(480)]) 
开发者ID:j96w,项目名称:DenseFusion,代码行数:26,代码来源:data_controller.py

示例7: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def __init__(self, transform, mode, select_attrs=[], out_img_size=64, bbox_out_size=32, randomrotate=0, scaleRange=[0.1, 0.9], squareAspectRatio=False, use_celeb=False):
        self.image_path = os.path.join('data','mnist')
        self.mode = mode
        self.iouThresh = 0.5
        self.maxDigits= 1
        self.minDigits = 1
        self.use_celeb = use_celeb
        self.scaleRange = scaleRange
        self.squareAspectRatio = squareAspectRatio
        self.nc = 1 if not self.use_celeb else 3
        transList = [transforms.RandomHorizontalFlip(), transforms.RandomRotation(randomrotate,resample=Image.BICUBIC)]#, transforms.ColorJitter(0.5,0.5,0.5,0.3)
        self.digitTransforms = transforms.Compose(transList)
        self.dataset = MNIST(self.image_path,train=True, transform=self.digitTransforms) if not use_celeb else CelebDataset('./data/celebA/images', './data/celebA/list_attr_celeba.txt', self.digitTransforms, mode)
        self.num_data = len(self.dataset)
        self.metadata = {'images':[]}
        self.catid2attr = {}
        self.out_img_size = out_img_size
        self.bbox_out_size = bbox_out_size
        self.selected_attrs = select_attrs

        print ('Start preprocessing dataset..!')
        self.preprocess()
        print ('Finished preprocessing dataset..!') 
开发者ID:rakshithShetty,项目名称:adversarial-object-removal,代码行数:25,代码来源:data_loader_stargan.py

示例8: imgnet_transform

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def imgnet_transform(is_training=True):
  if is_training:
    transform_list = transforms.Compose([transforms.RandomResizedCrop(224),
                                         transforms.RandomHorizontalFlip(),
                                         transforms.ColorJitter(brightness=0.5,
                                                                contrast=0.5,
                                                                saturation=0.3),
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                              std=[0.229, 0.224, 0.225])])
  else:
    transform_list = transforms.Compose([transforms.Resize(256),
                                         transforms.CenterCrop(224),
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                              std=[0.229, 0.224, 0.225])])
  return transform_list 
开发者ID:zzzxxxttt,项目名称:pytorch_DoReFaNet,代码行数:19,代码来源:preprocessing.py

示例9: get_transform

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def get_transform(resize, phase='train'):
    if phase == 'train':
        return transforms.Compose([
            transforms.Resize(size=(int(resize[0] / 0.875), int(resize[1] / 0.875))),
            transforms.RandomCrop(resize),
            transforms.RandomHorizontalFlip(0.5),
            transforms.ColorJitter(brightness=0.126, saturation=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    else:
        return transforms.Compose([
            transforms.Resize(size=(int(resize[0] / 0.875), int(resize[1] / 0.875))),
            transforms.CenterCrop(resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]) 
开发者ID:GuYuc,项目名称:WS-DAN.PyTorch,代码行数:19,代码来源:utils.py

示例10: get_data_transforms

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def get_data_transforms():
	
	data_transforms = {
	    'train': transforms.Compose([
	        transforms.CenterCrop(config.patch_size),
	        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.2),
	        transforms.RandomHorizontalFlip(),
	        transforms.RandomVerticalFlip(),
	        Random90Rotation(),
	        transforms.ToTensor(),
	        transforms.Normalize([0.7, 0.6, 0.7], [0.15, 0.15, 0.15]) #mean and standard deviations for lung adenocarcinoma resection slides
	    ]),
	    'val': transforms.Compose([
	        transforms.CenterCrop(config.patch_size),
	        transforms.ToTensor(),
	        transforms.Normalize([0.7, 0.6, 0.7], [0.15, 0.15, 0.15])
	    ]),
	    'unnormalize': transforms.Compose([
	        transforms.Normalize([1/0.15, 1/0.15, 1/0.15], [1/0.15, 1/0.15, 1/0.15])
	    ]),
	}

	return data_transforms

#printing the model 
开发者ID:BMIRDS,项目名称:HistoGAN,代码行数:27,代码来源:utils_model.py

示例11: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def __init__(self, data_dir, image_size, is_train=True, **kwargs):
		self.image_size = image_size
		self.image_paths = []
		self.image_labels = []
		self.classes = sorted(os.listdir(data_dir))
		for idx, cls_ in enumerate(self.classes):
			self.image_paths += glob.glob(os.path.join(data_dir, cls_, '*.*'))
			self.image_labels += [idx] * len(glob.glob(os.path.join(data_dir, cls_, '*.*')))
		self.indexes = list(range(len(self.image_paths)))
		if is_train:
			random.shuffle(self.indexes)
			self.transform = transforms.Compose([transforms.RandomResizedCrop(image_size),
												 transforms.RandomHorizontalFlip(),
												 transforms.ColorJitter(brightness=1, contrast=1, saturation=0.5, hue=0.5),
												 transforms.ToTensor(),
												 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
		else:
			self.transform = transforms.Compose([transforms.ToTensor(),
												 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 
开发者ID:CharlesPikachu,项目名称:garbageClassifier,代码行数:21,代码来源:datasets.py

示例12: build_train_transform

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def build_train_transform(self, distort_color, resize_scale):
        print('Color jitter: %s' % distort_color)
        if distort_color == 'strong':
            color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
        elif distort_color == 'normal':
            color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
        else:
            color_transform = None
        if color_transform is None:
            train_transforms = transforms.Compose([
                transforms.RandomResizedCrop(self.image_size, scale=(resize_scale, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                self.normalize,
            ])
        else:
            train_transforms = transforms.Compose([
                transforms.RandomResizedCrop(self.image_size, scale=(resize_scale, 1.0)),
                transforms.RandomHorizontalFlip(),
                color_transform,
                transforms.ToTensor(),
                self.normalize,
            ])
        return train_transforms 
开发者ID:microsoft,项目名称:nni,代码行数:26,代码来源:datasets.py

示例13: data_transforms_imagenet

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def data_transforms_imagenet():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(128),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(
            brightness=0.4,
            contrast=0.4,
            saturation=0.4,
            hue=0.2),
        transforms.ToTensor(),
        normalize,
    ])

    valid_transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(128),
                transforms.ToTensor(),
                normalize,
            ])

    return train_transform, valid_transform 
开发者ID:antoyang,项目名称:NAS-Benchmark,代码行数:25,代码来源:utils.py

示例14: data_transforms_food101

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def data_transforms_food101():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(128), # default bilinear for interpolation
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(
            brightness=0.2,
            contrast=0.2,
            saturation=0.2,
            hue=0.2),
        transforms.ToTensor(),
        normalize,
    ])

    valid_transform = transforms.Compose([
                transforms.Resize(128),
                transforms.CenterCrop(128),
                transforms.ToTensor(),
                normalize,
            ])

    return train_transform, valid_transform 
开发者ID:antoyang,项目名称:NAS-Benchmark,代码行数:25,代码来源:utils.py

示例15: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import ColorJitter [as 别名]
def __init__(self):
        # flipping image along vertical axis
        self.flip_lr = transforms.RandomHorizontalFlip(p=0.5)
        # image augmentation functions
        normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        col_jitter = transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8)
        img_jitter = transforms.RandomApply([
            RandomTranslateWithReflect(4)], p=0.8)
        rnd_gray = transforms.RandomGrayscale(p=0.25)
        # main transform for self-supervised training
        self.train_transform = transforms.Compose([
            img_jitter,
            col_jitter,
            rnd_gray,
            transforms.ToTensor(),
            normalize
        ])
        # transform for testing
        self.test_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize
        ]) 
开发者ID:Philip-Bachman,项目名称:amdim-public,代码行数:26,代码来源:datasets.py


注:本文中的torchvision.transforms.ColorJitter方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。