当前位置: 首页>>代码示例>>Python>>正文


Python transforms.RandomRotation方法代码示例

本文整理汇总了Python中torchvision.transforms.RandomRotation方法的典型用法代码示例。如果您正苦于以下问题:Python transforms.RandomRotation方法的具体用法?Python transforms.RandomRotation怎么用?Python transforms.RandomRotation使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torchvision.transforms的用法示例。


在下文中一共展示了transforms.RandomRotation方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: initialize_dataset

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import RandomRotation [as 别名]
def initialize_dataset(clevr_dir, dictionaries, state_description=True):
    if not state_description:
        train_transforms = transforms.Compose([transforms.Resize((128, 128)),
                                           transforms.Pad(8),
                                           transforms.RandomCrop((128, 128)),
                                           transforms.RandomRotation(2.8),  # .05 rad
                                           transforms.ToTensor()])
        test_transforms = transforms.Compose([transforms.Resize((128, 128)),
                                          transforms.ToTensor()])
                                          
        clevr_dataset_train = ClevrDataset(clevr_dir, True, dictionaries, train_transforms)
        clevr_dataset_test = ClevrDataset(clevr_dir, False, dictionaries, test_transforms)
        
    else:
        clevr_dataset_train = ClevrDatasetStateDescription(clevr_dir, True, dictionaries)
        clevr_dataset_test = ClevrDatasetStateDescription(clevr_dir, False, dictionaries)
    
    return clevr_dataset_train, clevr_dataset_test 
开发者ID:mesnico,项目名称:RelationNetworks-CLEVR,代码行数:20,代码来源:train.py

示例2: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import RandomRotation [as 别名]
def __init__(self, config=None, batch_size=50, epoch_len = 100, seq_len = 8, transform=None, training_path = '/home/mayank/Desktop/GitRepos/crnn-pytorch_mnist/data/processed/training.pt', ip_channels = 1, size=(20,200), Type='train', profiler = None, target_transform = None):

		self.training_path = training_path
		self.abc = '0123456789'
		self.seq_len = seq_len
		self.epoch_len = epoch_len
		self.transform = transform

		self.train_data, self.train_labels = torch.load(training_path)
		self.num_total = len(self.train_labels)
		self.final_size = size
		self.normal_mean = 7
		self.clip = (1,40)
		self.ip_channels = ip_channels
		self.resized_shape = (*size,ip_channels)

		self.target_aspect_ratio = 10

		self.out_dir = 'out'
		self.rotate = RandomRotation(10)

		self.batch_size = batch_size
		self.encoding_to_char = {1: '0', 2:'1', 3:'2', 4:'3', 5:'4', 6:'5', 7:'6', 8:'7', 9:'8', 10:'9'} 
开发者ID:mayank-git-hub,项目名称:Text-Recognition,代码行数:25,代码来源:mnist.py

示例3: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import RandomRotation [as 别名]
def __init__(self, patches, use_cache, augment_data):
        super(PatchDataset, self).__init__()
        self.patches = patches
        self.crop = CenterCrop(config.CROP_SIZE)

        if augment_data:
            self.random_transforms = [RandomRotation((90, 90)), RandomVerticalFlip(1.0), RandomHorizontalFlip(1.0),
                                      (lambda x: x)]
            self.get_aug_transform = (lambda: random.sample(self.random_transforms, 1)[0])
        else:
            # Transform does nothing. Not sure if horrible or very elegant...
            self.get_aug_transform = (lambda: (lambda x: x))

        if use_cache:
            self.load_patch = data_manager.load_cached_patch
        else:
            self.load_patch = data_manager.load_patch

        print('Dataset ready with {} tuples.'.format(len(patches))) 
开发者ID:martkartasev,项目名称:sepconv,代码行数:21,代码来源:dataset.py

示例4: test_filedataset_segmentation

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import RandomRotation [as 别名]
def test_filedataset_segmentation(self):
        target_trans = Compose([default_image_load_fn,
                                Resize(60), RandomRotation(90), ToTensor()])
        file_dataset = FileDataset(self.paths, self.paths, self.transform, target_trans, seed=1337)
        x, y = file_dataset[0]
        assert np.allclose(x.numpy(), y.numpy())
        out1 = list(DataLoader(file_dataset, batch_size=1, num_workers=3, shuffle=False))
        out2 = list(DataLoader(file_dataset, batch_size=1, num_workers=3, shuffle=False))
        assert all([np.allclose(x1.numpy(), x2.numpy())
                    for (x1, _), (x2, _) in zip(out1, out2)])

        file_dataset = FileDataset(self.paths, self.paths, self.transform, target_trans, seed=None)
        x, y = file_dataset[0]
        assert np.allclose(x.numpy(), y.numpy())
        out1 = list(DataLoader(file_dataset, batch_size=1, num_workers=3, shuffle=False))
        out2 = list(DataLoader(file_dataset, batch_size=1, num_workers=3, shuffle=False))
        assert not all([np.allclose(x1.numpy(), x2.numpy())
                        for (x1, _), (x2, _) in zip(out1, out2)]) 
开发者ID:ElementAI,项目名称:baal,代码行数:20,代码来源:file_dataset_test.py

示例5: test_segmentation_pipeline

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import RandomRotation [as 别名]
def test_segmentation_pipeline(self):
        class DrawSquare:
            def __init__(self, side):
                self.side = side

            def __call__(self, x, **kwargs):
                x, canvas = x  # x is a [int, ndarray]
                canvas[:self.side, :self.side] = x
                return canvas

        target_trans = BaaLCompose(
            [GetCanvas(), DrawSquare(3), ToPILImage(mode=None), Resize(60, interpolation=0),
             RandomRotation(10, resample=NEAREST, fill=0.0), PILToLongTensor()])
        file_dataset = FileDataset(self.paths, [1] * len(self.paths), self.transform, target_trans)

        x, y = file_dataset[0]
        assert np.allclose(np.unique(y), [0, 1])
        assert y.shape[1:] == x.shape[1:] 
开发者ID:ElementAI,项目名称:baal,代码行数:20,代码来源:file_dataset_test.py

示例6: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import RandomRotation [as 别名]
def __init__(self, transform, mode, select_attrs=[], out_img_size=64, bbox_out_size=32, randomrotate=0, scaleRange=[0.1, 0.9], squareAspectRatio=False, use_celeb=False):
        self.image_path = os.path.join('data','mnist')
        self.mode = mode
        self.iouThresh = 0.5
        self.maxDigits= 1
        self.minDigits = 1
        self.use_celeb = use_celeb
        self.scaleRange = scaleRange
        self.squareAspectRatio = squareAspectRatio
        self.nc = 1 if not self.use_celeb else 3
        transList = [transforms.RandomHorizontalFlip(), transforms.RandomRotation(randomrotate,resample=Image.BICUBIC)]#, transforms.ColorJitter(0.5,0.5,0.5,0.3)
        self.digitTransforms = transforms.Compose(transList)
        self.dataset = MNIST(self.image_path,train=True, transform=self.digitTransforms) if not use_celeb else CelebDataset('./data/celebA/images', './data/celebA/list_attr_celeba.txt', self.digitTransforms, mode)
        self.num_data = len(self.dataset)
        self.metadata = {'images':[]}
        self.catid2attr = {}
        self.out_img_size = out_img_size
        self.bbox_out_size = bbox_out_size
        self.selected_attrs = select_attrs

        print ('Start preprocessing dataset..!')
        self.preprocess()
        print ('Finished preprocessing dataset..!') 
开发者ID:rakshithShetty,项目名称:adversarial-object-removal,代码行数:25,代码来源:data_loader_stargan.py

示例7: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import RandomRotation [as 别名]
def __init__(self, name, size, scale, ratio, colorjitter):
        self.transfs = {
            'val': transforms.Compose([
                transforms.Resize(size),
                transforms.CenterCrop(size=size),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]),
            'train': transforms.Compose([
                transforms.RandomResizedCrop(size, scale=scale, ratio=ratio),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=colorjitter[0], 
                    contrast=colorjitter[1], 
                    saturation=colorjitter[2]),
                transforms.RandomRotation(degrees=15),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        }[name] 
开发者ID:ksanjeevan,项目名称:crnn-audio-classification,代码行数:21,代码来源:transforms.py

示例8: get_dataloaders

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import RandomRotation [as 别名]
def get_dataloaders(data, batch_size=8, study_level=False):
    '''
    Returns dataloader pipeline with data augmentation
    '''
    data_transforms = {
        'train': transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(10),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 
        ]),
        'valid': transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    image_datasets = {x: ImageDataset(data[x], transform=data_transforms[x]) for x in data_cat}
    dataloaders = {x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in data_cat}
    return dataloaders 
开发者ID:pyaf,项目名称:DenseNet-MURA-PyTorch,代码行数:23,代码来源:pipeline.py

示例9: __call__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import RandomRotation [as 别名]
def __call__(self, task_description):
        rotations = {}
        for data_description in task_description:
            c = self.dataset.indices_to_labels[data_description.index]
            if c not in rotations:
                rot = random.choice(self.degrees)
                try:
                    rotations[c] = transforms.Compose([
                        transforms.ToPILImage(),
                        transforms.RandomRotation((rot, rot), fill=(0, )),
                        transforms.ToTensor(),
                    ])
                except Exception:
                    rotations[c] = transforms.Compose([
                        transforms.ToPILImage(),
                        transforms.RandomRotation((rot, rot)),
                        transforms.ToTensor(),
                    ])
            rotation = rotations[c]
            data_description.transforms.append(lambda x: (rotation(x[0]), x[1]))
        return task_description 
开发者ID:learnables,项目名称:learn2learn,代码行数:23,代码来源:transforms.py

示例10: get_augmentor

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import RandomRotation [as 别名]
def get_augmentor(is_train, image_size, strong=False):

    augments = []

    if is_train:
        if strong:
            augments.append(transforms.RandomRotation(10))

        augments += [
            transforms.RandomResizedCrop(image_size, interpolation=Image.BILINEAR),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.RandomHorizontalFlip()
        ]
    else:
        augments += [
            transforms.Resize(int(image_size / 0.875 + 0.5) if image_size ==
                              224 else image_size, interpolation=Image.BILINEAR),
            transforms.CenterCrop(image_size)
        ]

    augments += [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]

    augmentor = transforms.Compose(augments)
    return augmentor 
开发者ID:IBM,项目名称:BigLittleNet,代码行数:29,代码来源:imagenet_utils.py

示例11: setUp

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import RandomRotation [as 别名]
def setUp(self):
        self.lbls = None
        self.transform = Compose([Resize(60), RandomRotation(90), ToTensor()])
        testtransform = Compose([Resize(32), ToTensor()])
        self.dataset = FileDataset(self.paths, self.lbls, transform=self.transform)
        self.lbls = self.generate_labels(len(self.paths), 10)
        self.dataset = FileDataset(self.paths, self.lbls, transform=self.transform)
        self.active = ActiveLearningDataset(self.dataset,
                                            pool_specifics={'transform': testtransform},
                                            labelled=torch.from_numpy(
                                                (np.array(self.lbls) != -1).astype(np.uint8))) 
开发者ID:ElementAI,项目名称:baal,代码行数:13,代码来源:file_dataset_test.py

示例12: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import RandomRotation [as 别名]
def __init__(self, cfg):
        super(SiamFCDataset, self).__init__()
        # pair information
        self.template_size = cfg.SIAMFC.TRAIN.TEMPLATE_SIZE
        self.search_size = cfg.SIAMFC.TRAIN.SEARCH_SIZE
        self.size = (self.search_size - self.template_size) // cfg.SIAMFC.TRAIN.STRIDE + 1   # from cross-correlation

        # aug information
        self.color = cfg.SIAMFC.DATASET.COLOR
        self.flip = cfg.SIAMFC.DATASET.FLIP
        self.rotation = cfg.SIAMFC.DATASET.ROTATION
        self.blur = cfg.SIAMFC.DATASET.BLUR
        self.shift = cfg.SIAMFC.DATASET.SHIFT
        self.scale = cfg.SIAMFC.DATASET.SCALE

        self.transform_extra = transforms.Compose(
            [transforms.ToPILImage(), ] +
            ([transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), ] if self.color > random.random() else [])
            + ([transforms.RandomHorizontalFlip(), ] if self.flip > random.random() else [])
            + ([transforms.RandomRotation(degrees=10), ] if self.rotation > random.random() else [])
        )

        # train data information
        if cfg.SIAMFC.TRAIN.WHICH_USE == 'VID':
            self.anno = cfg.SIAMFC.DATASET.VID.ANNOTATION
            self.num_use = cfg.SIAMFC.TRAIN.PAIRS
            self.root = cfg.SIAMFC.DATASET.VID.PATH
        elif cfg.SIAMFC.TRAIN.WHICH_USE == 'GOT10K':
            self.anno = cfg.SIAMFC.DATASET.GOT10K.ANNOTATION
            self.num_use = cfg.SIAMFC.TRAIN.PAIRS
            self.root = cfg.SIAMFC.DATASET.GOT10K.PATH
        else:
            raise ValueError('not supported training dataset')

        self.labels = json.load(open(self.anno, 'r'))
        self.videos = list(self.labels.keys())
        self.num = len(self.videos)   # video number
        self.frame_range = 100
        self.pick = self._shuffle() 
开发者ID:researchmm,项目名称:SiamDW,代码行数:41,代码来源:siamfc.py

示例13: calc_accuracy

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import RandomRotation [as 别名]
def calc_accuracy(model, input_image_size, use_google_testset=False, testset_path=None, batch_size=32,
                  norm_mean=[0.485, 0.456, 0.406], norm_std=[0.229, 0.224, 0.225]):
    """
    Calculate the mean accuracy of the model on the test test
    :param use_google_testset: If true use the testset derived from google image
    :param testset_path: If None, use a default testset (missing image from the Udacity dataset,
    downloaded from here: http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz)
    :param batch_size:
    :param model:
    :param input_image_size:
    :param norm_mean:
    :param norm_std:
    :return: the mean accuracy
    """
    if use_google_testset:
        testset_path = "./google_test_data"
        url = 'https://www.dropbox.com/s/3zmf1kq58o909rq/google_test_data.zip?dl=1'
        download_test_set(testset_path, url)
    if testset_path is None:
        testset_path = "./flower_data_orginal_test"
        url = 'https://www.dropbox.com/s/da6ye9genbsdzbq/flower_data_original_test.zip?dl=1'
        download_test_set(testset_path, url)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.eval()
    model.to(device=device)
    with torch.no_grad():
        batch_accuracy = []
        torch.manual_seed(33)
        torch.cuda.manual_seed(33)
        np.random.seed(33)
        random.seed(33)
        torch.backends.cudnn.deterministic = True
        datatransform = transforms.Compose([transforms.RandomRotation(45),
                                            transforms.Resize(input_image_size + 32),
                                            transforms.CenterCrop(input_image_size),
                                            transforms.RandomHorizontalFlip(),
                                            transforms.ToTensor(),
                                            transforms.Normalize(norm_mean, norm_std)])
        image_dataset = datasets.ImageFolder(testset_path, transform=datatransform)
        dataloader = torch.utils.data.DataLoader(image_dataset, batch_size=batch_size, shuffle=True, worker_init_fn=_init_fn)
        for idx, (inputs, labels) in enumerate(dataloader):
            if device == 'cuda':
                inputs, labels = inputs.cuda(), labels.cuda()
            outputs = model.forward(inputs)
            _, predicted = outputs.max(dim=1)
            equals = predicted == labels.data
            print("Batch accuracy (Size {}): {}".format(batch_size, equals.float().mean()))
            batch_accuracy.append(equals.float().mean().cpu().numpy())
        mean_acc = np.mean(batch_accuracy)
        print("Mean accuracy: {}".format(mean_acc))
    return mean_acc 
开发者ID:GabrielePicco,项目名称:deep-learning-flower-identifier,代码行数:53,代码来源:test_model_pytorch_facebook_challenge.py

示例14: setup_datasets

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import RandomRotation [as 别名]
def setup_datasets(self):
        """Load the training datasets."""

        train_transform = transforms.Compose(
            [
                transforms.Resize(self.crop_size),
                transforms.RandomRotation(degrees=self.random_angle, resample=Image.BILINEAR),
                transforms.RandomResizedCrop(
                    size=self.crop_size, scale=(1-self.random_scale, 1+self.random_scale), ratio=(1, 1)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ]
        )
        val_transform = transforms.Compose(
            [
                transforms.Resize(self.crop_size),
                transforms.CenterCrop(self.crop_size),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ]
        )

        train_dataset = CocoDatasetPairs(
            root_dir=self.coco_path,
            set_name='train2014',
            transform=train_transform,
            dataset_size_ratio=self.dataset_size_ratio
        )
        train_subset_dataset = Subset(train_dataset, range(0, len(train_dataset), 5*self.dataset_size_ratio))
        val_dataset = CocoDatasetPairs(
            root_dir=self.coco_path,
            set_name='val2014',
            transform=val_transform,
        )

        train_loader = DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers
        )
        train_subset_loader = DataLoader(
            train_subset_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers
        )
        val_loader = DataLoader(
            val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers
        )
        return train_loader, train_subset_loader, val_loader 
开发者ID:leokarlin,项目名称:LaSO,代码行数:63,代码来源:train_setops_stripped.py

示例15: q_eval

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import RandomRotation [as 别名]
def q_eval(net, dataset, q_idx, flip=False, rotate=False, scale=1):
    # load query image
    q_im = dataset.get_image(q_idx)
    q_size = q_im.size

    # list of transformation lists
    trfs_chains = [[]]
    if rotate:
        eps = 1e-6
        trfs_chains[0] += [RandomRotation((rotate-eps,rotate+eps))]
    if flip:
        trfs_chains[0] += [RandomHorizontalFlip(1)]
    if scale == 0: # AlexNet asks for resized images of 224x224
        edge_list = [224]
        resize_list = [Resize((edge,edge)) for edge in edge_list]
    elif scale == 1:
        edge_list = [800]
        resize_list = [lambda im: imresize(im, edge) for edge in edge_list]
    elif scale == 1.5:
        edge_list = [1200]
        resize_list = [lambda im: imresize(im, edge) for edge in edge_list]
    elif scale == 2: # multiscale
        edge_list = [600,800,1000,1200]
        resize_list = [lambda im: imresize(im, edge) for edge in edge_list]
    else:
        raise ValueError()

    if len(resize_list) == 1:
        trfs_chains[0] += resize_list
    else:
        add_trf(trfs_chains, resize_list )

    # default transformations
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    for chain in trfs_chains:
        chain += [ToTensor(), Normalize(mean, std)]

    net = net.eval()
    q_feat = torch.zeros( (len(trfs_chains), net.out_features) )
    print ('Computing the forward pass and extracting the image representation...')
    for i in range(len(trfs_chains)):
        q_tensor = Compose(trfs_chains[i])(q_im)
        import pdb; pdb.set_trace()  # XXX BREAKPOINT
        q_feat[i] = net.forward(q_tensor.view(1,q_tensor.shape[0],q_tensor.shape[1],q_tensor.shape[2]))
    return F.normalize(q_feat.mean(dim=0), dim=0).detach().numpy() 
开发者ID:almazan,项目名称:paiss,代码行数:49,代码来源:test.py


注:本文中的torchvision.transforms.RandomRotation方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。