本文整理汇总了Python中torchvision.datasets.ImageNet方法的典型用法代码示例。如果您正苦于以下问题:Python datasets.ImageNet方法的具体用法?Python datasets.ImageNet怎么用?Python datasets.ImageNet使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torchvision.datasets
的用法示例。
在下文中一共展示了datasets.ImageNet方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_loaders
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageNet [as 别名]
def get_loaders(dataroot, val_batch_size, train_batch_size, input_size, workers, num_nodes, local_rank):
# TODO: pin-memory currently broken for distributed
pin_memory = False
# TODO: datasets.ImageNet
val_data = datasets.ImageFolder(root=os.path.join(dataroot, 'val'), transform=get_transform(False, input_size))
val_sampler = DistributedSampler(val_data, num_nodes, local_rank)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=val_batch_size, sampler=val_sampler,
num_workers=workers, pin_memory=pin_memory)
train_data = datasets.ImageFolder(root=os.path.join(dataroot, 'train'),
transform=get_transform(input_size=input_size))
train_sampler = DistributedSampler(train_data, num_nodes, local_rank)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=train_batch_size, sampler=train_sampler,
num_workers=workers, pin_memory=pin_memory)
return train_loader, val_loader
示例2: load_targets
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageNet [as 别名]
def load_targets(self):
"""
Downloads ImageNet labels and IDs and puts into self.root, then loads to self.targets
:return: void - update self.targets with the ImageNet validation data labels, and downloads if
the pickled validation data is not in the root location
"""
download_url(
url=ARCHIVE_DICT['labels']['url'],
root=self.root,
md5=ARCHIVE_DICT['labels']['md5'])
with open(os.path.join(self.root, 'imagenet_val_targets.pkl'), 'rb') as handle:
self.targets = pickle.load(handle)
示例3: save
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageNet [as 别名]
def save(self):
"""
Calculate results and then put into a BenchmarkResult object
On the sotabench.com server, this will produce a JSON file serialisation in sotabench_results.json and results
will be recorded on the platform.
:return: BenchmarkResult object with results and metadata
"""
# recalculate to ensure no mistakes made during batch-by-batch metric calculation
self.get_results()
# If this is the first time the model is run, then we record evaluation time information
if not self.cached_results:
exec_speed = (time.time() - self.init_time)
self.speed_mem_metrics['Tasks / Evaluation Time'] = len(self.outputs) / exec_speed
self.speed_mem_metrics['Tasks'] = len(self.outputs)
self.speed_mem_metrics['Evaluation Time'] = exec_speed
else:
self.speed_mem_metrics['Tasks / Evaluation Time'] = None
self.speed_mem_metrics['Tasks'] = None
self.speed_mem_metrics['Evaluation Time'] = None
return BenchmarkResult(
task=self.task,
config={},
dataset='ImageNet',
results=self.results,
speed_mem_metrics=self.speed_mem_metrics,
model=self.model_name,
model_description=self.model_description,
arxiv_id=self.paper_arxiv_id,
pwc_id=self.paper_pwc_id,
paper_results=self.paper_results,
run_hash=self.batch_hash,
)
示例4: __init__
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageNet [as 别名]
def __init__(self,
root,
train=True,
transform=None,
download=False):
assert not download, "Download dataset by yourself!"
super(ImageNet, self).__init__(root, split="train" if train else "val", transform=transform)
示例5: get_dataset
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageNet [as 别名]
def get_dataset(self):
"""
Uses torchvision.datasets.ImageNet to load dataset.
Downloads dataset if doesn't exist already.
Returns:
torch.utils.data.TensorDataset: trainset, valset
"""
trainset = datasets.ImageNet('datasets/ImageNet/train/', split='train', transform=self.train_transforms,
target_transform=None, download=True)
valset = datasets.ImageNet('datasets/ImageNet/val/', split='val', transform=self.val_transforms,
target_transform=None, download=True)
return trainset, valset
示例6: get_train_val_loaders
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageNet [as 别名]
def get_train_val_loaders(
root_path: str,
train_transforms: Callable,
val_transforms: Callable,
batch_size: int = 16,
num_workers: int = 8,
val_batch_size: Optional[int] = None,
limit_train_num_samples: Optional[int] = None,
limit_val_num_samples: Optional[int] = None,
) -> Tuple[DataLoader, DataLoader, DataLoader]:
train_ds = ImageNet(
root_path, split="train", transform=lambda sample: train_transforms(image=sample)["image"], loader=opencv_loader
)
val_ds = ImageNet(
root_path, split="val", transform=lambda sample: val_transforms(image=sample)["image"], loader=opencv_loader
)
if limit_train_num_samples is not None:
np.random.seed(limit_train_num_samples)
train_indices = np.random.permutation(len(train_ds))[:limit_train_num_samples]
train_ds = Subset(train_ds, train_indices)
if limit_val_num_samples is not None:
np.random.seed(limit_val_num_samples)
val_indices = np.random.permutation(len(val_ds))[:limit_val_num_samples]
val_ds = Subset(val_ds, val_indices)
# random samples for evaluation on training dataset
if len(val_ds) < len(train_ds):
np.random.seed(len(val_ds))
train_eval_indices = np.random.permutation(len(train_ds))[: len(val_ds)]
train_eval_ds = Subset(train_ds, train_eval_indices)
else:
train_eval_ds = train_ds
train_loader = idist.auto_dataloader(
train_ds, shuffle=True, batch_size=batch_size, num_workers=num_workers, drop_last=True,
)
val_batch_size = batch_size * 4 if val_batch_size is None else val_batch_size
val_loader = idist.auto_dataloader(
val_ds, shuffle=False, batch_size=val_batch_size, num_workers=num_workers, drop_last=False,
)
train_eval_loader = idist.auto_dataloader(
train_eval_ds, shuffle=False, batch_size=val_batch_size, num_workers=num_workers, drop_last=False,
)
return train_loader, val_loader, train_eval_loader
示例7: get_results
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageNet [as 别名]
def get_results(self):
"""
Gets the results for the evaluator. This method only runs if predictions for all 5,000 ImageNet validation
images are available. Otherwise raises an error and informs you of the missing or unmatched IDs.
:return: dict with Top 1 and Top 5 Accuracy
"""
if self.cached_results:
return self.results
if set(self.targets.keys()) != set(self.outputs.keys()):
missing_ids = set(self.targets.keys()) - set(self.outputs.keys())
unmatched_ids = set(self.outputs.keys()) - set(self.targets.keys())
if len(unmatched_ids) > 0:
raise ValueError('''There are {mis_no} missing and {un_no} unmatched image IDs\n\n'''
'''Missing IDs are {missing}\n\n'''
'''Unmatched IDs are {unmatched}'''.format(mis_no=len(missing_ids),
un_no=len(unmatched_ids),
missing=missing_ids,
unmatched=unmatched_ids))
else:
raise ValueError('''There are {mis_no} missing image IDs\n\n'''
'''Missing IDs are {missing}'''.format(mis_no=len(missing_ids),
missing=missing_ids))
# Do the calculation only if we have all the results...
self.top1 = AverageMeter()
self.top5 = AverageMeter()
for i, dict_key in enumerate(tqdm.tqdm(self.targets.keys())):
output = self.outputs[dict_key]
target = self.targets[dict_key]
prec1 = top_k_accuracy_score(y_true=target, y_pred=np.array([output]), k=1)
prec5 = top_k_accuracy_score(y_true=target, y_pred=np.array([output]), k=5)
self.top1.update(prec1, 1)
self.top5.update(prec5, 1)
self.results = {'Top 1 Accuracy': self.top1.avg, 'Top 5 Accuracy': self.top5.avg}
self.speed_mem_metrics['Max Memory Allocated (Total)'] = get_max_memory_allocated()
return self.results
示例8: pre_benchmark_atk
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageNet [as 别名]
def pre_benchmark_atk(**kwargs):
"""
Helper function that sets all the defaults while performing checks
for all the options passed before benchmarking attacks.
"""
# Set the Default options if nothing explicit provided
def_dict = {
'bs' : 4,
'trf' : get_trf('rz256_cc224_tt_normimgnet'),
'dset' : 'NA',
'root' : './',
'topk' : (1, 5),
'dfunc' : datasets.ImageFolder,
'download' : True,
}
for key, val in def_dict.items():
if key not in kwargs: kwargs[key] = val
if kwargs['dset'] == 'NA':
if 'loader' not in kwargs:
dset = kwargs['dfunc'](kwargs['root'], transform=kwargs['trf'])
loader = DataLoader(dset, batch_size=kwargs['bs'], num_workers=2)
else:
loader = kwargs['loader']
# Set dataset specific functions here
else:
if kwargs['dset'] == IMGNET12:
dset = datasets.ImageNet(kwargs['root'], split='test',
download=kwargs['download'], transform=kwargs['trf'])
elif kwargs['dset'] == MNIST:
kwargs['trf'] = get_trf('tt_normmnist')
kwargs['dfunc'] = datasets.MNIST
dset = kwargs['dfunc'](kwargs['root'], train=False,
download=kwargs['download'], transform=kwargs['trf'])
else: raise
loader = DataLoader(dset, shuffle=False, batch_size=kwargs['bs'])
topk = kwargs['topk']
for key, val in kwargs.items():
print ('[INFO] Setting {} to {}.'.format(key, kwargs[key]))
# Deleting keys that is used just for benchmark_atk() function is
# important as the same kwargs dict is passed to initialize the attack
# So, otherwise the attack will throw an exception
for key in def_dict:
del kwargs[key]
if 'loader' in kwargs: del kwargs['loader']
return loader, topk, kwargs