本文整理汇总了Python中miscc.config.cfg.DATA_DIR属性的典型用法代码示例。如果您正苦于以下问题:Python cfg.DATA_DIR属性的具体用法?Python cfg.DATA_DIR怎么用?Python cfg.DATA_DIR使用的例子?那么恭喜您, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在类miscc.config.cfg
的用法示例。
在下文中一共展示了cfg.DATA_DIR属性的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_dataset_indices
# 需要导入模块: from miscc.config import cfg [as 别名]
# 或者: from miscc.config.cfg import DATA_DIR [as 别名]
def get_dataset_indices(split="train", num_max_objects=10):
if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
label_path = os.path.join(os.path.join(cfg.DATA_DIR, split), 'labels_large.pickle')
with open(label_path, "rb") as f:
labels = pickle.load(f, encoding='latin1')
labels = np.array(labels)
dataset_indices = []
for _i in range(num_max_objects+1):
dataset_indices.append([])
for index, label in enumerate(labels):
for idx, l in enumerate(label):
if l == -1:
dataset_indices[idx].append(index)
break
else:
dataset_indices[-1].append(index)
return dataset_indices
开发者ID:tohinz,项目名称:semantic-object-accuracy-for-generative-text-to-image-synthesis,代码行数:22,代码来源:main.py
示例2: save_model
# 需要导入模块: from miscc.config import cfg [as 别名]
# 或者: from miscc.config.cfg import DATA_DIR [as 别名]
def save_model(netG, avg_param_G, netsD, epoch, model_dir):
last_run_dir = cfg.DATA_DIR + '/' + cfg.LAST_RUN_DIR + '/Model'
load_params(netG, avg_param_G)
torch.save(
netG.state_dict(),
'%s/netG_%d.pth' % (model_dir, epoch))
torch.save(
netG.state_dict(),
'%s/netG.pth' % (last_run_dir))
with open(last_run_dir + '/count.txt', 'w') as f:
f.write(str(epoch))
for i in range(len(netsD)):
netD = netsD[i]
torch.save(
netD.state_dict(),
'%s/netD%d.pth' % (model_dir, i))
torch.save(
netD.state_dict(),
'%s/netD%d.pth' % (last_run_dir, i))
print('Save G/Ds models.')
示例3: gen_example
# 需要导入模块: from miscc.config import cfg [as 别名]
# 或者: from miscc.config.cfg import DATA_DIR [as 别名]
def gen_example(wordtoix, algo):
'''generate images from example sentences'''
from nltk.tokenize import RegexpTokenizer
filepath = '%s/example_filenames.txt' % (cfg.DATA_DIR)
data_dic = {}
with open(filepath, "r") as f:
filenames = f.read().decode('utf8').split('\n')
for name in filenames:
if len(name) == 0:
continue
filepath = '%s/%s.txt' % (cfg.DATA_DIR, name)
with open(filepath, "r") as f:
print('Load from:', name)
sentences = f.read().decode('utf8').split('\n')
# a list of indices for a sentence
captions = []
cap_lens = []
for sent in sentences:
if len(sent) == 0:
continue
sent = sent.replace("\ufffd\ufffd", " ")
tokenizer = RegexpTokenizer(r'\w+')
tokens = tokenizer.tokenize(sent.lower())
if len(tokens) == 0:
print('sent', sent)
continue
rev = []
for t in tokens:
t = t.encode('ascii', 'ignore').decode('ascii')
if len(t) > 0 and t in wordtoix:
rev.append(wordtoix[t])
captions.append(rev)
cap_lens.append(len(rev))
max_len = np.max(cap_lens)
sorted_indices = np.argsort(cap_lens)[::-1]
cap_lens = np.asarray(cap_lens)
cap_lens = cap_lens[sorted_indices]
cap_array = np.zeros((len(captions), max_len), dtype='int64')
for i in range(len(captions)):
idx = sorted_indices[i]
cap = captions[idx]
c_len = len(cap)
cap_array[i, :c_len] = cap
key = name[(name.rfind('/') + 1):]
data_dic[key] = [cap_array, cap_lens, sorted_indices]
algo.gen_example(data_dic)
示例4: load_network
# 需要导入模块: from miscc.config import cfg [as 别名]
# 或者: from miscc.config.cfg import DATA_DIR [as 别名]
def load_network(gpus):
netG = G_NET()
netG.apply(weights_init)
netG = torch.nn.DataParallel(netG, device_ids=gpus)
print(netG)
netsD = []
if cfg.TREE.BRANCH_NUM > 0:
netsD.append(D_NET64())
if cfg.TREE.BRANCH_NUM > 1:
netsD.append(D_NET128())
if cfg.TREE.BRANCH_NUM > 2:
netsD.append(D_NET256())
if cfg.TREE.BRANCH_NUM > 3:
netsD.append(D_NET512())
if cfg.TREE.BRANCH_NUM > 4:
netsD.append(D_NET1024())
# TODO: if cfg.TREE.BRANCH_NUM > 5:
for i in range(len(netsD)):
netsD[i].apply(weights_init)
netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus)
# print(netsD[i])
print('# of netsD', len(netsD))
count = 0
if cfg.TRAIN.NET_G != '':
state_dict = torch.load(cfg.TRAIN.NET_G)
netG.load_state_dict(state_dict)
print('Load ', cfg.TRAIN.NET_G)
try:
istart = cfg.TRAIN.NET_G.rfind('_') + 1
iend = cfg.TRAIN.NET_G.rfind('.')
count = cfg.TRAIN.NET_G[istart:iend]
count = int(count)
except:
last_run_dir = cfg.DATA_DIR + '/' + cfg.LAST_RUN_DIR + '/Model'
with open(last_run_dir + '/count.txt', 'r') as f:
count = int(f.read())
count = int(count) + 1
if cfg.TRAIN.NET_D != '':
for i in range(len(netsD)):
print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i))
state_dict = torch.load('%s%d.pth' % (cfg.TRAIN.NET_D, i))
netsD[i].load_state_dict(state_dict)
inception_model = INCEPTION_V3()
if cfg.CUDA:
netG.cuda()
for i in range(len(netsD)):
netsD[i].cuda()
inception_model = inception_model.cuda()
inception_model.eval()
return netG, netsD, len(netsD), inception_model, count
示例5: save_img_results
# 需要导入模块: from miscc.config import cfg [as 别名]
# 或者: from miscc.config.cfg import DATA_DIR [as 别名]
def save_img_results(imgs_tcpu, fake_imgs, num_imgs,
count, image_dir, summary_writer, rec_ids , im_ids):
num = cfg.TRAIN.VIS_COUNT
last_run_dir = cfg.DATA_DIR + '/' + cfg.LAST_RUN_DIR + '/Image/'
# The range of real_img (i.e., self.imgs_tcpu[i][0:num])
# is changed to [0, 1] by function vutils.save_image
real_img = imgs_tcpu[-1][0:num]
vutils.save_image(
real_img, '%s/count_%09d_real_samples.png' % (image_dir, count),
normalize=True)
vutils.save_image(
real_img, last_run_dir + 'real_samples.png',
normalize=True)
# write images and recipe IDs to filenames
rec_ids = [t.tostring().decode('UTF-8') for t in rec_ids.numpy()]
im_ids = [t.tostring().decode('UTF-8') for t in im_ids.numpy()]
with open('%s/count_%09d_real_samples_IDs.txt' % (image_dir, count),"w") as f:
for rec_id, im_id in zip(rec_ids, im_ids):
f.write("rec_id=%s, img_id=%s\n" % (rec_id,im_id))
with open(last_run_dir + 'real_samples_IDs.txt',"w") as f:
for rec_id, im_id in zip(rec_ids, im_ids):
f.write("rec_id=%s, img_id=%s\n" % (rec_id,im_id))
real_img_set = vutils.make_grid(real_img).numpy()
real_img_set = np.transpose(real_img_set, (1, 2, 0))
real_img_set = real_img_set * 255
real_img_set = real_img_set.astype(np.uint8)
sup_real_img = summary.image('real_img', real_img_set)
summary_writer.add_summary(sup_real_img, count)
for i in range(num_imgs):
fake_img = fake_imgs[i][0:num]
# The range of fake_img.data (i.e., self.fake_imgs[i][0:num])
# is still [-1. 1]...
vutils.save_image(
fake_img.data, '%s/count_%09d_fake_samples%d.png' %
(image_dir, count, i), normalize=True)
vutils.save_image(
fake_img.data, last_run_dir + 'fake_samples%d.png' %
(i), normalize=True)
fake_img_set = vutils.make_grid(fake_img.data).cpu().numpy()
fake_img_set = np.transpose(fake_img_set, (1, 2, 0))
fake_img_set = (fake_img_set + 1) * 255 / 2
fake_img_set = fake_img_set.astype(np.uint8)
sup_fake_img = summary.image('fake_img%d' % i, fake_img_set)
summary_writer.add_summary(sup_fake_img, count)
summary_writer.flush()
# ################## For uncondional tasks ######################### #