本文整理汇总了Python中torch.utils.create_exp_dir方法的典型用法代码示例。如果您正苦于以下问题:Python utils.create_exp_dir方法的具体用法?Python utils.create_exp_dir怎么用?Python utils.create_exp_dir使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.utils
的用法示例。
在下文中一共展示了utils.create_exp_dir方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: from torch import utils [as 别名]
# 或者: from torch.utils import create_exp_dir [as 别名]
def __init__(self):
super(Helper, self).__init__()
self.args._save = copy(self.args.save)
self.args.save = '{}/{}/{}/{}_{}-{}'.format(self.args.save,
self.args.space,
self.args.dataset,
self.args.search_dp,
self.args.search_wd,
self.args.job_id)
utils.create_exp_dir(self.args.save)
config_filename = os.path.join(self.args._save, 'config.yaml')
if not os.path.exists(config_filename):
with open(config_filename, 'w') as f:
yaml.dump(self.args_to_log, f, default_flow_style=False)
if self.args.dataset != 'cifar100':
self.args.n_classes = 10
else:
self.args.n_classes = 100
示例2: __init__
# 需要导入模块: from torch import utils [as 别名]
# 或者: from torch.utils import create_exp_dir [as 别名]
def __init__(self):
super(Helper, self).__init__()
self.args._save = copy(self.args.save)
self.args.save = '{}/{}/{}/{}_{}-{}'.format(self.args.save,
self.args.space,
self.args.dataset,
self.args.drop_path_prob,
self.args.weight_decay,
self.args.job_id)
utils.create_exp_dir(self.args.save)
config_filename = os.path.join(self.args._save, 'config.yaml')
if not os.path.exists(config_filename):
with open(config_filename, 'w') as f:
yaml.dump(self.args_to_log, f, default_flow_style=False)
if self.args.dataset != 'cifar100':
self.args.n_classes = 10
else:
self.args.n_classes = 100
# set cutout to False if the drop_prob is 0
if self.args.drop_path_prob == 0:
self.args.cutout = False
示例3: initialize_run
# 需要导入模块: from torch import utils [as 别名]
# 或者: from torch.utils import create_exp_dir [as 别名]
def initialize_run(self, sub_dir_path=None):
args = self.args
utils = project_utils
if not self.args.continue_train:
self.sub_directory_path = 'WeightSharingNasBenchNetRandom-{}_SEED_{}'.format(self.args.save, self.args.seed)
self.exp_dir = os.path.join(self.args.main_path, self.sub_directory_path)
utils.create_exp_dir(self.exp_dir)
if self.args.visualize:
self.viz_dir_path = utils.create_viz_dir(self.exp_dir)
if self.args.tensorboard:
self.tb_dir = self.exp_dir
tboard_dir = os.path.join(self.args.tboard_dir, self.sub_directory_path)
self.writer = SummaryWriter(tboard_dir)
if self.args.debug:
torch.autograd.set_detect_anomaly(True)
# Set logger.
self.logger = utils.get_logger(
"train_search",
file_handler=utils.get_file_handler(os.path.join(self.exp_dir, 'log.txt')),
level=logging.INFO if not args.debug else logging.DEBUG
)
logging.info(f"setting random seed as {args.seed}")
utils.torch_random_seed(args.seed)
logging.info('gpu number = %d' % args.gpus)
logging.info("args = %s", args)
criterion = nn.CrossEntropyLoss().cuda()
eval_criterion = nn.CrossEntropyLoss().cuda()
self.eval_loss = eval_criterion
train_transform, valid_transform = utils._data_transforms_cifar10(args.cutout_length if args.cutout else None)
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=valid_transform)
test_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
num_train = len(train_data)
indices = list(range(num_train))
split = int(np.floor(args.nao_search_config.ratio * num_train))
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
pin_memory=True, num_workers=2)
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=args.nao_search_config.child_eval_batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
pin_memory=True, num_workers=2)
test_queue = torch.utils.data.DataLoader(
test_data, batch_size=args.evaluate_batch_size,
shuffle=False, pin_memory=True, num_workers=8)
return train_queue, valid_queue, test_queue, criterion, eval_criterion