當前位置: 首頁>>代碼示例>>Python>>正文


Python utils.create_exp_dir方法代碼示例

本文整理匯總了Python中torch.utils.create_exp_dir方法的典型用法代碼示例。如果您正苦於以下問題:Python utils.create_exp_dir方法的具體用法?Python utils.create_exp_dir怎麽用?Python utils.create_exp_dir使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch.utils的用法示例。


在下文中一共展示了utils.create_exp_dir方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: __init__

# 需要導入模塊: from torch import utils [as 別名]
# 或者: from torch.utils import create_exp_dir [as 別名]
def __init__(self):
        super(Helper, self).__init__()

        self.args._save = copy(self.args.save)
        self.args.save = '{}/{}/{}/{}_{}-{}'.format(self.args.save,
                                                    self.args.space,
                                                    self.args.dataset,
                                                    self.args.search_dp,
                                                    self.args.search_wd,
                                                    self.args.job_id)

        utils.create_exp_dir(self.args.save)

        config_filename = os.path.join(self.args._save, 'config.yaml')
        if not os.path.exists(config_filename):
            with open(config_filename, 'w') as f:
                yaml.dump(self.args_to_log, f, default_flow_style=False)

        if self.args.dataset != 'cifar100':
            self.args.n_classes = 10
        else:
            self.args.n_classes = 100 
開發者ID:automl,項目名稱:RobustDARTS,代碼行數:24,代碼來源:args.py

示例2: __init__

# 需要導入模塊: from torch import utils [as 別名]
# 或者: from torch.utils import create_exp_dir [as 別名]
def __init__(self):
        super(Helper, self).__init__()

        self.args._save = copy(self.args.save)
        self.args.save = '{}/{}/{}/{}_{}-{}'.format(self.args.save,
                                                    self.args.space,
                                                    self.args.dataset,
                                                    self.args.drop_path_prob,
                                                    self.args.weight_decay,
                                                    self.args.job_id)

        utils.create_exp_dir(self.args.save)

        config_filename = os.path.join(self.args._save, 'config.yaml')
        if not os.path.exists(config_filename):
            with open(config_filename, 'w') as f:
                yaml.dump(self.args_to_log, f, default_flow_style=False)

        if self.args.dataset != 'cifar100':
            self.args.n_classes = 10
        else:
            self.args.n_classes = 100

        # set cutout to False if the drop_prob is 0
        if self.args.drop_path_prob == 0:
            self.args.cutout = False 
開發者ID:automl,項目名稱:RobustDARTS,代碼行數:28,代碼來源:args.py

示例3: initialize_run

# 需要導入模塊: from torch import utils [as 別名]
# 或者: from torch.utils import create_exp_dir [as 別名]
def initialize_run(self, sub_dir_path=None):
        args = self.args
        utils = project_utils
        if not self.args.continue_train:
            self.sub_directory_path = 'WeightSharingNasBenchNetRandom-{}_SEED_{}'.format(self.args.save, self.args.seed)
            self.exp_dir = os.path.join(self.args.main_path, self.sub_directory_path)
            utils.create_exp_dir(self.exp_dir)
        if self.args.visualize:
            self.viz_dir_path = utils.create_viz_dir(self.exp_dir)

        if self.args.tensorboard:
            self.tb_dir = self.exp_dir
            tboard_dir = os.path.join(self.args.tboard_dir, self.sub_directory_path)
            self.writer = SummaryWriter(tboard_dir)

        if self.args.debug:
            torch.autograd.set_detect_anomaly(True)

        # Set logger.
        self.logger = utils.get_logger(
            "train_search",
            file_handler=utils.get_file_handler(os.path.join(self.exp_dir, 'log.txt')),
            level=logging.INFO if not args.debug else logging.DEBUG
        )
        logging.info(f"setting random seed as {args.seed}")
        utils.torch_random_seed(args.seed)
        logging.info('gpu number = %d' % args.gpus)
        logging.info("args = %s", args)

        criterion = nn.CrossEntropyLoss().cuda()
        eval_criterion = nn.CrossEntropyLoss().cuda()
        self.eval_loss = eval_criterion

        train_transform, valid_transform = utils._data_transforms_cifar10(args.cutout_length if args.cutout else None)
        train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
        valid_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=valid_transform)
        test_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)

        num_train = len(train_data)
        indices = list(range(num_train))
        split = int(np.floor(args.nao_search_config.ratio * num_train))

        train_queue = torch.utils.data.DataLoader(
            train_data, batch_size=args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
            pin_memory=True, num_workers=2)

        valid_queue = torch.utils.data.DataLoader(
            valid_data, batch_size=args.nao_search_config.child_eval_batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
            pin_memory=True, num_workers=2)

        test_queue = torch.utils.data.DataLoader(
            test_data, batch_size=args.evaluate_batch_size,
            shuffle=False, pin_memory=True, num_workers=8)

        return train_queue, valid_queue, test_queue, criterion, eval_criterion 
開發者ID:kcyu2014,項目名稱:eval-nas,代碼行數:59,代碼來源:nao_search_policy.py


注:本文中的torch.utils.create_exp_dir方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。