当前位置: 首页>>代码示例>>Python>>正文


Python check.Eq方法代码示例

本文整理汇总了Python中syntaxnet.util.check.Eq方法的典型用法代码示例。如果您正苦于以下问题:Python check.Eq方法的具体用法?Python check.Eq怎么用?Python check.Eq使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在syntaxnet.util.check的用法示例。


在下文中一共展示了check.Eq方法的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from syntaxnet.util import check [as 别名]
# 或者: from syntaxnet.util.check import Eq [as 别名]
def __init__(self, component):
    """Initializes weights and layers.

    Args:
      component: Parent ComponentBuilderBase object.
    """
    super(GatherNetwork, self).__init__(component)
    self._attrs = get_attrs_with_defaults(
        component.spec.network_unit.parameters, {'trainable_padding': False})

    check.In('indices', self._linked_feature_dims,
             'Missing required linked feature')
    check.Eq(self._linked_feature_dims['indices'], 1,
             'Wrong dimension for "indices" feature')
    self._dim = self._concatenated_input_dim - 1  # exclude 'indices'
    self._layers.append(Layer(component, 'outputs', self._dim))

    if self._attrs['trainable_padding']:
      self._params.append(
          tf.get_variable(
              'pre_padding', [1, 1, self._dim],
              initializer=tf.random_normal_initializer(stddev=1e-4),
              dtype=tf.float32)) 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:25,代码来源:network_units.py

示例2: calculate_parse_metrics

# 需要导入模块: from syntaxnet.util import check [as 别名]
# 或者: from syntaxnet.util.check import Eq [as 别名]
def calculate_parse_metrics(gold_corpus, annotated_corpus):
  """Calculate POS/UAS/LAS accuracy based on gold and annotated sentences."""
  check.Eq(len(gold_corpus), len(annotated_corpus), 'Corpora are not aligned')
  num_tokens = 0
  num_correct_pos = 0
  num_correct_uas = 0
  num_correct_las = 0
  for gold_str, annotated_str in zip(gold_corpus, annotated_corpus):
    gold = sentence_pb2.Sentence()
    annotated = sentence_pb2.Sentence()
    gold.ParseFromString(gold_str)
    annotated.ParseFromString(annotated_str)
    check.Eq(gold.text, annotated.text, 'Text is not aligned')
    check.Eq(len(gold.token), len(annotated.token), 'Tokens are not aligned')
    tokens = zip(gold.token, annotated.token)
    num_tokens += len(tokens)
    num_correct_pos += sum(1 for x, y in tokens if x.tag == y.tag)
    num_correct_uas += sum(1 for x, y in tokens if x.head == y.head)
    num_correct_las += sum(1 for x, y in tokens
                           if x.head == y.head and x.label == y.label)

  tf.logging.info('Total num documents: %d', len(annotated_corpus))
  tf.logging.info('Total num tokens: %d', num_tokens)
  pos = num_correct_pos * 100.0 / num_tokens
  uas = num_correct_uas * 100.0 / num_tokens
  las = num_correct_las * 100.0 / num_tokens
  tf.logging.info('POS: %.2f%%', pos)
  tf.logging.info('UAS: %.2f%%', uas)
  tf.logging.info('LAS: %.2f%%', las)
  return pos, uas, las 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:32,代码来源:evaluation.py

示例3: create

# 需要导入模块: from syntaxnet.util import check [as 别名]
# 或者: from syntaxnet.util.check import Eq [as 别名]
def create(self,
             fixed_embeddings,
             linked_embeddings,
             context_tensor_arrays,
             attention_tensor,
             during_training,
             stride=None):
    """See base class."""
    check.Eq(len(self.layers), len(fixed_embeddings))
    for index in range(len(fixed_embeddings)):
      check.Eq(self.layers[index].name, fixed_embeddings[index].name)
    return [fixed_embedding.tensor for fixed_embedding in fixed_embeddings] 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:14,代码来源:network_units.py

示例4: extract_fixed_feature_ids

# 需要导入模块: from syntaxnet.util import check [as 别名]
# 或者: from syntaxnet.util.check import Eq [as 别名]
def extract_fixed_feature_ids(comp, state, stride):
  """Extracts fixed feature IDs.

  Args:
    comp: Component whose fixed feature IDs we wish to extract.
    state: Live MasterState object for the component.
    stride: Tensor containing current batch * beam size.

  Returns:
    state handle: Updated state handle to be used after this call.
    ids: List of [stride * num_steps, 1] feature IDs per channel.  Missing IDs
         (e.g., due to batch padding) are set to -1.
  """
  num_channels = len(comp.spec.fixed_feature)
  if not num_channels:
    return state.handle, []

  for feature_spec in comp.spec.fixed_feature:
    check.Eq(feature_spec.size, 1, 'All features must have size=1')
    check.Lt(feature_spec.embedding_dim, 0, 'All features must be non-embedded')

  state.handle, indices, ids, _, num_steps = dragnn_ops.bulk_fixed_features(
      state.handle, component=comp.name, num_channels=num_channels)
  size = stride * num_steps

  fixed_ids = []
  for channel, feature_spec in enumerate(comp.spec.fixed_feature):
    tf.logging.info('[%s] Adding fixed feature IDs "%s"', comp.name,
                    feature_spec.name)

    # The +1 and -1 increments ensure that missing IDs default to -1.
    #
    # TODO(googleuser): This formula breaks if multiple IDs are extracted at some
    # step.  Try using tf.unique() to enforce the unique-IDS precondition.
    sums = tf.unsorted_segment_sum(ids[channel] + 1, indices[channel], size) - 1
    sums = tf.expand_dims(sums, axis=1)
    fixed_ids.append(network_units.NamedTensor(sums, feature_spec.name, dim=1))
  return state.handle, fixed_ids 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:40,代码来源:bulk_component.py

示例5: __init__

# 需要导入模块: from syntaxnet.util import check [as 别名]
# 或者: from syntaxnet.util.check import Eq [as 别名]
def __init__(self, master, component_spec):
    """Initializes the feature ID extractor component.

    Args:
      master: dragnn.MasterBuilder object.
      component_spec: dragnn.ComponentSpec proto to be built.
    """
    super(BulkFeatureIdExtractorComponentBuilder, self).__init__(
        master, component_spec)
    check.Eq(len(self.spec.linked_feature), 0, 'Linked features are forbidden')
    for feature_spec in self.spec.fixed_feature:
      check.Lt(feature_spec.embedding_dim, 0,
               'Features must be non-embedded: %s' % feature_spec) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:15,代码来源:bulk_component.py

示例6: __init__

# 需要导入模块: from syntaxnet.util import check [as 别名]
# 或者: from syntaxnet.util.check import Eq [as 别名]
def __init__(self, component):
    """Initializes weights and layers.

    Args:
      component: Parent ComponentBuilderBase object.
    """
    super(BiaffineDigraphNetwork, self).__init__(component)

    check.Eq(len(self._fixed_feature_dims.items()), 0,
             'Expected no fixed features')
    check.Eq(len(self._linked_feature_dims.items()), 2,
             'Expected two linked features')

    check.In('sources', self._linked_feature_dims,
             'Missing required linked feature')
    check.In('targets', self._linked_feature_dims,
             'Missing required linked feature')
    self._source_dim = self._linked_feature_dims['sources']
    self._target_dim = self._linked_feature_dims['targets']

    # TODO(googleuser): Make parameter initialization configurable.
    self._weights = []
    self._weights.append(tf.get_variable(
        'weights_arc', [self._source_dim, self._target_dim], tf.float32,
        tf.random_normal_initializer(stddev=1e-4)))
    self._weights.append(tf.get_variable(
        'weights_source', [self._source_dim], tf.float32,
        tf.random_normal_initializer(stddev=1e-4)))
    self._weights.append(tf.get_variable(
        'root', [self._source_dim], tf.float32,
        tf.random_normal_initializer(stddev=1e-4)))

    self._params.extend(self._weights)
    self._regularized_weights.extend(self._weights)

    # Negative Layer.dim indicates that the dimension is dynamic.
    self._layers.append(network_units.Layer(self, 'adjacency', -1)) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:39,代码来源:biaffine_units.py

示例7: CombineArcAndRootPotentials

# 需要导入模块: from syntaxnet.util import check [as 别名]
# 或者: from syntaxnet.util.check import Eq [as 别名]
def CombineArcAndRootPotentials(arcs, roots):
  """Combines arc and root potentials into a single set of potentials.

  Args:
    arcs: [B,N,N] tensor of batched arc potentials.
    roots: [B,N] matrix of batched root potentials.

  Returns:
    [B,N,N] tensor P of combined potentials where
      P_{b,s,t} = s == t ? roots[b,t] : arcs[b,s,t]
  """
  # All arguments must have statically-known rank.
  check.Eq(arcs.get_shape().ndims, 3, 'arcs must be rank 3')
  check.Eq(roots.get_shape().ndims, 2, 'roots must be a matrix')

  # All arguments must share the same type.
  dtype = arcs.dtype.base_dtype
  check.Same([dtype, roots.dtype.base_dtype], 'dtype mismatch')

  roots_shape = tf.shape(roots)
  arcs_shape = tf.shape(arcs)
  batch_size = roots_shape[0]
  num_tokens = roots_shape[1]
  with tf.control_dependencies([
      tf.assert_equal(batch_size, arcs_shape[0]),
      tf.assert_equal(num_tokens, arcs_shape[1]),
      tf.assert_equal(num_tokens, arcs_shape[2])]):
    return tf.matrix_set_diag(arcs, roots) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:30,代码来源:digraph_ops.py

示例8: Eq

# 需要导入模块: from syntaxnet.util import check [as 别名]
# 或者: from syntaxnet.util.check import Eq [as 别名]
def Eq(lhs, rhs, message='', error=ValueError):
  """Raises an error if |lhs| does not equal |rhs|."""
  if lhs != rhs:
    raise error('Expected (%s) == (%s): %s' % (lhs, rhs, message)) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:6,代码来源:check.py

示例9: testCheckEq

# 需要导入模块: from syntaxnet.util import check [as 别名]
# 或者: from syntaxnet.util.check import Eq [as 别名]
def testCheckEq(self):
    check.Eq(1, 1, 'foo')
    with self.assertRaisesRegexp(ValueError, 'bar'):
      check.Eq(1, 2, 'bar')
    with self.assertRaisesRegexp(RuntimeError, 'baz'):
      check.Eq(1, 2, 'baz', RuntimeError) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:8,代码来源:check_test.py

示例10: __init__

# 需要导入模块: from syntaxnet.util import check [as 别名]
# 或者: from syntaxnet.util.check import Eq [as 别名]
def __init__(self, component):
    """Initializes weights and layers.

    Args:
      component: Parent ComponentBuilderBase object.
    """
    super(BiaffineDigraphNetwork, self).__init__(component)

    check.Eq(len(self._fixed_feature_dims.items()), 0,
             'Expected no fixed features')
    check.Eq(len(self._linked_feature_dims.items()), 2,
             'Expected two linked features')

    check.In('sources', self._linked_feature_dims,
             'Missing required linked feature')
    check.In('targets', self._linked_feature_dims,
             'Missing required linked feature')
    self._source_dim = self._linked_feature_dims['sources']
    self._target_dim = self._linked_feature_dims['targets']

    # TODO(googleuser): Make parameter initialization configurable.
    self._weights = []
    self._weights.append(tf.get_variable(
        'weights_arc', [self._source_dim, self._target_dim], tf.float32,
        tf.random_normal_initializer(stddev=1e-4)))
    self._weights.append(tf.get_variable(
        'weights_source', [self._source_dim], tf.float32,
        tf.random_normal_initializer(stddev=1e-4)))
    self._weights.append(tf.get_variable(
        'root', [self._source_dim], tf.float32,
        tf.random_normal_initializer(stddev=1e-4)))

    self._params.extend(self._weights)
    self._regularized_weights.extend(self._weights)

    # Negative Layer.dim indicates that the dimension is dynamic.
    self._layers.append(network_units.Layer(component, 'adjacency', -1)) 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:39,代码来源:biaffine_units.py

示例11: get_segmenter_corpus

# 需要导入模块: from syntaxnet.util import check [as 别名]
# 或者: from syntaxnet.util.check import Eq [as 别名]
def get_segmenter_corpus(input_data_path, use_text_format):
  """Reads in a character corpus for segmenting."""
  # Read in the documents.
  tf.logging.info('Reading documents...')
  if use_text_format:
    char_corpus = sentence_io.FormatSentenceReader(input_data_path,
                                                   'untokenized-text').corpus()
  else:
    input_corpus = sentence_io.ConllSentenceReader(input_data_path).corpus()
    with tf.Session(graph=tf.Graph()) as tmp_session:
      char_input = gen_parser_ops.char_token_generator(input_corpus)
      char_corpus = tmp_session.run(char_input)
    check.Eq(len(input_corpus), len(char_corpus))

  return char_corpus 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:17,代码来源:parse_to_conll.py


注:本文中的syntaxnet.util.check.Eq方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。