当前位置: 首页>>代码示例>>Python>>正文


Python transforms.CenterCrop方法代码示例

本文整理汇总了Python中torchvision.transforms.CenterCrop方法的典型用法代码示例。如果您正苦于以下问题:Python transforms.CenterCrop方法的具体用法?Python transforms.CenterCrop怎么用?Python transforms.CenterCrop使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torchvision.transforms的用法示例。


在下文中一共展示了transforms.CenterCrop方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _get_ds_val

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def _get_ds_val(self, images_spec, crop=False, truncate=False):
        img_to_tensor_t = [images_loader.IndexImagesDataset.to_tensor_uint8_transform()]
        if crop:
            img_to_tensor_t.insert(0, transforms.CenterCrop(crop))
        img_to_tensor_t = transforms.Compose(img_to_tensor_t)

        fixed_first = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'fixedimg.jpg')
        if not os.path.isfile(fixed_first):
            print(f'INFO: No file found at {fixed_first}')
            fixed_first = None

        ds = images_loader.IndexImagesDataset(
                images=images_loader.ImagesCached(
                        images_spec, self.config_dl.image_cache_pkl,
                        min_size=self.config_dl.val_glob_min_size),
                to_tensor_transform=img_to_tensor_t,
                fixed_first=fixed_first)  # fix a first image to have consistency in tensor board

        if truncate:
            ds = pe.TruncatedDataset(ds, num_elemens=truncate)

        return ds 
开发者ID:fab-jul,项目名称:L3C-PyTorch,代码行数:24,代码来源:multiscale_trainer.py

示例2: get_lsun_dataloader

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def get_lsun_dataloader(path_to_data='../lsun', dataset='bedroom_train',
                        batch_size=64):
    """LSUN dataloader with (128, 128) sized images.
    path_to_data : str
        One of 'bedroom_val' or 'bedroom_train'
    """
    # Compose transforms
    transform = transforms.Compose([
        transforms.Resize(128),
        transforms.CenterCrop(128),
        transforms.ToTensor()
    ])

    # Get dataset
    lsun_dset = datasets.LSUN(db_path=path_to_data, classes=[dataset],
                              transform=transform)

    # Create dataloader
    return DataLoader(lsun_dset, batch_size=batch_size, shuffle=True) 
开发者ID:vandit15,项目名称:Self-Supervised-Gans-Pytorch,代码行数:21,代码来源:dataloaders.py

示例3: save_distorted

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def save_distorted(method=gaussian_noise):
    for severity in range(1, 6):
        print(method.__name__, severity)
        distorted_dataset = DistortImageFolder(
            root="/share/data/vision-greg/ImageNet/clsloc/images/val",
            method=method, severity=severity,
            transform=trn.Compose([trn.Resize(256), trn.CenterCrop(224)]))
        distorted_dataset_loader = torch.utils.data.DataLoader(
            distorted_dataset, batch_size=100, shuffle=False, num_workers=4)

        for _ in distorted_dataset_loader: continue


# /////////////// End Further Setup ///////////////


# /////////////// Display Results /////////////// 
开发者ID:hendrycks,项目名称:robustness,代码行数:19,代码来源:make_imagenet_c.py

示例4: transform

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def transform(is_train=True, normalize=True):
    """
    Returns a transform object
    """
    filters = []
    filters.append(Scale(256))

    if is_train:
        filters.append(RandomCrop(224))
    else:
        filters.append(CenterCrop(224))

    if is_train:
        filters.append(RandomHorizontalFlip())

    filters.append(ToTensor())
    if normalize:
        filters.append(Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]))
    return Compose(filters) 
开发者ID:uwnlp,项目名称:verb-attributes,代码行数:22,代码来源:imsitu_loader.py

示例5: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def __init__(
        self,
        resize: int = ImagenetConstants.RESIZE,
        crop_size: int = ImagenetConstants.CROP_SIZE,
        mean: List[float] = ImagenetConstants.MEAN,
        std: List[float] = ImagenetConstants.STD,
    ):
        """The constructor method of ImagenetNoAugmentTransform class.

        Args:
            resize: expected image size per dimension after resizing
            crop_size: expected size for a dimension of central cropping
            mean: a 3-tuple denoting the pixel RGB mean
            std: a 3-tuple denoting the pixel RGB standard deviation

        """
        self.transform = transforms.Compose(
            [
                transforms.Resize(resize),
                transforms.CenterCrop(crop_size),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std),
            ]
        ) 
开发者ID:facebookresearch,项目名称:ClassyVision,代码行数:26,代码来源:util.py

示例6: make

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def make(sz_resize = 256, sz_crop = 227, mean = [104, 117, 128],
        std = [1, 1, 1], rgb_to_bgr = True, is_train = True,
        intensity_scale = None):
    return transforms.Compose([
        RGBToBGR() if rgb_to_bgr else Identity(),
        transforms.RandomResizedCrop(sz_crop) if is_train else Identity(),
        transforms.Resize(sz_resize) if not is_train else Identity(),
        transforms.CenterCrop(sz_crop) if not is_train else Identity(),
        transforms.RandomHorizontalFlip() if is_train else Identity(),
        transforms.ToTensor(),
        ScaleIntensities(
            *intensity_scale) if intensity_scale is not None else Identity(),
        transforms.Normalize(
            mean=mean,
            std=std,
        )
    ]) 
开发者ID:CompVis,项目名称:metric-learning-divide-and-conquer,代码行数:19,代码来源:transform.py

示例7: test_on_validation_set

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def test_on_validation_set(model, validation_set=None):

    if validation_set is None:
        validation_set = get_validation_set()

    total_ssim = 0
    total_psnr = 0
    iters = len(validation_set.tuples)

    crop = CenterCrop(config.CROP_SIZE)

    for i, tup in enumerate(validation_set.tuples):
        x1, gt, x2, = [crop(load_img(p)) for p in tup]
        pred = interpolate(model, x1, x2)
        gt = pil_to_tensor(gt)
        pred = pil_to_tensor(pred)
        total_ssim += ssim(pred, gt).item()
        total_psnr += psnr(pred, gt).item()
        print(f'#{i+1} done')

    avg_ssim = total_ssim / iters
    avg_psnr = total_psnr / iters

    print(f'avg_ssim: {avg_ssim}, avg_psnr: {avg_psnr}') 
开发者ID:martkartasev,项目名称:sepconv,代码行数:26,代码来源:experiments.py

示例8: test_linear_interp

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def test_linear_interp(validation_set=None):

    if validation_set is None:
        validation_set = get_validation_set()

    total_ssim = 0
    total_psnr = 0
    iters = len(validation_set.tuples)

    crop = CenterCrop(config.CROP_SIZE)

    for tup in validation_set.tuples:
        x1, gt, x2, = [pil_to_tensor(crop(load_img(p))) for p in tup]
        pred = torch.mean(torch.stack((x1, x2), dim=0), dim=0)
        total_ssim += ssim(pred, gt).item()
        total_psnr += psnr(pred, gt).item()

    avg_ssim = total_ssim / iters
    avg_psnr = total_psnr / iters

    print(f'avg_ssim: {avg_ssim}, avg_psnr: {avg_psnr}') 
开发者ID:martkartasev,项目名称:sepconv,代码行数:23,代码来源:experiments.py

示例9: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def __init__(self, patches, use_cache, augment_data):
        super(PatchDataset, self).__init__()
        self.patches = patches
        self.crop = CenterCrop(config.CROP_SIZE)

        if augment_data:
            self.random_transforms = [RandomRotation((90, 90)), RandomVerticalFlip(1.0), RandomHorizontalFlip(1.0),
                                      (lambda x: x)]
            self.get_aug_transform = (lambda: random.sample(self.random_transforms, 1)[0])
        else:
            # Transform does nothing. Not sure if horrible or very elegant...
            self.get_aug_transform = (lambda: (lambda x: x))

        if use_cache:
            self.load_patch = data_manager.load_cached_patch
        else:
            self.load_patch = data_manager.load_patch

        print('Dataset ready with {} tuples.'.format(len(patches))) 
开发者ID:martkartasev,项目名称:sepconv,代码行数:21,代码来源:dataset.py

示例10: preprocess

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def preprocess(self):
        if self.train:
            return transforms.Compose([
                transforms.RandomResizedCrop(self.image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
                transforms.ToTensor(),
                transforms.Normalize(self.mean, self.std),
            ])
        else:
            return transforms.Compose([
                transforms.Resize((int(self.image_size / 0.875), int(self.image_size / 0.875))),
                transforms.CenterCrop(self.image_size),
                transforms.ToTensor(),
                transforms.Normalize(self.mean, self.std),
            ]) 
开发者ID:wandering007,项目名称:nasnet-pytorch,代码行数:18,代码来源:imagenet.py

示例11: __getitem__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def __getitem__(self, index):
        # get downscaled, cropped and gt (if available) image
        hr_image = Image.open(self.hr_files[index])
        w, h = hr_image.size
        cs = utils.calculate_valid_crop_size(min(w, h), self.upscale_factor)
        if self.crop_size is not None:
            cs = min(cs, self.crop_size)
        cropped_image = TF.to_tensor(T.CenterCrop(cs // self.upscale_factor)(hr_image))
        hr_image = T.CenterCrop(cs)(hr_image)
        hr_image = TF.to_tensor(hr_image)
        resized_image = utils.imresize(hr_image, 1.0 / self.upscale_factor, True)
        if self.lr_files is None:
            return resized_image, cropped_image, resized_image
        else:
            lr_image = Image.open(self.lr_files[index])
            lr_image = TF.to_tensor(T.CenterCrop(cs // self.upscale_factor)(lr_image))
            return resized_image, cropped_image, lr_image 
开发者ID:ManuelFritsche,项目名称:real-world-sr,代码行数:19,代码来源:data_loader.py

示例12: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def __init__(self, options):
        transform_list = []
        if options.image_size is not None:
            transform_list.append(transforms.Resize((options.image_size, options.image_size)))
            # transform_list.append(transforms.CenterCrop(options.image_size))
        transform_list.append(transforms.ToTensor())
        if options.image_colors == 1:
            transform_list.append(transforms.Normalize(mean=[0.5], std=[0.5]))
        elif options.image_colors == 3:
            transform_list.append(transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]))
        transform = transforms.Compose(transform_list)

        dataset = ImagePairs(options.data_dir, split=options.split, transform=transform)

        self.dataloader = DataLoader(
            dataset,
            batch_size=options.batch_size,
            num_workers=options.loader_workers,
            shuffle=True,
            drop_last=True,
            pin_memory=options.pin_memory
        )
        self.iterator = iter(self.dataloader) 
开发者ID:unicredit,项目名称:ganzo,代码行数:25,代码来源:data.py

示例13: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def __init__(self, path, classes, stage='train'):
        self.data = []
        for i, c in enumerate(classes):
            cls_path = osp.join(path, c)
            images = os.listdir(cls_path)
            for image in images:
                self.data.append((osp.join(cls_path, image), i))

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        
        if stage == 'train':
            self.transforms = transforms.Compose([transforms.RandomResizedCrop(224),
                                                  transforms.RandomHorizontalFlip(),
                                                  transforms.ToTensor(),
                                                  normalize])
        if stage == 'test':
            self.transforms = transforms.Compose([transforms.Resize(256),
                                                  transforms.CenterCrop(224),
                                                  transforms.ToTensor(),
                                                  normalize]) 
开发者ID:cyvius96,项目名称:DGP,代码行数:23,代码来源:image_folder.py

示例14: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def __init__(self, opt):
        self.image_path = opt.dataroot
        self.is_train = opt.is_train
        self.d_num = opt.n_attribute
        print ('Start preprocessing dataset..!')
        random.seed(1234)
        self.preprocess()
        print ('Finished preprocessing dataset..!')
        if self.is_train:
            trs = [transforms.Resize(opt.load_size, interpolation=Image.ANTIALIAS), transforms.RandomCrop(opt.fine_size)]
        else:
            trs = [transforms.Resize(opt.load_size, interpolation=Image.ANTIALIAS), transforms.CenterCrop(opt.fine_size)]
        if opt.is_flip:
            trs.append(transforms.RandomHorizontalFlip())
        self.transform = transforms.Compose(trs)
        self.norm = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        self.num_data = max(self.num) 
开发者ID:Xiaoming-Yu,项目名称:DMIT,代码行数:21,代码来源:season_transfer_dataset.py

示例15: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import CenterCrop [as 别名]
def __init__(self, opt):
    '''Initialize this dataset class.
       We need to specific the path of the dataset and the domain label of each image.
    '''
        self.image_list = []
        self.label_list = []
        if opt.is_train:
            trs = [transforms.Resize(opt.load_size, interpolation=Image.ANTIALIAS), transforms.RandomCrop(opt.fine_size)]
        else:
            trs = [transforms.Resize(opt.load_size, interpolation=Image.ANTIALIAS), transforms.CenterCrop(opt.fine_size)]
        if opt.is_flip:
            trs.append(transforms.RandomHorizontalFlip())
        trs.append(transforms.ToTensor())
        trs.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
        self.transform = transforms.Compose(trs)
        self.num_data = len(self.image_list) 
开发者ID:Xiaoming-Yu,项目名称:DMIT,代码行数:18,代码来源:template_dataset.py


注:本文中的torchvision.transforms.CenterCrop方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。