当前位置: 首页>>代码示例>>Python>>正文


Python transforms.TenCrop方法代码示例

本文整理汇总了Python中torchvision.transforms.TenCrop方法的典型用法代码示例。如果您正苦于以下问题:Python transforms.TenCrop方法的具体用法?Python transforms.TenCrop怎么用?Python transforms.TenCrop使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torchvision.transforms的用法示例。


在下文中一共展示了transforms.TenCrop方法的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: scale_crop

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import TenCrop [as 别名]
def scale_crop(input_size, scale_size=None, num_crops=1, normalize=_IMAGENET_STATS):
    assert num_crops in [1, 5, 10], "num crops must be in {1,5,10}"
    convert_tensor = transforms.Compose([transforms.ToTensor(),
                                         transforms.Normalize(**normalize)])
    if num_crops == 1:
        t_list = [
            transforms.CenterCrop(input_size),
            convert_tensor
        ]
    else:
        if num_crops == 5:
            t_list = [transforms.FiveCrop(input_size)]
        elif num_crops == 10:
            t_list = [transforms.TenCrop(input_size)]
        # returns a 4D tensor
        t_list.append(transforms.Lambda(lambda crops:
                                        torch.stack([convert_tensor(crop) for crop in crops])))

    if scale_size != input_size:
        t_list = [transforms.Resize(scale_size)] + t_list

    return transforms.Compose(t_list) 
开发者ID:eladhoffer,项目名称:convNet.pytorch,代码行数:24,代码来源:preprocess.py

示例2: create_test_transforms

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import TenCrop [as 别名]
def create_test_transforms(config, crop, scale, ten_crops):
    normalize = transforms.Normalize(mean=config["mean"], std=config["std"])

    val_transforms = []
    if scale != -1:
        val_transforms.append(transforms.Resize(scale))
    if ten_crops:
        val_transforms += [
            transforms.TenCrop(crop),
            transforms.Lambda(lambda crops: [transforms.ToTensor()(crop) for crop in crops]),
            transforms.Lambda(lambda crops: [normalize(crop) for crop in crops]),
            transforms.Lambda(lambda crops: torch.stack(crops))
        ]
    else:
        val_transforms += [
            transforms.CenterCrop(crop),
            transforms.ToTensor(),
            normalize
        ]

    return val_transforms 
开发者ID:mapillary,项目名称:inplace_abn,代码行数:23,代码来源:utils.py

示例3: ten_crop

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import TenCrop [as 别名]
def ten_crop(cfg, **kwargs):
    size = kwargs["input_size"] if kwargs["input_size"] != None else cfg.INPUT_SIZE
    return transforms.TenCrop(size) 
开发者ID:Megvii-Nanjing,项目名称:BBN,代码行数:5,代码来源:transform_wrapper.py

示例4: extract_features_CUHK03

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import TenCrop [as 别名]
def extract_features_CUHK03(model, scale_image_size, data, extract_features_folder, logger, batch_size=128, workers=4, is_tencrop=False,normalize=None):
    logger.info('Begin extract features')
    if normalize == None:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
    if is_tencrop:
        logger.info('==> Using TenCrop')
        tencrop = transforms.Compose([
            transforms.Resize([int(x*1.125) for x in scale_image_size]),
            transforms.TenCrop(scale_image_size)])
    else:
        tencrop = None
    transform = transforms.Compose([
        transforms.Resize(scale_image_size),
        transforms.ToTensor(),
        normalize, ])
    train_data_folder = data
    logger.info('Begin load train data from '+train_data_folder)
    train_dataloader = torch.utils.data.DataLoader(
        Datasets.CUHK03EvaluateDataset(folder=train_data_folder, transform=transform, tencrop=tencrop),
        batch_size=batch_size, shuffle=False,
        num_workers=workers, pin_memory=True)

    train_features = extract_features(model, train_dataloader, is_tencrop)
    if os.path.isdir(extract_features_folder) is False:
        os.makedirs(extract_features_folder)

    sio.savemat(os.path.join(extract_features_folder, 'train_features.mat'), {'feature_train_new': train_features})
    return 
开发者ID:mileyan,项目名称:DARENet,代码行数:31,代码来源:extract_features.py

示例5: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import TenCrop [as 别名]
def __init__(self, data_path, is_train = True, *args, **kwargs):
        super(Market1501, self).__init__(*args, **kwargs)
        self.is_train = is_train
        self.data_path = data_path
        self.imgs = os.listdir(data_path)
        self.imgs = [el for el in self.imgs if os.path.splitext(el)[1] == '.jpg']
        self.lb_ids = [int(el.split('_')[0]) for el in self.imgs]
        self.lb_cams = [int(el.split('_')[1][1]) for el in self.imgs]
        self.imgs = [os.path.join(data_path, el) for el in self.imgs]
        if is_train:
            self.trans = transforms.Compose([
                transforms.Resize((288, 144)),
                transforms.RandomCrop((256, 128)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.486, 0.459, 0.408), (0.229, 0.224, 0.225)),
                RandomErasing(0.5, mean=[0.0, 0.0, 0.0])
            ])
        else:
            self.trans_tuple = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.486, 0.459, 0.408), (0.229, 0.224, 0.225))
                ])
            self.Lambda = transforms.Lambda(
                lambda crops: [self.trans_tuple(crop) for crop in crops])
            self.trans = transforms.Compose([
                transforms.Resize((288, 144)),
                transforms.TenCrop((256, 128)),
                self.Lambda,
            ])

        # useful for sampler
        self.lb_img_dict = dict()
        self.lb_ids_uniq = set(self.lb_ids)
        lb_array = np.array(self.lb_ids)
        for lb in self.lb_ids_uniq:
            idx = np.where(lb_array == lb)[0]
            self.lb_img_dict.update({lb: idx}) 
开发者ID:CoinCheung,项目名称:triplet-reid-pytorch,代码行数:40,代码来源:Market1501.py

示例6: __init__

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import TenCrop [as 别名]
def __init__(self, base_dataset, input_sz=None, include_rgb=None):
    super(TenCropAndFinish, self).__init__()

    self.base_dataset = base_dataset
    self.num_tfs = 10
    self.input_sz = input_sz
    self.include_rgb = include_rgb

    self.crops_tf = transforms.TenCrop(self.input_sz)
    self.finish_tf = custom_greyscale_to_tensor(self.include_rgb) 
开发者ID:xu-ji,项目名称:IIC,代码行数:12,代码来源:dataset.py

示例7: get_transform_for_test

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import TenCrop [as 别名]
def get_transform_for_test():

    transform_list = []
    
    transform_list.append(transforms.Lambda(lambda img:scale_keep_ar_min_fixed(img, 560)))
   
    transform_list.append(transforms.TenCrop(448)) 
    
    transform_list.append(transforms.Lambda(lambda crops: torch.stack([transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))((transforms.ToTensor())(crop)) for crop in crops])) )
    
    #transform_list.append(transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5)))
    
    return transforms.Compose(transform_list) 
开发者ID:songdejia,项目名称:DFL-CNN,代码行数:15,代码来源:transform.py

示例8: loaders_imagenet

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import TenCrop [as 别名]
def loaders_imagenet(dataset_name, batch_size, cuda,
                     train_size, augment=True, val_size=50000,
                     test_batch_size=256, topk=None, noise=False,
                     multiple_crops=False, data_root=None):

    assert dataset_name == 'imagenet'
    data_root = data_root if data_root is not None else os.environ['VISION_DATA_SSD']
    root = '{}/ILSVRC2012-prepr-split/images'.format(data_root)

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    traindir = os.path.join(root, 'train')
    valdir = os.path.join(root, 'val')
    testdir = os.path.join(root, 'test')

    normalize = transforms.Normalize(mean=mean, std=std)

    if multiple_crops:
        print('Using multiple crops')
        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.TenCrop(224),
            lambda x: [normalize(transforms.functional.to_tensor(img)) for img in x]])
    else:
        transform_test = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize])

    if augment:
        transform_train = transforms.Compose([
            transforms.RandomCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize])
    else:
        transform_train = transform_test

    dataset_train = datasets.ImageFolder(traindir, transform_train)
    dataset_val = datasets.ImageFolder(valdir, transform_test)
    dataset_test = datasets.ImageFolder(testdir, transform_test)

    return create_loaders(dataset_name, dataset_train, dataset_val,
                          dataset_test, train_size, val_size, batch_size,
                          test_batch_size, cuda, noise=noise, num_workers=4) 
开发者ID:oval-group,项目名称:smooth-topk,代码行数:49,代码来源:main.py

示例9: extract_features_MARS

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import TenCrop [as 别名]
def extract_features_MARS(model, scale_image_size, info_folder, data, extract_features_folder, logger, batch_size=128, workers=4, is_tencrop=False):
    logger.info('Begin extract features')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    if is_tencrop:
        logger.info('==> Using TenCrop')
        tencrop = transforms.Compose([
            transforms.Resize([int(x*1.125) for x in scale_image_size]),
            transforms.TenCrop(scale_image_size)])
    else:
        tencrop = None
    transform = transforms.Compose([
        transforms.Resize(scale_image_size),
        transforms.ToTensor(),
        normalize, ])
    train_name_path = os.path.join(info_folder, 'train_name.txt')
    test_name_path = os.path.join(info_folder, 'test_name.txt')
    train_data_folder = os.path.join(data, 'bbox_train')
    test_data_folder = os.path.join(data, 'bbox_test')
    logger.info('Train data folder: '+train_data_folder)
    logger.info('Test data folder: '+test_data_folder)
    logger.info('Begin load train data')
    train_dataloader = torch.utils.data.DataLoader(
        Datasets.MARSEvalDataset(folder=train_data_folder,
                                    image_name_file=train_name_path,
                                    transform=transform, tencrop=tencrop),
        batch_size=batch_size, shuffle=False,
        num_workers=workers, pin_memory=True)
    logger.info('Begin load test data')
    test_dataloader = torch.utils.data.DataLoader(
        Datasets.MARSEvalDataset(folder=test_data_folder,
                                    image_name_file=test_name_path,
                                    transform=transform, tencrop=tencrop),
        batch_size=batch_size, shuffle=False,
        num_workers=workers, pin_memory=True)

    train_features = extract_features(model, train_dataloader, is_tencrop)
    test_features = extract_features(model, test_dataloader, is_tencrop)
    if os.path.isdir(extract_features_folder) is False:
        os.makedirs(extract_features_folder)

    sio.savemat(os.path.join(extract_features_folder, 'train_features.mat'), {'feature_train_new': train_features})
    sio.savemat(os.path.join(extract_features_folder, 'test_features.mat'), {'feature_test_new': test_features})
    return 
开发者ID:mileyan,项目名称:DARENet,代码行数:46,代码来源:extract_features.py

示例10: extract_features_Market1501

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import TenCrop [as 别名]
def extract_features_Market1501(model, scale_image_size, data, extract_features_folder, logger, batch_size=128, workers=4, is_tencrop=False, gen_stage_features = False):
    logger.info('Begin extract features')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    if is_tencrop:
        logger.info('==> Using TenCrop')
        tencrop = transforms.Compose([
            transforms.Resize([int(x*1.125) for x in scale_image_size]),
            transforms.TenCrop(scale_image_size)])
    else:
        tencrop = None
    transform = transforms.Compose([
        transforms.Resize(scale_image_size),
        transforms.ToTensor(),
        normalize, ])
    train_data_folder = os.path.join(data, 'bounding_box_train')
    test_data_folder = os.path.join(data, 'bounding_box_test')
    query_data_folder = os.path.join(data, 'query')
    logger.info('Begin load train data from '+train_data_folder)
    train_dataloader = torch.utils.data.DataLoader(
        Datasets.Market1501EvaluateDataset(folder=train_data_folder, transform=transform, tencrop=tencrop),
        batch_size=batch_size, shuffle=False,
        num_workers=workers, pin_memory=True)
    logger.info('Begin load test data from '+test_data_folder)
    test_dataloader = torch.utils.data.DataLoader(
        Datasets.Market1501EvaluateDataset(folder=test_data_folder, transform=transform, tencrop=tencrop),
        batch_size=batch_size, shuffle=False,
        num_workers=workers, pin_memory=True)

    logger.info('Begin load query data from '+query_data_folder)
    query_dataloader = torch.utils.data.DataLoader(
        Datasets.Market1501EvaluateDataset(folder=query_data_folder, transform=transform, tencrop=tencrop),
        batch_size=batch_size, shuffle=False,
        num_workers=workers, pin_memory=True)
    if not gen_stage_features:
        train_features = extract_features(model, train_dataloader, is_tencrop)
        test_features = extract_features(model, test_dataloader, is_tencrop)
        query_features = extract_features(model, query_dataloader, is_tencrop)
        if os.path.isdir(extract_features_folder) is False:
            os.makedirs(extract_features_folder)

        sio.savemat(os.path.join(extract_features_folder, 'train_features.mat'), {'feature_train_new': train_features})
        sio.savemat(os.path.join(extract_features_folder, 'test_features.mat'), {'feature_test_new': test_features})
        sio.savemat(os.path.join(extract_features_folder, 'query_features.mat'), {'feature_query_new': query_features})
    else:
        # model.gen_stage_features = True

        train_features = extract_stage_features(model, train_dataloader, is_tencrop)
        test_features = extract_stage_features(model, test_dataloader, is_tencrop)
        query_features = extract_stage_features(model, query_dataloader, is_tencrop)
        if os.path.isdir(extract_features_folder) is False:
            os.makedirs(extract_features_folder)

        for i in range(4):
            sio.savemat(os.path.join(extract_features_folder, 'train_features_{}.mat'.format(i + 1)), {'feature_train_new': train_features[i]})
            sio.savemat(os.path.join(extract_features_folder, 'test_features_{}.mat'.format(i + 1)), {'feature_test_new': test_features[i]})
            sio.savemat(os.path.join(extract_features_folder, 'query_features_{}.mat'.format(i + 1)), {'feature_query_new': query_features[i]})

        sio.savemat(os.path.join(extract_features_folder, 'train_features_fusion.mat'), {'feature_train_new': train_features[4]})
        sio.savemat(os.path.join(extract_features_folder, 'test_features_fusion.mat'), {'feature_test_new': test_features[4]})
        sio.savemat(os.path.join(extract_features_folder, 'query_features_fusion.mat'), {'feature_query_new': query_features[4]}) 
开发者ID:mileyan,项目名称:DARENet,代码行数:63,代码来源:extract_features.py

示例11: data_loader_predict

# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import TenCrop [as 别名]
def data_loader_predict(data_dir, input_shape, name):
    if name in ["inceptionv4", "inceptionresnetv2", "inception_v3"]:
        scale = 360
        mean = [0.5, 0.5, 0.5]
        std = [0.5, 0.5, 0.5]

    elif name == "bninception":
        scale = 256
        mean = [104, 117, 128]
        std =  [1, 1, 1]

    elif name == "vggm":
        scale = 256
        mean = [123.68, 116.779, 103.939]
        std = [1, 1, 1]

    elif name == "nasnetalarge":
        scale = 354
        mean = [0.5, 0.5, 0.5]
        std = [1, 1, 1]

    else:
        scale = 256
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
    print("[Scale: {} , mean: {}, std: {}]".format(scale, mean, std))
    if name == "bninception":
        val = transforms.Compose([transforms.Scale(scale),
                              transforms.TenCrop(input_shape),
                              transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                              transforms.Lambda(lambda bgr: torch.stack([ToSpaceBGR(True)(bgrformat) for bgrformat in bgr])),
                              transforms.Lambda(lambda range255: torch.stack([ToRange255(True)(ranges) for ranges in range255])),
                              transforms.Lambda(lambda normal: torch.stack([transforms.Normalize(mean, std)(normalize) for normalize in normal]))])
    else:
        val = transforms.Compose([transforms.Scale(scale),
                              transforms.TenCrop(input_shape),
                              transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                              transforms.Lambda(lambda normal: torch.stack([transforms.Normalize(mean, std)(normalize) for normalize in normal]))])
    image_datasets = datasets.ImageFolder(data_dir, val)
    dataloaders = torch.utils.data.DataLoader(image_datasets, batch_size=1,
                                         shuffle=False, num_workers=1)
    return dataloaders, image_datasets 
开发者ID:prakashjayy,项目名称:pytorch_classifiers,代码行数:44,代码来源:tars_data_loaders.py


注:本文中的torchvision.transforms.TenCrop方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。