本文整理汇总了Python中config.dataset方法的典型用法代码示例。如果您正苦于以下问题:Python config.dataset方法的具体用法?Python config.dataset怎么用?Python config.dataset使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类config
的用法示例。
在下文中一共展示了config.dataset方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: path_for
# 需要导入模块: import config [as 别名]
# 或者: from config import dataset [as 别名]
def path_for(train=False, val=False, test=False, question=False, trainval=False, answer=False):
assert train + val + test + trainval == 1
assert question + answer == 1
if train:
split = 'train2014'
elif val:
split = 'val2014'
elif trainval:
split = 'trainval2014'
else:
split = config.test_split
if question:
fmt = 'v2_{0}_{1}_{2}_questions.json'
else:
if test:
# just load validation data in the test=answer=True case, will be ignored anyway
split = 'val2014'
fmt = 'v2_{1}_{2}_annotations.json'
s = fmt.format(config.task, config.dataset, split)
return os.path.join(config.qa_path, s)
示例2: path_for
# 需要导入模块: import config [as 别名]
# 或者: from config import dataset [as 别名]
def path_for(train=False, val=False, test=False, question=False, answer=False):
assert train + val + test == 1
assert question + answer == 1
if train:
split = 'train2014'
elif val:
split = 'val2014'
else:
split = config.test_split
if question:
fmt = 'v2_{0}_{1}_{2}_questions.json'
else:
if test:
# just load validation data in the test=answer=True case, will be ignored anyway
split = 'val2014'
fmt = 'v2_{1}_{2}_annotations.json'
s = fmt.format(config.task, config.dataset, split)
return os.path.join(config.qa_path, s)
示例3: load_dataset
# 需要导入模块: import config [as 别名]
# 或者: from config import dataset [as 别名]
def load_dataset(dataset_spec=None, verbose=True, **spec_overrides):
if verbose: print('Loading dataset...')
if dataset_spec is None: dataset_spec = config.dataset
dataset_spec = dict(dataset_spec) # take a copy of the dict before modifying it
dataset_spec.update(spec_overrides)
dataset_spec['h5_path'] = os.path.join(config.data_dir, dataset_spec['h5_path'])
if 'label_path' in dataset_spec: dataset_spec['label_path'] = os.path.join(config.data_dir, dataset_spec['label_path'])
training_set = dataset.Dataset(**dataset_spec)
if verbose: print('Dataset shape =', np.int32(training_set.shape).tolist())
drange_orig = training_set.get_dynamic_range()
if verbose: print('Dynamic range =', drange_orig)
return training_set, drange_orig
示例4: __init__
# 需要导入模块: import config [as 别名]
# 或者: from config import dataset [as 别名]
def __init__(self, config, args):
self.config = config
for k, v in args.__dict__.items():
setattr(self.config, k, v)
setattr(self.config, 'save_dir', '{}_log'.format(self.config.dataset))
disp_str = ''
for attr in sorted(dir(self.config), key=lambda x: len(x)):
if not attr.startswith('__'):
disp_str += '{} : {}\n'.format(attr, getattr(self.config, attr))
sys.stdout.write(disp_str)
sys.stdout.flush()
self.labeled_loader, self.unlabeled_loader, self.unlabeled_loader2, self.dev_loader, self.special_set = data.get_svhn_loaders(config)
self.dis = model.Discriminative(config).cuda()
self.gen = model.Generator(image_size=config.image_size, noise_size=config.noise_size).cuda()
self.dis_optimizer = optim.Adam(self.dis.parameters(), lr=config.dis_lr, betas=(0.5, 0.999)) # 0.0 0.9999
self.gen_optimizer = optim.Adam(self.gen.parameters(), lr=config.gen_lr, betas=(0.0, 0.999)) # 0.0 0.9999
self.d_criterion = nn.CrossEntropyLoss()
if not os.path.exists(self.config.save_dir):
os.makedirs(self.config.save_dir)
log_path = os.path.join(self.config.save_dir, '{}.FM+PT+ENT.{}.txt'.format(self.config.dataset, self.config.suffix))
self.logger = open(log_path, 'wb')
self.logger.write(disp_str)
示例5: visualize
# 需要导入模块: import config [as 别名]
# 或者: from config import dataset [as 别名]
def visualize(self):
self.gen.eval()
self.dis.eval()
vis_size = 100
noise = Variable(torch.Tensor(vis_size, self.config.noise_size).uniform_().cuda())
gen_images = self.gen(noise)
save_path = os.path.join(self.config.save_dir, '{}.FM+PT+Ent.{}.png'.format(self.config.dataset, self.config.suffix))
vutils.save_image(gen_images.data.cpu(), save_path, normalize=True, range=(-1,1), nrow=10)
示例6: __init__
# 需要导入模块: import config [as 别名]
# 或者: from config import dataset [as 别名]
def __init__(self, opt):
self.log = create_logger(__name__, silent=False, to_disk=True,
log_file=cfg.log_filename if cfg.if_test
else [cfg.log_filename, cfg.save_root + 'log.txt'])
self.sig = Signal(cfg.signal_file)
self.opt = opt
self.show_config()
self.clas = None
# load dictionary
self.word2idx_dict, self.idx2word_dict = load_dict(cfg.dataset)
# Dataloader
try:
self.train_data = GenDataIter(cfg.train_data)
self.test_data = GenDataIter(cfg.test_data, if_test_data=True)
except:
pass
try:
self.train_data_list = [GenDataIter(cfg.cat_train_data.format(i)) for i in range(cfg.k_label)]
self.test_data_list = [GenDataIter(cfg.cat_test_data.format(i), if_test_data=True) for i in
range(cfg.k_label)]
self.clas_data_list = [GenDataIter(cfg.cat_test_data.format(str(i)), if_test_data=True) for i in
range(cfg.k_label)]
self.train_samples_list = [self.train_data_list[i].target for i in range(cfg.k_label)]
self.clas_samples_list = [self.clas_data_list[i].target for i in range(cfg.k_label)]
except:
pass
# Criterion
self.mle_criterion = nn.NLLLoss()
self.dis_criterion = nn.CrossEntropyLoss()
self.clas_criterion = nn.CrossEntropyLoss()
# Optimizer
self.clas_opt = None
# Metrics
self.bleu = BLEU('BLEU', gram=[2, 3, 4, 5], if_use=cfg.use_bleu)
self.nll_gen = NLL('NLL_gen', if_use=cfg.use_nll_gen, gpu=cfg.CUDA)
self.nll_div = NLL('NLL_div', if_use=cfg.use_nll_div, gpu=cfg.CUDA)
self.self_bleu = BLEU('Self-BLEU', gram=[2, 3, 4], if_use=cfg.use_self_bleu)
self.clas_acc = ACC(if_use=cfg.use_clas_acc)
self.ppl = PPL(self.train_data, self.test_data, n_gram=5, if_use=cfg.use_ppl)
self.all_metrics = [self.bleu, self.nll_gen, self.nll_div, self.self_bleu, self.ppl]
示例7: train
# 需要导入模块: import config [as 别名]
# 或者: from config import dataset [as 别名]
def train(self):
config = self.config
self.param_init()
self.iter_cnt = 0
iter, min_dev_incorrect = 0, 1e6
monitor = OrderedDict()
batch_per_epoch = int((len(self.unlabeled_loader) + config.train_batch_size - 1) / config.train_batch_size)
min_lr = config.min_lr if hasattr(config, 'min_lr') else 0.0
while True:
if iter % batch_per_epoch == 0:
epoch = iter / batch_per_epoch
if config.dataset != 'svhn' and epoch >= config.max_epochs:
break
epoch_ratio = float(epoch) / float(config.max_epochs)
# use another outer max to prevent any float computation precision problem
self.dis_optimizer.param_groups[0]['lr'] = max(min_lr, config.dis_lr * min(3. * (1. - epoch_ratio), 1.))
self.gen_optimizer.param_groups[0]['lr'] = max(min_lr, config.gen_lr * min(3. * (1. - epoch_ratio), 1.))
iter_vals = self._train()
for k, v in iter_vals.items():
if not monitor.has_key(k):
monitor[k] = 0.
monitor[k] += v
if iter % config.vis_period == 0:
self.visualize()
if iter % config.eval_period == 0:
train_loss, train_incorrect = self.eval(self.labeled_loader)
dev_loss, dev_incorrect = self.eval(self.dev_loader)
unl_acc, gen_acc, max_unl_acc, max_gen_acc = self.eval_true_fake(self.dev_loader, 10)
train_incorrect /= 1.0 * len(self.labeled_loader)
dev_incorrect /= 1.0 * len(self.dev_loader)
min_dev_incorrect = min(min_dev_incorrect, dev_incorrect)
disp_str = '#{}\ttrain: {:.4f}, {:.4f} | dev: {:.4f}, {:.4f} | best: {:.4f}'.format(
iter, train_loss, train_incorrect, dev_loss, dev_incorrect, min_dev_incorrect)
for k, v in monitor.items():
disp_str += ' | {}: {:.4f}'.format(k, v / config.eval_period)
disp_str += ' | [Eval] unl acc: {:.4f}, gen acc: {:.4f}, max unl acc: {:.4f}, max gen acc: {:.4f}'.format(unl_acc, gen_acc, max_unl_acc, max_gen_acc)
disp_str += ' | lr: {:.5f}'.format(self.dis_optimizer.param_groups[0]['lr'])
disp_str += '\n'
monitor = OrderedDict()
self.logger.write(disp_str)
sys.stdout.write(disp_str)
sys.stdout.flush()
iter += 1
self.iter_cnt += 1