当前位置: 首页>>代码示例>>Python>>正文


Python transforms.Lambda方法代码示例

本文整理汇总了Python中torchvision.transforms.Lambda方法的典型用法代码示例。如果您正苦于以下问题:Python transforms.Lambda方法的具体用法?Python transforms.Lambda怎么用?Python transforms.Lambda使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torchvision.transforms的用法示例。


在下文中一共展示了transforms.Lambda方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: scale_crop

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Lambda [as 别名]
def scale_crop(input_size, scale_size=None, num_crops=1, normalize=_IMAGENET_STATS):
    assert num_crops in [1, 5, 10], "num crops must be in {1,5,10}"
    convert_tensor = transforms.Compose([transforms.ToTensor(),
                                         transforms.Normalize(**normalize)])
    if num_crops == 1:
        t_list = [
            transforms.CenterCrop(input_size),
            convert_tensor
        ]
    else:
        if num_crops == 5:
            t_list = [transforms.FiveCrop(input_size)]
        elif num_crops == 10:
            t_list = [transforms.TenCrop(input_size)]
        # returns a 4D tensor
        t_list.append(transforms.Lambda(lambda crops:
                                        torch.stack([convert_tensor(crop) for crop in crops])))

    if scale_size != input_size:
        t_list = [transforms.Resize(scale_size)] + t_list

    return transforms.Compose(t_list) 
开发者ID:eladhoffer,项目名称:convNet.pytorch,代码行数:24,代码来源:preprocess.py

示例2: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Lambda [as 别名]
def __init__(self, train_mode, loader_params, dataset_params, augmentation_params):
        super().__init__(train_mode, loader_params, dataset_params, augmentation_params)

        self.image_transform = transforms.Compose([transforms.Grayscale(num_output_channels=3),
                                                   transforms.ToTensor(),
                                                   transforms.Normalize(mean=self.dataset_params.MEAN,
                                                                        std=self.dataset_params.STD),
                                                   ])
        self.mask_transform = transforms.Compose([transforms.Lambda(to_array),
                                                  transforms.Lambda(to_tensor),
                                                  ])

        self.image_augment_train = ImgAug(self.augmentation_params['image_augment_train'])
        self.image_augment_with_target_train = ImgAug(self.augmentation_params['image_augment_with_target_train'])
        self.image_augment_inference = ImgAug(self.augmentation_params['image_augment_inference'])
        self.image_augment_with_target_inference = ImgAug(
            self.augmentation_params['image_augment_with_target_inference'])

        if self.dataset_params.target_format == 'png':
            self.dataset = ImageSegmentationPngDataset
        elif self.dataset_params.target_format == 'json':
            self.dataset = ImageSegmentationJsonDataset
        else:
            raise Exception('files must be png or json') 
开发者ID:minerva-ml,项目名称:steppy-toolkit,代码行数:26,代码来源:segmentation.py

示例3: get_mnist

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Lambda [as 别名]
def get_mnist(train, get_dataset=False, batch_size=cfg.batch_size):
    """Get MNIST dataset loader."""
    # image pre-processing
    convert_to_3_channels = transforms.Lambda(
        lambda x: torch.cat([x, x, x], 0))
    pre_process = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize(
                                          mean=cfg.dataset_mean,
                                          std=cfg.dataset_std),
                                      convert_to_3_channels])

    # dataset and data loader
    mnist_dataset = datasets.MNIST(root=cfg.data_root,
                                   train=train,
                                   transform=pre_process,
                                   download=True)

    if get_dataset:
        return mnist_dataset
    else:
        mnist_data_loader = torch.utils.data.DataLoader(
            dataset=mnist_dataset,
            batch_size=batch_size,
            shuffle=True)
        return mnist_data_loader 
开发者ID:corenel,项目名称:pytorch-atda,代码行数:27,代码来源:mnist.py

示例4: get_transform

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Lambda [as 别名]
def get_transform(params, image_size, num_channels):
    # Transforms for PIL Images: Gray <-> RGB
    Gray2RGB = transforms.Lambda(lambda x: x.convert('RGB'))
    RGB2Gray = transforms.Lambda(lambda x: x.convert('L'))

    transform = []
    # Does size request match original size?
    if not image_size == params.image_size:
        transform.append(transforms.Resize(image_size))
   
    # Does number of channels requested match original?
    if not num_channels == params.num_channels:
        if num_channels == 1:
            transform.append(RGB2Gray)
        elif num_channels == 3:
            transform.append(Gray2RGB)
        else:
            print('NumChannels should be 1 or 3', num_channels)
            raise Exception

    transform += [transforms.ToTensor(), 
            transforms.Normalize((params.mean,), (params.std,))]

    return transforms.Compose(transform) 
开发者ID:jhoffman,项目名称:cycada_release,代码行数:26,代码来源:data_loader.py

示例5: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Lambda [as 别名]
def __init__(self, train_mode, loader_params, dataset_params, augmentation_params):
        super().__init__(train_mode, loader_params, dataset_params, augmentation_params)

        self.image_transform = transforms.Compose([transforms.Grayscale(num_output_channels=3),
                                                   transforms.ToTensor(),
                                                   transforms.Normalize(mean=self.dataset_params.MEAN,
                                                                        std=self.dataset_params.STD),
                                                   AddDepthChannels()
                                                   ])
        self.mask_transform = transforms.Lambda(preprocess_emptiness_target)

        self.image_augment_train = ImgAug(self.augmentation_params['image_augment_train'])
        self.image_augment_with_target_train = ImgAug(self.augmentation_params['image_augment_with_target_train'])
        self.image_augment_inference = ImgAug(self.augmentation_params['image_augment_inference'])
        self.image_augment_with_target_inference = ImgAug(
            self.augmentation_params['image_augment_with_target_inference'])

        self.dataset = EmptinessDataset 
开发者ID:neptune-ai,项目名称:open-solution-salt-identification,代码行数:20,代码来源:loaders.py

示例6: get_transform

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Lambda [as 别名]
def get_transform():
    transform_image_list = [
        transforms.Resize((256, 256), 3),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]

    transform_gt_list = [
        transforms.Resize((256, 256), 0),
        transforms.Lambda(lambda img: np.asarray(img, dtype=np.uint8)),
    ]

    data_transforms = {
        'img': transforms.Compose(transform_image_list),
        'gt': transforms.Compose(transform_gt_list),
    }
    return data_transforms 
开发者ID:hyk1996,项目名称:Single-Human-Parsing-LIP,代码行数:19,代码来源:eval.py

示例7: get_transform

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Lambda [as 别名]
def get_transform(opt):
    transform_list = []
    if opt.resize_or_crop == 'resize_and_crop':
        osize = [opt.loadSize, opt.loadSize]
        transform_list.append(transforms.Scale(osize, Image.BICUBIC))
        transform_list.append(transforms.RandomCrop(opt.fineSize))
    elif opt.resize_or_crop == 'crop':
        transform_list.append(transforms.RandomCrop(opt.fineSize))
    elif opt.resize_or_crop == 'scale_width':
        transform_list.append(transforms.Lambda(
            lambda img: __scale_width(img, opt.fineSize)))
    elif opt.resize_or_crop == 'scale_width_and_crop':
        transform_list.append(transforms.Lambda(
            lambda img: __scale_width(img, opt.loadSize)))
        transform_list.append(transforms.RandomCrop(opt.fineSize))

    if opt.isTrain and not opt.no_flip:
        transform_list.append(transforms.RandomHorizontalFlip())

    transform_list += [transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5),
                                            (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list) 
开发者ID:aayushbansal,项目名称:Recycle-GAN,代码行数:25,代码来源:base_dataset.py

示例8: _transform_row

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Lambda [as 别名]
def _transform_row(mnist_row):
    # For this example, the images are stored as simpler ndarray (28,28), but the
    # training network expects 3-dim images, hence the additional lambda transform.
    transform = transforms.Compose([
        transforms.Lambda(lambda nd: nd.reshape(28, 28, 1)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    # In addition, the petastorm pytorch DataLoader does not distinguish the notion of
    # data or target transform, but that actually gives the user more flexibility
    # to make the desired partial transform, as shown here.
    result_row = {
        'image': transform(mnist_row['image']),
        'digit': mnist_row['digit']
    }

    return result_row 
开发者ID:uber,项目名称:petastorm,代码行数:19,代码来源:pytorch_example.py

示例9: test_torch_transform_spec

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Lambda [as 别名]
def test_torch_transform_spec(spark_test_ctx):
    df = spark_test_ctx.spark.range(8)
    conv = make_spark_converter(df)

    from torchvision import transforms
    from petastorm import TransformSpec

    def _transform_row(df_row):
        scale_tranform = transforms.Compose([
            transforms.Lambda(lambda x: x * 0.1),
        ])
        return scale_tranform(df_row)

    transform = TransformSpec(_transform_row)
    with conv.make_torch_dataloader(transform_spec=transform,
                                    num_epochs=1) as dataloader:
        for batch in dataloader:
            assert min(batch['id']) >= 0 and max(batch['id']) < 1 
开发者ID:uber,项目名称:petastorm,代码行数:20,代码来源:test_spark_dataset_converter.py

示例10: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Lambda [as 别名]
def __init__(self, train_mode, loader_params, dataset_params, augmentation_params):
        super().__init__(train_mode, loader_params, dataset_params, augmentation_params)

        self.image_transform = transforms.Compose([transforms.ToTensor(),
                                                   transforms.Normalize(mean=self.dataset_params.MEAN,
                                                                        std=self.dataset_params.STD),
                                                   ])
        self.mask_transform = transforms.Compose([transforms.Lambda(preprocess_target),
                                                  ])

        self.image_augment_train = ImgAug(self.augmentation_params['image_augment_train'])
        self.image_augment_with_target_train = ImgAug(self.augmentation_params['image_augment_with_target_train'])
        self.image_augment_inference = ImgAug(self.augmentation_params['image_augment_inference'])
        self.image_augment_with_target_inference = ImgAug(
            self.augmentation_params['image_augment_with_target_inference'])

        if self.dataset_params.target_format == 'png':
            self.dataset = ImageSegmentationPngDataset
        elif self.dataset_params.target_format == 'json':
            self.dataset = ImageSegmentationJsonDataset
        elif self.dataset_params.target_format == 'joblib':
            self.dataset = ImageSegmentationJoblibDataset
        else:
            raise Exception('files must be png or json') 
开发者ID:minerva-ml,项目名称:open-solution-ship-detection,代码行数:26,代码来源:loaders.py

示例11: create_loaders

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Lambda [as 别名]
def create_loaders():

    kwargs = {'num_workers': args.num_workers, 'pin_memory': args.pin_memory} if args.cuda else {}
    transform = transforms.Compose([
            transforms.Lambda(np_reshape),
            transforms.ToTensor()
            ])

    train_loader = torch.utils.data.DataLoader(
            TotalDatasetsLoader(datasets_path = args.dataroot, train=True,
                             n_triplets = args.n_pairs,
                             fliprot=True,
                             batch_size=args.batch_size,
                             download=True,
                             transform=transform),
                             batch_size=args.batch_size,
                             shuffle=False, **kwargs)
    return train_loader, None 
开发者ID:ducha-aiki,项目名称:affnet,代码行数:20,代码来源:train_AffNet_test_on_graffity.py

示例12: get_transform

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Lambda [as 别名]
def get_transform(opt):
    transform_list = []
    if opt.resize_or_crop == 'resize_and_crop':
        osize = [opt.loadSize, opt.loadSize]
        transform_list.append(transforms.Scale(osize, Image.BICUBIC))
        transform_list.append(transforms.RandomCrop(opt.fineSize))
    elif opt.resize_or_crop == 'crop':
        transform_list.append(transforms.RandomCrop(opt.fineSize))
    elif opt.resize_or_crop == 'scale_width':
        transform_list.append(transforms.Lambda(
            lambda img: __scale_width(img, opt.fineSize)))
    elif opt.resize_or_crop == 'scale_width_and_crop':
        transform_list.append(transforms.Lambda(
            lambda img: __scale_width(img, opt.loadSize)))
        transform_list.append(transforms.RandomCrop(opt.fineSize))

    # if opt.isTrain and not opt.no_flip:
    #     transform_list.append(transforms.RandomHorizontalFlip())

    transform_list += [transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5),
                                            (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list) 
开发者ID:jessemelpolio,项目名称:non-stationary_texture_syn,代码行数:25,代码来源:base_dataset.py

示例13: create_test_transforms

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Lambda [as 别名]
def create_test_transforms(config, crop, scale, ten_crops):
    normalize = transforms.Normalize(mean=config["mean"], std=config["std"])

    val_transforms = []
    if scale != -1:
        val_transforms.append(transforms.Resize(scale))
    if ten_crops:
        val_transforms += [
            transforms.TenCrop(crop),
            transforms.Lambda(lambda crops: [transforms.ToTensor()(crop) for crop in crops]),
            transforms.Lambda(lambda crops: [normalize(crop) for crop in crops]),
            transforms.Lambda(lambda crops: torch.stack(crops))
        ]
    else:
        val_transforms += [
            transforms.CenterCrop(crop),
            transforms.ToTensor(),
            normalize
        ]

    return val_transforms 
开发者ID:mapillary,项目名称:inplace_abn,代码行数:23,代码来源:utils.py

示例14: test_transform

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Lambda [as 别名]
def test_transform(self):
        train_transform = Lambda(lambda k: 1)
        test_transform = Lambda(lambda k: 0)
        dataset = ActiveLearningDataset(MyDataset(train_transform),
                                        pool_specifics={'transform': test_transform},
                                        make_unlabelled=lambda x: (x[0], -1))
        dataset.label(np.arange(10))
        pool = dataset.pool
        assert np.equal([i for i in pool], [(0, -1) for i in np.arange(10, 100)]).all()
        assert np.equal([i for i in dataset], [(1, i) for i in np.arange(10)]).all()

        with pytest.warns(DeprecationWarning) as e:
            ActiveLearningDataset(MyDataset(train_transform), eval_transform=train_transform)
        assert len(e) == 1

        with pytest.raises(ValueError) as e:
            ActiveLearningDataset(MyDataset(train_transform), pool_specifics={'whatever': 123}).pool 
开发者ID:ElementAI,项目名称:baal,代码行数:19,代码来源:dataset_test.py

示例15: img_transformer

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Lambda [as 别名]
def img_transformer(self):
        transform_list = []
        if self.opt.resize_or_crop == 'resize_and_crop':
            transform_list.append(transforms.Resize([self.opt.load_size, self.opt.load_size], Image.BICUBIC))
            transform_list.append(transforms.RandomCrop(self.opt.final_size))
        elif self.opt.resize_or_crop == 'crop':
            transform_list.append(transforms.RandomCrop(self.opt.final_size))
        elif self.opt.resize_or_crop == 'none':
            transform_list.append(transforms.Lambda(lambda image: image))
        else:
            raise ValueError("--resize_or_crop %s is not a valid option." % self.opt.resize_or_crop)

        if self.is_train and not self.opt.no_flip:
            transform_list.append(transforms.RandomHorizontalFlip())

        transform_list.append(transforms.ToTensor())
        transform_list.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))

        img2tensor = transforms.Compose(transform_list)

        return img2tensor 
开发者ID:donydchen,项目名称:ganimation_replicate,代码行数:23,代码来源:base_dataset.py


注:本文中的torchvision.transforms.Lambda方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。