当前位置: 首页>>代码示例>>Python>>正文


Python transforms.Grayscale方法代码示例

本文整理汇总了Python中torchvision.transforms.Grayscale方法的典型用法代码示例。如果您正苦于以下问题:Python transforms.Grayscale方法的具体用法?Python transforms.Grayscale怎么用?Python transforms.Grayscale使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torchvision.transforms的用法示例。


在下文中一共展示了transforms.Grayscale方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Grayscale [as 别名]
def __init__(self, train_mode, loader_params, dataset_params, augmentation_params):
        super().__init__(train_mode, loader_params, dataset_params, augmentation_params)

        self.image_transform = transforms.Compose([transforms.Grayscale(num_output_channels=3),
                                                   transforms.ToTensor(),
                                                   transforms.Normalize(mean=self.dataset_params.MEAN,
                                                                        std=self.dataset_params.STD),
                                                   ])
        self.mask_transform = transforms.Compose([transforms.Lambda(to_array),
                                                  transforms.Lambda(to_tensor),
                                                  ])

        self.image_augment_train = ImgAug(self.augmentation_params['image_augment_train'])
        self.image_augment_with_target_train = ImgAug(self.augmentation_params['image_augment_with_target_train'])
        self.image_augment_inference = ImgAug(self.augmentation_params['image_augment_inference'])
        self.image_augment_with_target_inference = ImgAug(
            self.augmentation_params['image_augment_with_target_inference'])

        if self.dataset_params.target_format == 'png':
            self.dataset = ImageSegmentationPngDataset
        elif self.dataset_params.target_format == 'json':
            self.dataset = ImageSegmentationJsonDataset
        else:
            raise Exception('files must be png or json') 
开发者ID:minerva-ml,项目名称:steppy-toolkit,代码行数:26,代码来源:segmentation.py

示例2: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Grayscale [as 别名]
def __init__(self, train_mode, loader_params, dataset_params, augmentation_params):
        super().__init__(train_mode, loader_params, dataset_params, augmentation_params)

        self.image_transform = transforms.Compose([transforms.Grayscale(num_output_channels=3),
                                                   transforms.ToTensor(),
                                                   transforms.Normalize(mean=self.dataset_params.MEAN,
                                                                        std=self.dataset_params.STD),
                                                   AddDepthChannels()
                                                   ])
        self.mask_transform = transforms.Lambda(preprocess_emptiness_target)

        self.image_augment_train = ImgAug(self.augmentation_params['image_augment_train'])
        self.image_augment_with_target_train = ImgAug(self.augmentation_params['image_augment_with_target_train'])
        self.image_augment_inference = ImgAug(self.augmentation_params['image_augment_inference'])
        self.image_augment_with_target_inference = ImgAug(
            self.augmentation_params['image_augment_with_target_inference'])

        self.dataset = EmptinessDataset 
开发者ID:neptune-ai,项目名称:open-solution-salt-identification,代码行数:20,代码来源:loaders.py

示例3: get_data

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Grayscale [as 别名]
def get_data(train):
	data_raw = datasets.CIFAR10('../data/dl/', train=train, download=True,  transform=transforms.Compose([
							transforms.Grayscale(),
							transforms.Resize((20, 20)),
							transforms.ToTensor(),
							lambda x: x.numpy().flatten()]))

	data_x, data_y = zip(*data_raw)
	
	data_x = np.array(data_x)
	data_y = np.array(data_y, dtype='int32').reshape(-1, 1)

	# binarize
	label_0 = data_y < 5
	label_1 = ~label_0

	data_y[label_0] = 0
	data_y[label_1] = 1

	data = pd.DataFrame(data_x)
	data[COLUMN_LABEL] = data_y

	return data, data_x.mean(), data_x.std()

#--- 
开发者ID:jaromiru,项目名称:cwcf,代码行数:27,代码来源:conv_cifar_2.py

示例4: get_data

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Grayscale [as 别名]
def get_data(train):
	data_raw = datasets.CIFAR10('../data/dl/', train=train, download=True,  transform=transforms.Compose([
							transforms.Grayscale(),
							transforms.Resize((20, 20)),
							transforms.ToTensor(),
							lambda x: x.numpy().flatten()]))

	data_x, data_y = zip(*data_raw)
	
	data_x = np.array(data_x)
	data_y = np.array(data_y, dtype='int32').reshape(-1, 1)

	data = pd.DataFrame(data_x)
	data[COLUMN_LABEL] = data_y

	return data, data_x.mean(), data_x.std()

#--- 
开发者ID:jaromiru,项目名称:cwcf,代码行数:20,代码来源:conv_cifar.py

示例5: get_transforms

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Grayscale [as 别名]
def get_transforms(eval=False, aug=None):
    trans = []

    if aug["randcrop"] and not eval:
        trans.append(transforms.RandomCrop(aug["randcrop"]))

    if aug["randcrop"] and eval:
        trans.append(transforms.CenterCrop(aug["randcrop"]))

    if aug["flip"] and not eval:
        trans.append(transforms.RandomHorizontalFlip())

    if aug["grayscale"]:
        trans.append(transforms.Grayscale())
        trans.append(transforms.ToTensor())
        trans.append(transforms.Normalize(mean=aug["bw_mean"], std=aug["bw_std"]))
    elif aug["mean"]:
        trans.append(transforms.ToTensor())
        trans.append(transforms.Normalize(mean=aug["mean"], std=aug["std"]))
    else:
        trans.append(transforms.ToTensor())

    trans = transforms.Compose(trans)
    return trans 
开发者ID:loeweX,项目名称:Greedy_InfoMax,代码行数:26,代码来源:get_dataloader.py

示例6: __get_transforms

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Grayscale [as 别名]
def __get_transforms(self, patch_size):
        if self.gray_scale:
            train_transforms = transforms.Compose([
                transforms.Resize(size=(patch_size, patch_size)),
                transforms.Grayscale(num_output_channels=1),
                transforms.ToTensor(),
                ])

            val_transforms = transforms.Compose([
                transforms.Resize(size=(patch_size, patch_size)),
                transforms.Grayscale(num_output_channels=1),
                transforms.ToTensor(),
                ])
        else:
            train_transforms = transforms.Compose([
                transforms.Resize(size=(patch_size, patch_size)),
                transforms.ToTensor(),
            ])

            val_transforms = transforms.Compose([
                transforms.Resize(size=(patch_size, patch_size)),
                transforms.ToTensor(),
            ])

        return train_transforms, val_transforms 
开发者ID:MrtnMndt,项目名称:OCDVAEContinualLearning,代码行数:27,代码来源:datasets.py

示例7: load_data

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Grayscale [as 别名]
def load_data(root_dir,domain,batch_size):
    transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.Resize([28, 28]),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0,),std=(1,)),
    ]
    )
    image_folder = datasets.ImageFolder(
            root=root_dir + domain,
            transform=transform
        )
    data_loader = torch.utils.data.DataLoader(dataset=image_folder,batch_size=batch_size,shuffle=True,num_workers=2,drop_last=True
    )
    return data_loader 
开发者ID:jindongwang,项目名称:transferlearning,代码行数:17,代码来源:data_loader.py

示例8: load_test

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Grayscale [as 别名]
def load_test(root_dir,domain,batch_size):
    transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.Resize([28, 28]),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0,), std=(1,)),
    ]
    )
    image_folder = datasets.ImageFolder(
        root=root_dir + domain,
        transform=transform
    )
    data_loader = torch.utils.data.DataLoader(dataset=image_folder, batch_size=batch_size, shuffle=False, num_workers=2
                                              )
    return data_loader 
开发者ID:jindongwang,项目名称:transferlearning,代码行数:17,代码来源:data_loader.py

示例9: get_transform

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Grayscale [as 别名]
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
    transform_list = []
    if grayscale:
        transform_list.append(transforms.Grayscale(1))
    if 'resize' in opt.preprocess:
        osize = [opt.load_size, opt.load_size]
        transform_list.append(transforms.Resize(osize, method))
    elif 'scale_width' in opt.preprocess:
        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))

    if 'crop' in opt.preprocess:
        if params is None:
            transform_list.append(transforms.RandomCrop(opt.crop_size))
        else:
            transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))

    if opt.preprocess == 'none':
        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))

    if not opt.no_flip:
        if params is None:
            transform_list.append(transforms.RandomHorizontalFlip())
        elif params['flip']:
            transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
    ##
    if convert:
        transform_list += [transforms.ToTensor()]
        if grayscale:
            transform_list += [transforms.Normalize((0.5,), (0.5,))]
        else:
            transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list) 
开发者ID:Mingtzge,项目名称:2019-CCF-BDCI-OCR-MCZJ-OCR-IdentificationIDElement,代码行数:34,代码来源:base_dataset.py

示例10: get_transform

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Grayscale [as 别名]
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
    transform_list = []
    if grayscale:
        transform_list.append(transforms.Grayscale(1))
    if 'resize' in opt.preprocess:
        osize = [opt.load_size, opt.load_size]
        transform_list.append(transforms.Resize(osize, method))
    elif 'scale_width' in opt.preprocess:
        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))

    if 'crop' in opt.preprocess:
        if params is None:
            transform_list.append(transforms.RandomCrop(opt.crop_size))
        else:
            transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))

    if opt.preprocess == 'none':
        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))

    if not opt.no_flip:
        if params is None:
            transform_list.append(transforms.RandomHorizontalFlip())
        elif params['flip']:
            transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

    if convert:
        transform_list += [transforms.ToTensor(),
                           transforms.Normalize((0.5, 0.5, 0.5),
                                                (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list) 
开发者ID:WANG-Chaoyue,项目名称:EvolutionaryGAN-pytorch,代码行数:32,代码来源:base_dataset.py

示例11: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Grayscale [as 别名]
def __init__(self, img_size, mask_descriptor):
        self.img_size = img_size
        self.num_pixels = img_size[1] * img_size[2]
        self.mask_type, self.mask_attribute = mask_descriptor

        if self.mask_type == 'random_blob_cache':
            dset = datasets.ImageFolder(self.mask_attribute[0],
                                        transform=transforms.Compose([transforms.Grayscale(),
                                                                      transforms.ToTensor()]))
            self.data_loader = DataLoader(dset, batch_size=self.mask_attribute[1], shuffle=True) 
开发者ID:Schlumberger,项目名称:pixel-constrained-cnn-pytorch,代码行数:12,代码来源:masks.py

示例12: load_image

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Grayscale [as 别名]
def load_image(file,grayscale=False,target_size=None,to_tensor=True,mean=0.5,std=0.5,interpolation = Image.BILINEAR):

    """

    :param file:
    :param grayscale:
    :param target_size:
    :param to_tensor:
    :param mean:
    :param std:
    :param interpolation:
    :return:
    """
    img = Image.open(file).convert("RGB")

    transformations = []

    if grayscale:
        transformations.append(transforms.Grayscale())

    if target_size is not None:
        target_ = target_size
        if isinstance(target_size,int):
            target_ = (target_size,target_size)
        transformations.append(transforms.CenterCrop(target_))

    if to_tensor:
        transformations.append(transforms.ToTensor())

    if mean is not None and std is not None:
        if not isinstance(mean,tuple):
            mean = (mean,)
        if not isinstance(std,tuple):
            std = (std,)
        transformations.append(transforms.Normalize(mean=mean,std=std))

    trans_ = transforms.Compose(transformations)

    return trans_(img) 
开发者ID:johnolafenwa,项目名称:TorchFusion,代码行数:41,代码来源:utils.py

示例13: get_loader

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Grayscale [as 别名]
def get_loader(image_path, proto_same_path, proto_oppo_path, metadata_path,
               crop_size=(224, 224), image_size=(224, 224), batch_size=64,
               dataset='CelebA', mode='train',
               num_workers=1):
    """Build and return data loader."""

    if mode == 'train':
        transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.RandomCrop(size=crop_size),
            transforms.Resize(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
    else:
        transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.Resize(image_size),
            transforms.ToTensor()
        ])

    #if dataset == 'CelebA':
    dataset = CelebaDataset(image_path, proto_same_path, proto_oppo_path,
                            metadata_path, transform, mode) #, flip_rate=flip_rate)

    if mode == 'train':
        shuffle = True
    else:
        shuffle = False

    data_loader = DataLoader(dataset=dataset,
                             batch_size=batch_size,
                             shuffle=shuffle,
                             num_workers=num_workers)
    return data_loader 
开发者ID:iPRoBe-lab,项目名称:semi-adversarial-networks,代码行数:37,代码来源:dataset_loader.py

示例14: get_transform

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Grayscale [as 别名]
def get_transform(opt, grayscale=False, convert=True, crop=True, flip=True):
    """Create a torchvision transformation function

    The type of transformation is defined by option (e.g., [opt.preprocess], [opt.load_size], [opt.crop_size])
    and can be overwritten by arguments such as [convert], [crop], and [flip]

    Parameters:
        opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        grayscale (bool)   -- if convert input RGB image to a grayscale image
        convert (bool)     -- if convert an image to a tensor array betwen [-1, 1]
        crop    (bool)     -- if apply cropping
        flip    (bool)     -- if apply horizontal flippling
    """
    transform_list = []
    if grayscale:
        transform_list.append(transforms.Grayscale(1))
    if opt.preprocess == 'resize_and_crop':
        osize = [opt.load_size, opt.load_size]
        transform_list.append(transforms.Resize(osize, Image.BICUBIC))
        transform_list.append(transforms.RandomCrop(opt.crop_size))
    elif opt.preprocess == 'crop' and crop:
        transform_list.append(transforms.RandomCrop(opt.crop_size))
    elif opt.preprocess == 'scale_width':
        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.crop_size)))
    elif opt.preprocess == 'scale_width_and_crop':
        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size)))
        if crop:
            transform_list.append(transforms.RandomCrop(opt.crop_size))
    elif opt.preprocess == 'none':
        transform_list.append(transforms.Lambda(lambda img: __adjust(img)))
    else:
        raise ValueError('--preprocess %s is not a valid option.' % opt.preprocess)

    if not opt.no_flip and flip:
        transform_list.append(transforms.RandomHorizontalFlip())
    if convert:
        transform_list += [transforms.ToTensor(),
                           transforms.Normalize((0.5, 0.5, 0.5),
                                                (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list) 
开发者ID:RogerZhangzz,项目名称:CAG_UDA,代码行数:42,代码来源:base_dataset.py

示例15: __getitem__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Grayscale [as 别名]
def __getitem__(self, idx):

        imgName = self.listImg[idx]
        imgPath = os.path.join(self.pathdb, imgName)
        img = pil_loader(imgPath)

        if self.transform is not None:
            img = self.transform(img)

        # Build the attribute tensor
        attr = [0 for i in range(self.totAttribSize)]

        if self.hasAttrib:
            attribVals = self.attribDict[imgName]
            for key, val in attribVals.items():
                baseShift = self.shiftAttrib[key]
                attr[baseShift] = self.shiftAttribVal[key][val]
        else:
            attr = [0]

        if self.pathMask is not None:
            mask_path = os.path.join(
                self.pathMask, os.path.splitext(imgName)[0] + "_mask.jpg")
            mask = pil_loader(mask_path)
            mask = Transforms.Grayscale(1)(mask)
            mask = self.transform(mask)

            return img, torch.tensor(attr, dtype=torch.long), mask

        return img, torch.tensor(attr, dtype=torch.long) 
开发者ID:facebookresearch,项目名称:pytorch_GAN_zoo,代码行数:32,代码来源:attrib_dataset.py


注:本文中的torchvision.transforms.Grayscale方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。