当前位置: 首页>>代码示例>>Python>>正文


Python data_reader.DataReader方法代码示例

本文整理汇总了Python中data_reader.DataReader方法的典型用法代码示例。如果您正苦于以下问题:Python data_reader.DataReader方法的具体用法?Python data_reader.DataReader怎么用?Python data_reader.DataReader使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在data_reader的用法示例。


在下文中一共展示了data_reader.DataReader方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: save

# 需要导入模块: import data_reader [as 别名]
# 或者: from data_reader import DataReader [as 别名]
def save(artist, model_path, num_save):
    sample_save_dir = c.get_dir('../save/samples/')
    sess = tf.Session()

    print artist

    data_reader = DataReader(artist)
    vocab = data_reader.get_vocab()

    print 'Init model...'
    model = LSTMModel(sess,
                      vocab,
                      c.BATCH_SIZE,
                      c.SEQ_LEN,
                      c.CELL_SIZE,
                      c.NUM_LAYERS,
                      test=True)

    saver = tf.train.Saver()
    sess.run(tf.initialize_all_variables())

    saver.restore(sess, model_path)
    print 'Model restored from ' + model_path

    artist_save_dir = c.get_dir(join(sample_save_dir, artist))
    for i in xrange(num_save):
        print i

        path = join(artist_save_dir, str(i) + '.txt')
        sample = model.generate()
        processed_sample = process_sample(sample)

        with open(path, 'w') as f:
            f.write(processed_sample) 
开发者ID:dyelax,项目名称:encore.ai,代码行数:36,代码来源:save_samples.py

示例2: __init__

# 需要导入模块: import data_reader [as 别名]
# 或者: from data_reader import DataReader [as 别名]
def __init__(self, model_load_path, artist_name, test, prime_text):
        """
        Initializes the Lyric Generation Runner.

        @param model_load_path: The path from which to load a previously-saved model.
                                Default = None.
        @param artist_name: The name of the artist on which to train. (Used to grab data).
                            Default = 'kanye_west'
        @param test: Whether to test or train the model. Testing generates a sequence from the
                     provided model and artist. Default = False.
        @param prime_text: The text with which to start the test sequence.
        """

        self.sess = tf.Session()
        self.artist_name = artist_name

        print 'Process data...'
        self.data_reader = DataReader(self.artist_name)
        self.vocab = self.data_reader.get_vocab()

        print 'Init model...'
        self.model = LSTMModel(self.sess,
                               self.vocab,
                               c.BATCH_SIZE,
                               c.SEQ_LEN,
                               c.CELL_SIZE,
                               c.NUM_LAYERS,
                               test=test)

        print 'Init variables...'
        self.saver = tf.train.Saver(max_to_keep=None)
        self.sess.run(tf.global_variables_initializer())

        # if load path specified, load a saved model
        if model_load_path is not None:
            self.saver.restore(self.sess, model_load_path)
            print 'Model restored from ' + model_load_path


        if test:
            self.test(prime_text)
        else:
            self.train() 
开发者ID:dyelax,项目名称:encore.ai,代码行数:45,代码来源:runner.py

示例3: main

# 需要导入模块: import data_reader [as 别名]
# 或者: from data_reader import DataReader [as 别名]
def main(args):

  logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
  coord = tf.train.Coordinator()

  if args.mode == "train":
    with tf.name_scope('create_inputs'):
      data_reader = DataReader(
          data_dir=args.train_dir,
          data_list=args.train_list,
          mask_window=0.4,
          queue_size=args.batch_size*3,
          coord=coord)
      if args.valid_list is not None:
        data_reader_valid = DataReader(
            data_dir=args.valid_dir,
            data_list=args.valid_list,
            mask_window=0.4,
            queue_size=args.batch_size*2,
            coord=coord)
        logging.info("Dataset size: train {}, valid {}".format(data_reader.num_data, data_reader_valid.num_data))
      else:
      	data_reader_valid = None
      	logging.info("Dataset size: train {}".format(data_reader.num_data))
    train_fn(args, data_reader, data_reader_valid)
  
  elif args.mode == "valid" or args.mode == "test":
    with tf.name_scope('create_inputs'):
      data_reader = DataReader_test(
          data_dir=args.data_dir,
          data_list=args.data_list,
          mask_window=0.4,
          queue_size=args.batch_size*10,
          coord=coord)
    valid_fn(args, data_reader)

  elif args.mode == "pred":
    with tf.name_scope('create_inputs'):
      if args.input_mseed:
        data_reader = DataReader_mseed(
            data_dir=args.data_dir,
            data_list=args.data_list,
            queue_size=args.batch_size*10,
            coord=coord,
            input_length=args.input_length)
      else:
        data_reader = DataReader_pred(
            data_dir=args.data_dir,
            data_list=args.data_list,
            queue_size=args.batch_size*10,
            coord=coord,
            input_length=args.input_length)
    pred_fn(args, data_reader, log_dir=args.output_dir)

  else:
    print("mode should be: train, valid, test, pred or debug")

  return 
开发者ID:wayneweiqiang,项目名称:PhaseNet,代码行数:60,代码来源:run.py

示例4: test

# 需要导入模块: import data_reader [as 别名]
# 或者: from data_reader import DataReader [as 别名]
def test(self):

        batch_size = 4
        num_unroll_steps = 3
        char_vocab_size = 51
        max_word_length = 11
        char_embed_size = 3

        _, _, word_data, char_data, _ = load_data('data/', max_word_length)
        dataset = char_data['train']
        self.assertEqual(dataset.shape, (929589, max_word_length))

        reader = DataReader(word_data['train'], char_data['train'], batch_size=batch_size, num_unroll_steps=num_unroll_steps)
        for x, y in reader.iter():
            assert x.shape == (batch_size, num_unroll_steps, max_word_length)
            break

        self.assertAllClose(X, x)
        self.assertAllClose(Y, y)

        with self.test_session() as session:
            input_ = tf.placeholder(tf.int32, shape=[batch_size, num_unroll_steps, max_word_length], name="input")

            ''' First, embed characters '''
            with tf.variable_scope('Embedding'):
                char_embedding = tf.get_variable('char_embedding', [char_vocab_size, char_embed_size])

                # [batch_size x max_word_length, num_unroll_steps, char_embed_size]
                input_embedded = tf.nn.embedding_lookup(char_embedding, input_)

                input_embedded = tf.reshape(input_embedded, [-1, max_word_length, char_embed_size])

            session.run(tf.assign(char_embedding, EMBEDDING))
            ie = session.run(input_embedded, {
                input_: x
            })

            #print(x.shape)
            #print(np.transpose(x, (1, 0, 2)))
            #print(ie.shape)
            ie = ie.reshape([batch_size, num_unroll_steps, max_word_length, char_embed_size])
            ie = np.transpose(ie, (1, 0, 2, 3))
            #print(ie[0,:,:,:])

            self.assertAllClose(IE3, ie[0,:,:,:]) 
开发者ID:mkroutikov,项目名称:tf-lstm-char-cnn,代码行数:47,代码来源:test_embedding.py

示例5: main

# 需要导入模块: import data_reader [as 别名]
# 或者: from data_reader import DataReader [as 别名]
def main(_):
  config = tf.ConfigProto(allow_soft_placement=FLAGS.allow_soft_placement)
  with tf.Session(config=config) as sess:
    print('\n{} Model initializing'.format(datetime.now()))

    model = VistaNet(FLAGS.hidden_dim, FLAGS.att_dim, FLAGS.emb_size, FLAGS.num_images, FLAGS.num_classes)
    loss = loss_fn(model.labels, model.logits)
    train_op = train_fn(loss, model.global_step)
    accuracy = eval_fn(model.labels, model.logits)
    summary_op = tf.summary.merge_all()

    sess.run(tf.global_variables_initializer())
    train_summary_writer.add_graph(sess.graph)
    saver = tf.train.Saver(max_to_keep=FLAGS.num_checkpoints)
    data_reader = DataReader(num_images=FLAGS.num_images, train_shuffle=True)

    print('\n{} Start training'.format(datetime.now()))

    epoch = 0
    best_loss = float('inf')
    while epoch < FLAGS.num_epochs:
      epoch += 1
      print('\n=> Epoch: {}'.format(epoch))

      train(sess, data_reader, model, train_op, loss, accuracy, summary_op)

      print('=> Evaluation')
      print('best_loss={:.4f}'.format(best_loss))
      valid_loss, valid_acc = evaluate(sess, data_reader.read_valid_set(batch_size=FLAGS.batch_size),
                                       model, loss, accuracy, summary_op)
      print('valid_loss={:.4f}, valid_acc={:.4f}'.format(valid_loss, valid_acc))

      if valid_loss < best_loss:
        best_loss = valid_loss
        save_path = os.path.join(FLAGS.checkpoint_dir,
                                 'epoch={}-loss={:.4f}-acc={:.4f}'.format(epoch, valid_loss, valid_acc))
        saver.save(sess, save_path)
        print('Best model saved @ {}'.format(save_path))

        print('=> Testing')
        result_file = open(
          os.path.join(FLAGS.log_dir, 'loss={:.4f},acc={:.4f},epoch={}'.format(valid_loss, valid_acc, epoch)), 'w')
        test(sess, data_reader, model, loss, accuracy, epoch, result_file)

  print("{} Optimization Finished!".format(datetime.now())) 
开发者ID:PreferredAI,项目名称:vista-net,代码行数:47,代码来源:train.py


注:本文中的data_reader.DataReader方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。