当前位置: 首页>>代码示例>>Python>>正文


Python state_ops.scatter_nd_update函数代码示例

本文整理汇总了Python中tensorflow.python.ops.state_ops.scatter_nd_update函数的典型用法代码示例。如果您正苦于以下问题:Python scatter_nd_update函数的具体用法?Python scatter_nd_update怎么用?Python scatter_nd_update使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了scatter_nd_update函数的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testRank3InvalidShape2

 def testRank3InvalidShape2(self):
   indices = array_ops.zeros([2, 2, 1], dtypes.int32)
   updates = array_ops.zeros([2, 2], dtypes.int32)
   shape = np.array([2, 2, 2])
   ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
   with self.assertRaisesWithPredicateMatch(
       ValueError, "The inner \\d+ dimensions of input\\.shape="):
     state_ops.scatter_nd_update(ref, indices, updates)
开发者ID:abhinav-upadhyay,项目名称:tensorflow,代码行数:8,代码来源:scatter_nd_ops_test.py

示例2: testResVarInvalidOutputShape

 def testResVarInvalidOutputShape(self):
   res = variables.Variable(
       initial_value=lambda: array_ops.zeros(shape=[], dtype=dtypes.float32),
       dtype=dtypes.float32)
   with self.cached_session():
     res.initializer.run()
     with self.assertRaisesOpError("Output must be at least 1-D"):
       state_ops.scatter_nd_update(res, [[0]], [0.22]).eval()
开发者ID:abhinav-upadhyay,项目名称:tensorflow,代码行数:8,代码来源:scatter_nd_ops_test.py

示例3: testRank3ValidShape

 def testRank3ValidShape(self):
   indices = array_ops.zeros([2, 2, 2], dtypes.int32)
   updates = array_ops.zeros([2, 2, 2], dtypes.int32)
   shape = np.array([2, 2, 2])
   ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
   self.assertAllEqual(
       state_ops.scatter_nd_update(ref, indices,
                                   updates).get_shape().as_list(), shape)
开发者ID:abhinav-upadhyay,项目名称:tensorflow,代码行数:8,代码来源:scatter_nd_ops_test.py

示例4: testExtraIndicesDimensions

  def testExtraIndicesDimensions(self):
    indices = array_ops.zeros([1, 1, 2], dtypes.int32)
    updates = array_ops.zeros([1, 1], dtypes.int32)
    shape = np.array([2, 2])
    ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
    scatter_update = state_ops.scatter_nd_update(ref, indices, updates)
    self.assertAllEqual(scatter_update.get_shape().as_list(), shape)

    expected_result = np.zeros([2, 2], dtype=np.int32)
    with self.cached_session():
      ref.initializer.run()
      self.assertAllEqual(expected_result, scatter_update.eval())
开发者ID:abhinav-upadhyay,项目名称:tensorflow,代码行数:12,代码来源:scatter_nd_ops_test.py

示例5: testSimple

  def testSimple(self):
    indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
    updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32)
    ref = variables.Variable([0, 0, 0, 0, 0, 0, 0, 0], dtype=dtypes.float32)
    expected = np.array([0, 11, 0, 10, 9, 0, 0, 12])
    scatter = state_ops.scatter_nd_update(ref, indices, updates)
    init = variables.global_variables_initializer()

    with self.session(use_gpu=True) as sess:
      sess.run(init)
      result = sess.run(scatter)
      self.assertAllClose(result, expected)
开发者ID:abhinav-upadhyay,项目名称:tensorflow,代码行数:12,代码来源:scatter_nd_ops_test.py

示例6: testSimple3

  def testSimple3(self):
    indices = constant_op.constant([[1]], dtype=dtypes.int32)
    updates = constant_op.constant([[11., 12.]], dtype=dtypes.float32)
    ref = variables.Variable(
        [[0., 0.], [0., 0.], [0., 0.]], dtype=dtypes.float32)
    expected = np.array([[0., 0.], [11., 12.], [0., 0.]])
    scatter = state_ops.scatter_nd_update(ref, indices, updates)
    init = variables.global_variables_initializer()

    with self.test_session(use_gpu=True) as sess:
      sess.run(init)
      result = sess.run(scatter)
      self.assertAllClose(result, expected)
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:13,代码来源:scatter_nd_ops_test.py


注:本文中的tensorflow.python.ops.state_ops.scatter_nd_update函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。