當前位置: 首頁>>代碼示例>>Python>>正文


Python spec_pb2.GridPoint方法代碼示例

本文整理匯總了Python中dragnn.protos.spec_pb2.GridPoint方法的典型用法代碼示例。如果您正苦於以下問題:Python spec_pb2.GridPoint方法的具體用法?Python spec_pb2.GridPoint怎麽用?Python spec_pb2.GridPoint使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在dragnn.protos.spec_pb2的用法示例。


在下文中一共展示了spec_pb2.GridPoint方法的10個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: _create_learning_rate

# 需要導入模塊: from dragnn.protos import spec_pb2 [as 別名]
# 或者: from dragnn.protos.spec_pb2 import GridPoint [as 別名]
def _create_learning_rate(hyperparams, step_var):
  """Creates learning rate var, with decay and switching for CompositeOptimizer.

  Args:
    hyperparams: a GridPoint proto containing optimizer spec, particularly
      learning_method to determine optimizer class to use.
    step_var: tf.Variable, global training step.

  Returns:
    a scalar `Tensor`, the learning rate based on current step and hyperparams.
  """
  if hyperparams.learning_method != 'composite':
    base_rate = hyperparams.learning_rate
  else:
    spec = hyperparams.composite_optimizer_spec
    switch = tf.less(step_var, spec.switch_after_steps)
    base_rate = tf.cond(switch, lambda: tf.constant(spec.method1.learning_rate),
                        lambda: tf.constant(spec.method2.learning_rate))
  return tf.train.exponential_decay(
      base_rate,
      step_var,
      hyperparams.decay_steps,
      hyperparams.decay_base,
      staircase=hyperparams.decay_staircase) 
開發者ID:ringringyi,項目名稱:DOTA_models,代碼行數:26,代碼來源:graph_builder.py

示例2: __init__

# 需要導入模塊: from dragnn.protos import spec_pb2 [as 別名]
# 或者: from dragnn.protos.spec_pb2 import GridPoint [as 別名]
def __init__(self):
    self.spec = spec_pb2.MasterSpec()
    self.hyperparams = spec_pb2.GridPoint()
    self.lookup_component = {
        'previous': MockComponent(self, spec_pb2.ComponentSpec())
    } 
開發者ID:ringringyi,項目名稱:DOTA_models,代碼行數:8,代碼來源:network_units_test.py

示例3: MakeHyperparams

# 需要導入模塊: from dragnn.protos import spec_pb2 [as 別名]
# 或者: from dragnn.protos.spec_pb2 import GridPoint [as 別名]
def MakeHyperparams(self, **kwargs):
    hyperparam_config = spec_pb2.GridPoint()
    for key in kwargs:
      setattr(hyperparam_config, key, kwargs[key])
    return hyperparam_config 
開發者ID:ringringyi,項目名稱:DOTA_models,代碼行數:7,代碼來源:graph_builder_test.py

示例4: getBuilderAndTarget

# 需要導入模塊: from dragnn.protos import spec_pb2 [as 別名]
# 或者: from dragnn.protos.spec_pb2 import GridPoint [as 別名]
def getBuilderAndTarget(
      self, test_name, master_spec_path='simple_parser_master_spec.textproto'):
    """Generates a MasterBuilder and TrainTarget based on a simple spec."""
    master_spec = self.LoadSpec(master_spec_path)
    hyperparam_config = spec_pb2.GridPoint()
    target = spec_pb2.TrainTarget()
    target.name = 'test-%s-train' % test_name
    target.component_weights.extend([0] * len(master_spec.component))
    target.component_weights[-1] = 1.0
    target.unroll_using_oracle.extend([False] * len(master_spec.component))
    target.unroll_using_oracle[-1] = True
    builder = graph_builder.MasterBuilder(
        master_spec, hyperparam_config, pool_scope=test_name)
    return builder, target 
開發者ID:ringringyi,項目名稱:DOTA_models,代碼行數:16,代碼來源:graph_builder_test.py

示例5: __init__

# 需要導入模塊: from dragnn.protos import spec_pb2 [as 別名]
# 或者: from dragnn.protos.spec_pb2 import GridPoint [as 別名]
def __init__(self):
    self.spec = spec_pb2.MasterSpec()
    self.hyperparams = spec_pb2.GridPoint()
    self.lookup_component = {'mock': MockComponent()} 
開發者ID:ringringyi,項目名稱:DOTA_models,代碼行數:6,代碼來源:bulk_component_test.py

示例6: _validate_grid_point

# 需要導入模塊: from dragnn.protos import spec_pb2 [as 別名]
# 或者: from dragnn.protos.spec_pb2 import GridPoint [as 別名]
def _validate_grid_point(hyperparams, is_sub_optimizer=False):
  """Validates that a grid point's configuration is reasonable.

  Args:
    hyperparams (spec_pb2.GridPoint): Grid point to validate.
    is_sub_optimizer (bool): Whether this optimizer is a sub-optimizer of
      a composite optimizer.

  Raises:
    ValueError: If the grid point is not valid.
  """
  valid_methods = ('gradient_descent', 'adam', 'lazyadam', 'momentum',
                   'composite')
  if hyperparams.learning_method not in valid_methods:
    raise ValueError('Unknown learning method (optimizer)')

  if is_sub_optimizer:
    for base_only_field in ('decay_steps', 'decay_base', 'decay_staircase'):
      if hyperparams.HasField(base_only_field):
        raise ValueError('Field {} is not valid for sub-optimizers of a '
                         'composite optimizer.'.format(base_only_field))

  if hyperparams.learning_method == 'composite':
    spec = hyperparams.composite_optimizer_spec
    if spec.switch_after_steps < 1:
      raise ValueError('switch_after_steps {} not valid for composite '
                       'optimizer!'.format(spec.switch_after_steps))
    for sub_optimizer in (spec.method1, spec.method2):
      _validate_grid_point(sub_optimizer, is_sub_optimizer=True) 
開發者ID:rky0930,項目名稱:yolo_v2,代碼行數:31,代碼來源:graph_builder.py

示例7: _create_learning_rate

# 需要導入模塊: from dragnn.protos import spec_pb2 [as 別名]
# 或者: from dragnn.protos.spec_pb2 import GridPoint [as 別名]
def _create_learning_rate(hyperparams, step_var):
  """Creates learning rate var, with decay and switching for CompositeOptimizer.

  Args:
    hyperparams: a GridPoint proto containing optimizer spec, particularly
      learning_method to determine optimizer class to use.
    step_var: tf.Variable, global training step.

  Raises:
    ValueError: If the composite optimizer is set, but not correctly configured.

  Returns:
    a scalar `Tensor`, the learning rate based on current step and hyperparams.
  """
  if hyperparams.learning_method != 'composite':
    base_rate = hyperparams.learning_rate
    adjusted_steps = step_var
  else:
    spec = hyperparams.composite_optimizer_spec
    switch = tf.less(step_var, spec.switch_after_steps)
    base_rate = tf.cond(switch, lambda: tf.constant(spec.method1.learning_rate),
                        lambda: tf.constant(spec.method2.learning_rate))
    if spec.reset_learning_rate:
      adjusted_steps = tf.cond(switch, lambda: step_var,
                               lambda: step_var - spec.switch_after_steps)
    else:
      adjusted_steps = step_var

  return tf.train.exponential_decay(
      learning_rate=base_rate,
      global_step=adjusted_steps,
      decay_steps=hyperparams.decay_steps,
      decay_rate=hyperparams.decay_base,
      staircase=hyperparams.decay_staircase) 
開發者ID:rky0930,項目名稱:yolo_v2,代碼行數:36,代碼來源:graph_builder.py

示例8: RunTraining

# 需要導入模塊: from dragnn.protos import spec_pb2 [as 別名]
# 或者: from dragnn.protos.spec_pb2 import GridPoint [as 別名]
def RunTraining(self, hyperparam_config):
    master_spec = self.LoadSpec('master_spec_link.textproto')

    self.assertTrue(isinstance(hyperparam_config, spec_pb2.GridPoint))
    gold_doc = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_GOLD_SENTENCE, gold_doc)
    gold_doc_2 = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_GOLD_SENTENCE_2, gold_doc_2)
    reader_strings = [
        gold_doc.SerializeToString(),
        gold_doc_2.SerializeToString()
    ]
    tf.logging.info('Generating graph with config: %s', hyperparam_config)
    with tf.Graph().as_default():
      builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)

      target = spec_pb2.TrainTarget()
      target.name = 'testTraining-all'
      train = builder.add_training_from_config(target)
      with self.test_session() as sess:
        logging.info('Initializing')
        sess.run(tf.global_variables_initializer())

        # Run one iteration of training and verify nothing crashes.
        logging.info('Training')
        sess.run(train['run'], feed_dict={train['input_batch']: reader_strings}) 
開發者ID:rky0930,項目名稱:yolo_v2,代碼行數:28,代碼來源:graph_builder_test.py

示例9: testTaggerParserNanDeath

# 需要導入模塊: from dragnn.protos import spec_pb2 [as 別名]
# 或者: from dragnn.protos.spec_pb2 import GridPoint [as 別名]
def testTaggerParserNanDeath(self):
    hyperparam_config = spec_pb2.GridPoint()
    hyperparam_config.learning_rate = 1.0

    # The large learning rate should trigger check_numerics.
    with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
                                 'Cost is not finite'):
      self.RunFullTrainingAndInference(
          'tagger-parser',
          'tagger_parser_master_spec.textproto',
          hyperparam_config=hyperparam_config,
          component_weights=[0., 1., 1.],
          unroll_using_oracle=[False, True, True],
          expected_num_actions=12,
          expected=_TAGGER_PARSER_EXPECTED_SENTENCES) 
開發者ID:rky0930,項目名稱:yolo_v2,代碼行數:17,代碼來源:graph_builder_test.py

示例10: load_model

# 需要導入模塊: from dragnn.protos import spec_pb2 [as 別名]
# 或者: from dragnn.protos.spec_pb2 import GridPoint [as 別名]
def load_model(base_dir, master_spec_name, checkpoint_name):
    """
    Function to load the syntaxnet models. Highly specific to the tutorial
    format right now.
    """
    # Read the master spec
    master_spec = spec_pb2.MasterSpec()
    with open(os.path.join(base_dir, master_spec_name), "r") as f:
        text_format.Merge(f.read(), master_spec)
    spec_builder.complete_master_spec(master_spec, None, base_dir)
    logging.set_verbosity(logging.WARN)  # Turn off TensorFlow spam.

    # Initialize a graph
    graph = tf.Graph()
    with graph.as_default():
        hyperparam_config = spec_pb2.GridPoint()
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
        # This is the component that will annotate test sentences.
        annotator = builder.add_annotation(enable_tracing=True)
        builder.add_saver()  # "Savers" can save and load models; here, we're only going to load.

    sess = tf.Session(graph=graph)
    with graph.as_default():
        #sess.run(tf.global_variables_initializer())
        #sess.run('save/restore_all', {'save/Const:0': os.path.join(base_dir, checkpoint_name)})
        builder.saver.restore(sess, os.path.join(base_dir, checkpoint_name))

    def annotate_sentence(sentence):
        with graph.as_default():
            return sess.run([annotator['annotations'], annotator['traces']],
                            feed_dict={annotator['input_batch']: [sentence]})
    return annotate_sentence 
開發者ID:hltcoe,項目名稱:PredPatt,代碼行數:34,代碼來源:ParseyPredFace.py


注:本文中的dragnn.protos.spec_pb2.GridPoint方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。