当前位置: 首页>>代码示例>>Python>>正文


Python check.Same方法代码示例

本文整理汇总了Python中syntaxnet.util.check.Same方法的典型用法代码示例。如果您正苦于以下问题:Python check.Same方法的具体用法?Python check.Same怎么用?Python check.Same使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在syntaxnet.util.check的用法示例。


在下文中一共展示了check.Same方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: CombineArcAndRootPotentials

# 需要导入模块: from syntaxnet.util import check [as 别名]
# 或者: from syntaxnet.util.check import Same [as 别名]
def CombineArcAndRootPotentials(arcs, roots):
  """Combines arc and root potentials into a single set of potentials.

  Args:
    arcs: [B,N,N] tensor of batched arc potentials.
    roots: [B,N] matrix of batched root potentials.

  Returns:
    [B,N,N] tensor P of combined potentials where
      P_{b,s,t} = s == t ? roots[b,t] : arcs[b,s,t]
  """
  # All arguments must have statically-known rank.
  check.Eq(arcs.get_shape().ndims, 3, 'arcs must be rank 3')
  check.Eq(roots.get_shape().ndims, 2, 'roots must be a matrix')

  # All arguments must share the same type.
  dtype = arcs.dtype.base_dtype
  check.Same([dtype, roots.dtype.base_dtype], 'dtype mismatch')

  roots_shape = tf.shape(roots)
  arcs_shape = tf.shape(arcs)
  batch_size = roots_shape[0]
  num_tokens = roots_shape[1]
  with tf.control_dependencies([
      tf.assert_equal(batch_size, arcs_shape[0]),
      tf.assert_equal(num_tokens, arcs_shape[1]),
      tf.assert_equal(num_tokens, arcs_shape[2])]):
    return tf.matrix_set_diag(arcs, roots) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:30,代码来源:digraph_ops.py

示例2: testCheckSame

# 需要导入模块: from syntaxnet.util import check [as 别名]
# 或者: from syntaxnet.util.check import Same [as 别名]
def testCheckSame(self):
    check.Same([], 'foo')  # empty OK
    check.Same([1, 1, 1.0, 1.0, 1], 'foo')
    check.Same(['hello', 'hello'], 'foo')
    with self.assertRaisesRegexp(ValueError, 'bar'):
      check.Same(['hello', 'world'], 'bar')
    with self.assertRaisesRegexp(RuntimeError, 'baz'):
      check.Same([1, 1.1], 'baz', RuntimeError) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:10,代码来源:check_test.py

示例3: ArcSourcePotentialsFromTokens

# 需要导入模块: from syntaxnet.util import check [as 别名]
# 或者: from syntaxnet.util.check import Same [as 别名]
def ArcSourcePotentialsFromTokens(tokens, weights):
  r"""Returns arc source potentials computed from tokens and weights.

  For each batch of token activations, computes a scalar potential for each arc
  as the product between the activations of the source token and the |weights|.
  Specifically,

    arc[b,s,:] = \sum_{i} weights[i] * tokens[b,s,i]

  Args:
    tokens: [B,N,S] tensor of batched activations for source tokens.
    weights: [S] vector of weights.

    B,N may be statically-unknown, but S must be statically-known.  The dtype of
    all arguments must be compatible.

  Returns:
    [B,N,N] tensor A of arc potentials as defined above.  The dtype of A is the
    same as that of the arguments.  Note that the diagonal entries (i.e., where
    s==t) represent self-loops and may not be meaningful.
  """
  # All arguments must have statically-known rank.
  check.Eq(tokens.get_shape().ndims, 3, 'tokens must be rank 3')
  check.Eq(weights.get_shape().ndims, 1, 'weights must be a vector')

  # All activation dimensions must be statically-known.
  num_source_activations = weights.get_shape().as_list()[0]
  check.NotNone(num_source_activations, 'unknown source activation dimension')
  check.Eq(tokens.get_shape().as_list()[2], num_source_activations,
           'dimension mismatch between weights and tokens')

  # All arguments must share the same type.
  check.Same([weights.dtype.base_dtype,
              tokens.dtype.base_dtype],
             'dtype mismatch')

  tokens_shape = tf.shape(tokens)
  batch_size = tokens_shape[0]
  num_tokens = tokens_shape[1]

  # Flatten out the batch dimension so we can use a couple big matmuls.
  tokens_bnxs = tf.reshape(tokens, [-1, num_source_activations])
  weights_sx1 = tf.expand_dims(weights, 1)
  sources_bnx1 = tf.matmul(tokens_bnxs, weights_sx1)
  sources_bnxn = tf.tile(sources_bnx1, [1, num_tokens])

  # Restore the batch dimension in the output.
  sources_bxnxn = tf.reshape(sources_bnxn, [batch_size, num_tokens, num_tokens])
  return sources_bxnxn 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:51,代码来源:digraph_ops.py

示例4: RootPotentialsFromTokens

# 需要导入模块: from syntaxnet.util import check [as 别名]
# 或者: from syntaxnet.util.check import Same [as 别名]
def RootPotentialsFromTokens(root, tokens, weights):
  r"""Returns root selection potentials computed from tokens and weights.

  For each batch of token activations, computes a scalar potential for each root
  selection as the 3-way product between the activations of the artificial root
  token, the token activations, and the |weights|.  Specifically,

    roots[b,r] = \sum_{i,j} root[i] * weights[i,j] * tokens[b,r,j]

  Args:
    root: [S] vector of activations for the artificial root token.
    tokens: [B,N,T] tensor of batched activations for root tokens.
    weights: [S,T] matrix of weights.

    B,N may be statically-unknown, but S,T must be statically-known.  The dtype
    of all arguments must be compatible.

  Returns:
    [B,N] matrix R of root-selection potentials as defined above.  The dtype of
    R is the same as that of the arguments.
  """
  # All arguments must have statically-known rank.
  check.Eq(root.get_shape().ndims, 1, 'root must be a vector')
  check.Eq(tokens.get_shape().ndims, 3, 'tokens must be rank 3')
  check.Eq(weights.get_shape().ndims, 2, 'weights must be a matrix')

  # All activation dimensions must be statically-known.
  num_source_activations = weights.get_shape().as_list()[0]
  num_target_activations = weights.get_shape().as_list()[1]
  check.NotNone(num_source_activations, 'unknown source activation dimension')
  check.NotNone(num_target_activations, 'unknown target activation dimension')
  check.Eq(root.get_shape().as_list()[0], num_source_activations,
           'dimension mismatch between weights and root')
  check.Eq(tokens.get_shape().as_list()[2], num_target_activations,
           'dimension mismatch between weights and tokens')

  # All arguments must share the same type.
  check.Same([weights.dtype.base_dtype,
              root.dtype.base_dtype,
              tokens.dtype.base_dtype],
             'dtype mismatch')

  root_1xs = tf.expand_dims(root, 0)

  tokens_shape = tf.shape(tokens)
  batch_size = tokens_shape[0]
  num_tokens = tokens_shape[1]

  # Flatten out the batch dimension so we can use a couple big matmuls.
  tokens_bnxt = tf.reshape(tokens, [-1, num_target_activations])
  weights_targets_bnxs = tf.matmul(tokens_bnxt, weights, transpose_b=True)
  roots_1xbn = tf.matmul(root_1xs, weights_targets_bnxs, transpose_b=True)

  # Restore the batch dimension in the output.
  roots_bxn = tf.reshape(roots_1xbn, [batch_size, num_tokens])
  return roots_bxn 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:58,代码来源:digraph_ops.py

示例5: LabelPotentialsFromTokens

# 需要导入模块: from syntaxnet.util import check [as 别名]
# 或者: from syntaxnet.util.check import Same [as 别名]
def LabelPotentialsFromTokens(tokens, weights):
  r"""Computes label potentials from tokens and weights.

  For each batch of token activations, computes a scalar potential for each
  label as the product between the activations of the source token and the
  |weights|.  Specifically,

    labels[b,t,l] = \sum_{i} weights[l,i] * tokens[b,t,i]

  Args:
    tokens: [B,N,T] tensor of batched token activations.
    weights: [L,T] matrix of weights.

    B,N may be dynamic, but L,T must be static.  The dtype of all arguments must
    be compatible.

  Returns:
    [B,N,L] tensor of label potentials as defined above, with the same dtype as
    the arguments.
  """
  check.Eq(tokens.get_shape().ndims, 3, 'tokens must be rank 3')
  check.Eq(weights.get_shape().ndims, 2, 'weights must be a matrix')

  num_labels = weights.get_shape().as_list()[0]
  num_activations = weights.get_shape().as_list()[1]
  check.NotNone(num_labels, 'unknown number of labels')
  check.NotNone(num_activations, 'unknown activation dimension')
  check.Eq(tokens.get_shape().as_list()[2], num_activations,
           'activation mismatch between weights and tokens')
  tokens_shape = tf.shape(tokens)
  batch_size = tokens_shape[0]
  num_tokens = tokens_shape[1]

  check.Same([tokens.dtype.base_dtype,
              weights.dtype.base_dtype],
             'dtype mismatch')

  # Flatten out the batch dimension so we can use one big matmul().
  tokens_bnxt = tf.reshape(tokens, [-1, num_activations])
  labels_bnxl = tf.matmul(tokens_bnxt, weights, transpose_b=True)

  # Restore the batch dimension in the output.
  labels_bxnxl = tf.reshape(labels_bnxl, [batch_size, num_tokens, num_labels])
  return labels_bxnxl 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:46,代码来源:digraph_ops.py


注:本文中的syntaxnet.util.check.Same方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。