當前位置: 首頁>>代碼示例>>Python>>正文


Python skip_thoughts_model.SkipThoughtsModel方法代碼示例

本文整理匯總了Python中skip_thoughts.skip_thoughts_model.SkipThoughtsModel方法的典型用法代碼示例。如果您正苦於以下問題:Python skip_thoughts_model.SkipThoughtsModel方法的具體用法?Python skip_thoughts_model.SkipThoughtsModel怎麽用?Python skip_thoughts_model.SkipThoughtsModel使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在skip_thoughts.skip_thoughts_model的用法示例。


在下文中一共展示了skip_thoughts_model.SkipThoughtsModel方法的7個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: build_inputs

# 需要導入模塊: from skip_thoughts import skip_thoughts_model [as 別名]
# 或者: from skip_thoughts.skip_thoughts_model import SkipThoughtsModel [as 別名]
def build_inputs(self):
    if self.mode == "encode":
      # Encode mode doesn't read from disk, so defer to parent.
      return super(SkipThoughtsModel, self).build_inputs()
    else:
      # Replace disk I/O with random Tensors.
      self.encode_ids = tf.random_uniform(
          [self.config.batch_size, 15],
          minval=0,
          maxval=self.config.vocab_size,
          dtype=tf.int64)
      self.decode_pre_ids = tf.random_uniform(
          [self.config.batch_size, 15],
          minval=0,
          maxval=self.config.vocab_size,
          dtype=tf.int64)
      self.decode_post_ids = tf.random_uniform(
          [self.config.batch_size, 15],
          minval=0,
          maxval=self.config.vocab_size,
          dtype=tf.int64)
      self.encode_mask = tf.ones_like(self.encode_ids)
      self.decode_pre_mask = tf.ones_like(self.decode_pre_ids)
      self.decode_post_mask = tf.ones_like(self.decode_post_ids) 
開發者ID:ringringyi,項目名稱:DOTA_models,代碼行數:26,代碼來源:skip_thoughts_model_test.py

示例2: build_graph_from_config

# 需要導入模塊: from skip_thoughts import skip_thoughts_model [as 別名]
# 或者: from skip_thoughts.skip_thoughts_model import SkipThoughtsModel [as 別名]
def build_graph_from_config(self, model_config, checkpoint_path):
    """Builds the inference graph from a configuration object.

    Args:
      model_config: Object containing configuration for building the model.
      checkpoint_path: Checkpoint file or a directory containing a checkpoint
        file.

    Returns:
      restore_fn: A function such that restore_fn(sess) loads model variables
        from the checkpoint file.
    """
    tf.logging.info("Building model.")
    model = skip_thoughts_model.SkipThoughtsModel(model_config, mode="encode")
    model.build()
    saver = tf.train.Saver()

    return self._create_restore_fn(checkpoint_path, saver) 
開發者ID:ringringyi,項目名稱:DOTA_models,代碼行數:20,代碼來源:skip_thoughts_encoder.py

示例3: testBuildForTraining

# 需要導入模塊: from skip_thoughts import skip_thoughts_model [as 別名]
# 或者: from skip_thoughts.skip_thoughts_model import SkipThoughtsModel [as 別名]
def testBuildForTraining(self):
    model = SkipThoughtsModel(self._model_config, mode="train")
    model.build()

    self._checkModelParameters()

    expected_shapes = {
        # [batch_size, length]
        model.encode_ids: (128, 15),
        model.decode_pre_ids: (128, 15),
        model.decode_post_ids: (128, 15),
        model.encode_mask: (128, 15),
        model.decode_pre_mask: (128, 15),
        model.decode_post_mask: (128, 15),
        # [batch_size, length, word_embedding_dim]
        model.encode_emb: (128, 15, 620),
        model.decode_pre_emb: (128, 15, 620),
        model.decode_post_emb: (128, 15, 620),
        # [batch_size, encoder_dim]
        model.thought_vectors: (128, 2400),
        # [batch_size * length]
        model.target_cross_entropy_losses[0]: (1920,),
        model.target_cross_entropy_losses[1]: (1920,),
        # [batch_size * length]
        model.target_cross_entropy_loss_weights[0]: (1920,),
        model.target_cross_entropy_loss_weights[1]: (1920,),
        # Scalar
        model.total_loss: (),
    }
    self._checkOutputs(expected_shapes) 
開發者ID:ringringyi,項目名稱:DOTA_models,代碼行數:32,代碼來源:skip_thoughts_model_test.py

示例4: testBuildForEval

# 需要導入模塊: from skip_thoughts import skip_thoughts_model [as 別名]
# 或者: from skip_thoughts.skip_thoughts_model import SkipThoughtsModel [as 別名]
def testBuildForEval(self):
    model = SkipThoughtsModel(self._model_config, mode="eval")
    model.build()

    self._checkModelParameters()

    expected_shapes = {
        # [batch_size, length]
        model.encode_ids: (128, 15),
        model.decode_pre_ids: (128, 15),
        model.decode_post_ids: (128, 15),
        model.encode_mask: (128, 15),
        model.decode_pre_mask: (128, 15),
        model.decode_post_mask: (128, 15),
        # [batch_size, length, word_embedding_dim]
        model.encode_emb: (128, 15, 620),
        model.decode_pre_emb: (128, 15, 620),
        model.decode_post_emb: (128, 15, 620),
        # [batch_size, encoder_dim]
        model.thought_vectors: (128, 2400),
        # [batch_size * length]
        model.target_cross_entropy_losses[0]: (1920,),
        model.target_cross_entropy_losses[1]: (1920,),
        # [batch_size * length]
        model.target_cross_entropy_loss_weights[0]: (1920,),
        model.target_cross_entropy_loss_weights[1]: (1920,),
        # Scalar
        model.total_loss: (),
    }
    self._checkOutputs(expected_shapes) 
開發者ID:ringringyi,項目名稱:DOTA_models,代碼行數:32,代碼來源:skip_thoughts_model_test.py

示例5: testBuildForEncode

# 需要導入模塊: from skip_thoughts import skip_thoughts_model [as 別名]
# 或者: from skip_thoughts.skip_thoughts_model import SkipThoughtsModel [as 別名]
def testBuildForEncode(self):
    model = SkipThoughtsModel(self._model_config, mode="encode")
    model.build()

    # Test feeding a batch of word embeddings to get skip thought vectors.
    encode_emb = np.random.rand(64, 15, 620)
    encode_mask = np.ones((64, 15), dtype=np.int64)
    feed_dict = {model.encode_emb: encode_emb, model.encode_mask: encode_mask}
    expected_shapes = {
        # [batch_size, encoder_dim]
        model.thought_vectors: (64, 2400),
    }
    self._checkOutputs(expected_shapes, feed_dict) 
開發者ID:ringringyi,項目名稱:DOTA_models,代碼行數:15,代碼來源:skip_thoughts_model_test.py

示例6: main

# 需要導入模塊: from skip_thoughts import skip_thoughts_model [as 別名]
# 或者: from skip_thoughts.skip_thoughts_model import SkipThoughtsModel [as 別名]
def main(unused_argv):
  if not FLAGS.input_file_pattern:
    raise ValueError("--input_file_pattern is required.")
  if not FLAGS.train_dir:
    raise ValueError("--train_dir is required.")

  model_config = configuration.model_config(
      input_file_pattern=FLAGS.input_file_pattern)
  training_config = configuration.training_config()

  tf.logging.info("Building training graph.")
  g = tf.Graph()
  with g.as_default():
    model = skip_thoughts_model.SkipThoughtsModel(model_config, mode="train")
    model.build()

    learning_rate = _setup_learning_rate(training_config, model.global_step)
    optimizer = tf.train.AdamOptimizer(learning_rate)

    train_tensor = tf.contrib.slim.learning.create_train_op(
        total_loss=model.total_loss,
        optimizer=optimizer,
        global_step=model.global_step,
        clip_gradient_norm=training_config.clip_gradient_norm)

    saver = tf.train.Saver()

  tf.contrib.slim.learning.train(
      train_op=train_tensor,
      logdir=FLAGS.train_dir,
      graph=g,
      global_step=model.global_step,
      number_of_steps=training_config.number_of_steps,
      save_summaries_secs=training_config.save_summaries_secs,
      saver=saver,
      save_interval_secs=training_config.save_model_secs) 
開發者ID:ringringyi,項目名稱:DOTA_models,代碼行數:38,代碼來源:train.py

示例7: main

# 需要導入模塊: from skip_thoughts import skip_thoughts_model [as 別名]
# 或者: from skip_thoughts.skip_thoughts_model import SkipThoughtsModel [as 別名]
def main(unused_argv):
    if not FLAGS.input_file_pattern:
        raise ValueError("--input_file_pattern is required.")
    if not FLAGS.train_dir:
        raise ValueError("--train_dir is required.")

    model_config = configuration.model_config(
        input_file_pattern=FLAGS.input_file_pattern)
    training_config = configuration.training_config()

    tf.logging.info("Building training graph.")
    g = tf.Graph()
    with g.as_default():
        model = skip_thoughts_model.SkipThoughtsModel(model_config,
                                                      mode="train")
        model.build()

        learning_rate = _setup_learning_rate(training_config, model.global_step)
        optimizer = tf.train.AdamOptimizer(learning_rate)

        train_tensor = tf.contrib.slim.learning.create_train_op(
            total_loss=model.total_loss,
            optimizer=optimizer,
            global_step=model.global_step,
            clip_gradient_norm=training_config.clip_gradient_norm)

        saver = tf.train.Saver()

    tf.contrib.slim.learning.train(
        train_op=train_tensor,
        logdir=FLAGS.train_dir,
        graph=g,
        global_step=model.global_step,
        number_of_steps=training_config.number_of_steps,
        save_summaries_secs=training_config.save_summaries_secs,
        saver=saver,
        save_interval_secs=training_config.save_model_secs) 
開發者ID:snuspl,項目名稱:parallax,代碼行數:39,代碼來源:train.py


注:本文中的skip_thoughts.skip_thoughts_model.SkipThoughtsModel方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。