本文整理汇总了Python中data_utils.Dataset方法的典型用法代码示例。如果您正苦于以下问题:Python data_utils.Dataset方法的具体用法?Python data_utils.Dataset怎么用?Python data_utils.Dataset使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类data_utils
的用法示例。
在下文中一共展示了data_utils.Dataset方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: import data_utils [as 别名]
# 或者: from data_utils import Dataset [as 别名]
def main(_):
hps = LM.get_default_hparams().parse(FLAGS.hpconfig)
hps.num_gpus = FLAGS.num_gpus
vocab = Vocabulary.from_file("1b_word_vocab.txt")
if FLAGS.mode == "train":
hps.batch_size = 256
dataset = Dataset(vocab, FLAGS.datadir + "/training-monolingual.tokenized.shuffled/*")
run_train(dataset, hps, FLAGS.logdir + "/train", ps_device="/gpu:0")
elif FLAGS.mode.startswith("eval_"):
if FLAGS.mode.startswith("eval_train"):
data_dir = FLAGS.datadir + "/training-monolingual.tokenized.shuffled/*"
else:
data_dir = FLAGS.datadir + "/heldout-monolingual.tokenized.shuffled/news.en.heldout-00000-of-00050"
dataset = Dataset(vocab, data_dir, deterministic=True)
run_eval(dataset, hps, FLAGS.logdir, FLAGS.mode, FLAGS.eval_steps)
示例2: test_dataset
# 需要导入模块: import data_utils [as 别名]
# 或者: from data_utils import Dataset [as 别名]
def test_dataset(self):
vocab = Vocabulary.from_file("testdata/test_vocab.txt")
dataset = Dataset(vocab, "testdata/*")
def generator():
for i in range(1, 10):
yield [0] + list(range(1, i + 1)) + [0]
counts = [0] * 10
for seq in generator():
for v in seq:
counts[v] += 1
counts2 = [0] * 10
for x, y, w in dataset._iterate(generator(), 2, 4):
for v in x.ravel():
counts2[v] += 1
for i in range(1, 10):
self.assertEqual(counts[i], counts2[i], "Mismatch at i=%d" % i)
示例3: test_dataset
# 需要导入模块: import data_utils [as 别名]
# 或者: from data_utils import Dataset [as 别名]
def test_dataset(self):
vocab = Vocabulary.from_file("testdata/test_vocab.txt")
dataset = Dataset(vocab, "testdata/*")
def generator():
for i in range(1, 10):
yield [0] + list(range(1, i + 1)) + [0]
counts = [0] * 10
for seq in generator():
for v in seq:
counts[v] += 1
counts2 = [0] * 10
for x, y in dataset._iterate(generator(), 2, 4):
for v in x.ravel():
counts2[v] += 1
for i in range(1, 10):
self.assertEqual(counts[i], counts2[i], "Mismatch at i=%d. counts[i]=%s, counts2[i]=%s" % (i,counts[i], counts2[i]))
示例4: main
# 需要导入模块: import data_utils [as 别名]
# 或者: from data_utils import Dataset [as 别名]
def main(_):
"""
Start either train or eval. Note hardcoded parts of path for training and eval data
"""
hps = LM.get_default_hparams().parse(FLAGS.hpconfig)
hps._set("num_gpus", FLAGS.num_gpus)
print('*****HYPER PARAMETERS*****')
print(hps)
print('**************************')
vocab = Vocabulary.from_file(os.path.join(FLAGS.datadir, "1b_word_vocab.txt"))
if FLAGS.mode == "train":
#hps.batch_size = 256
dataset = Dataset(vocab, os.path.join(FLAGS.datadir,
"training-monolingual.tokenized.shuffled/*"))
run_train(dataset, hps, os.path.join(FLAGS.logdir, "train"), ps_device="/gpu:0")
elif FLAGS.mode.startswith("eval_"):
if FLAGS.mode.startswith("eval_train"):
data_dir = os.path.join(FLAGS.datadir, "training-monolingual.tokenized.shuffled/*")
elif FLAGS.mode.startswith("eval_full"):
data_dir = os.path.join(FLAGS.datadir, "heldout-monolingual.tokenized.shuffled/news.en.heldout-00000-of-00050")
else:
data_dir = os.path.join(FLAGS.datadir, "heldout-monolingual.tokenized.shuffled/news.en.heldout-00000-of-00050")
dataset = Dataset(vocab, data_dir, deterministic=True)
run_eval(dataset, hps, FLAGS.logdir, FLAGS.mode, FLAGS.eval_steps)
elif FLAGS.mode.startswith("infer"):
data_dir = os.path.join(FLAGS.datadir, "heldout-monolingual.tokenized.shuffled/news.en.heldout-00000-of-00050")
dataset = Dataset(vocab, data_dir, deterministic=True)
run_infer(dataset, hps, FLAGS.logdir, FLAGS.mode, vocab)
示例5: main
# 需要导入模块: import data_utils [as 别名]
# 或者: from data_utils import Dataset [as 别名]
def main():
# configuration
config = Config()
config.parse_arg(FLAGS)
config.setup_path()
config.print_arg()
# dataset
if(config.dataset == 'wikibio'):
dset = DatasetTable2text(config)
dset.load()
config.key_size = len(dset.key2id)
else:
dset = Dataset(config)
dset.build()
config.vocab_size = len(dset.word2id)
config.dec_start_id = dset.word2id["_GOO"]
config.dec_end_id = dset.word2id["_EOS"]
config.pad_id = dset.pad_id
config.stop_words = dset.stop_words
# model
if(config.model_name == "seq2seq"):
if(config.dataset == 'wikibio'): Model = Seq2seqData2text
else: Model = Seq2seq
elif(config.model_name == "bow_seq2seq"): Model = BowSeq2seq
elif(config.model_name == "vae"): Model = Vae
elif(config.model_name == "hierarchical_vae"): Model = Hierarchical_Vae
elif(config.model_name == "latent_bow"):
if(config.dataset == 'wikibio'): Model = LatentBowData2text
else: Model = LatentBow
elif(config.model_name == "lm"): Model = LM
else:
msg = "the model name shoule be in ['seq2seq', 'vae', 'hierarchical_vae', 'latent_low', 'lm'], "
msg += "current name: %s" % config.model_name
raise Exception(msg)
model = Model(config)
with tf.variable_scope(config.model_name):
model.build()
# controller
controller = Controller(config)
if(config.model_name != "lm"):
if("lm" in controller.eval_metrics_list): controller.build_lm(LM, config)
controller.train(model, dset)
return