本文整理汇总了Python中tensorflow.python.ops.list_ops.tensor_list_get_item函数的典型用法代码示例。如果您正苦于以下问题:Python tensor_list_get_item函数的具体用法?Python tensor_list_get_item怎么用?Python tensor_list_get_item使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了tensor_list_get_item函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testSkipEagerTensorListGetItemGradAggregation
def testSkipEagerTensorListGetItemGradAggregation(self):
l = list_ops.tensor_list_reserve(
element_shape=[], num_elements=1, element_dtype=dtypes.float32)
x = constant_op.constant(1.0)
l = list_ops.tensor_list_set_item(l, 0, x)
l_read1 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
l_read2 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
grad = gradients_impl.gradients([l_read1, l_read2], [x])
with self.cached_session() as sess:
self.assertSequenceEqual(self.evaluate(grad), [2.])
示例2: testUnevenSplit
def testUnevenSplit(self):
l = list_ops.tensor_list_split([1., 2., 3., 4., 5],
element_shape=None,
lengths=[3, 2])
self.assertAllEqual(list_ops.tensor_list_length(l), 2)
self.assertAllEqual(
list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32),
[1., 2., 3.])
self.assertAllEqual(
list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32),
[4., 5.])
示例3: testGetSetReservedNonScalar
def testGetSetReservedNonScalar(self):
with self.cached_session() as sess, self.test_scope():
l = list_ops.tensor_list_reserve(
element_dtype=dtypes.float32,
element_shape=(7, 15),
num_elements=2)
l = list_ops.tensor_list_set_item(
l, 0, constant_op.constant(1.0, shape=(7, 15)))
e1 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
e2 = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32)
self.assertAllEqual(sess.run(e1), np.ones((7, 15)))
self.assertAllEqual(sess.run(e2), np.zeros((7, 15)))
示例4: testScatterGrad
def testScatterGrad(self):
with backprop.GradientTape() as tape:
c0 = constant_op.constant([1.0, 2.0])
tape.watch(c0)
l = list_ops.tensor_list_scatter(
c0, [1, 0], ops.convert_to_tensor([], dtype=dtypes.int32))
t0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
t1 = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32)
self.assertAllEqual(self.evaluate(t0), 2.0)
self.assertAllEqual(self.evaluate(t1), 1.0)
loss = t0 * t0 + t1 * t1
dt = tape.gradient(loss, c0)
self.assertAllEqual(self.evaluate(dt), [2., 4.])
示例5: testGetSetGradients
def testGetSetGradients(self):
with backprop.GradientTape() as tape:
c = constant_op.constant([1.0, 2.0])
tape.watch(c)
l = list_ops.tensor_list_from_tensor(c, element_shape=[])
c2 = constant_op.constant(3.0)
tape.watch(c2)
l = list_ops.tensor_list_set_item(l, 0, c2)
e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
ee = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32)
y = e * e + ee * ee
grad_c, grad_c2 = tape.gradient(y, [c, c2])
self.assertAllEqual(self.evaluate(grad_c), [0.0, 4.0])
self.assertAllEqual(self.evaluate(grad_c2), 6.0)
示例6: _tf_tensor_list_get_item
def _tf_tensor_list_get_item(target, i, opts):
"""Overload of get_item that stages a Tensor list read."""
if opts.element_dtype is None:
raise ValueError('cannot retrieve from a list without knowing its '
'element type; use set_element_type to annotate it')
x = list_ops.tensor_list_get_item(target, i, element_dtype=opts.element_dtype)
return x
示例7: testGetSetItem
def testGetSetItem(self):
t = constant_op.constant([1.0, 2.0])
l = list_ops.tensor_list_from_tensor(t, element_shape=[])
e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
self.assertAllEqual(self.evaluate(e0), 1.0)
l = list_ops.tensor_list_set_item(l, 0, 3.0)
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(self.evaluate(t), [3.0, 2.0])
示例8: testGetSetReserved
def testGetSetReserved(self):
with self.cached_session(), self.test_scope():
l = list_ops.tensor_list_reserve(
element_dtype=dtypes.float32, element_shape=[], num_elements=2)
e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
self.assertAllEqual(e0, 0.0)
l = list_ops.tensor_list_set_item(l, 0, 3.0)
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(t, [3.0, 0.0])
示例9: testGetSet
def testGetSet(self):
with self.cached_session(), self.test_scope():
t = constant_op.constant([1.0, 2.0])
l = list_ops.tensor_list_from_tensor(t, element_shape=[])
e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
self.assertAllEqual(e0, 1.0)
l = list_ops.tensor_list_set_item(l, 0, 3.0)
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(t, [3.0, 2.0])
示例10: testSetGetGrad
def testSetGetGrad(self):
with backprop.GradientTape() as tape:
t = constant_op.constant(5.)
tape.watch(t)
l = list_ops.tensor_list_reserve(
element_dtype=dtypes.float32, element_shape=[], num_elements=3)
l = list_ops.tensor_list_set_item(l, 1, 2. * t)
e = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32)
self.assertAllEqual(self.evaluate(e), 10.0)
self.assertAllEqual(self.evaluate(tape.gradient(e, t)), 2.0)
示例11: read
def read(self, index, name=None):
"""See TensorArray."""
value = list_ops.tensor_list_get_item(
input_handle=self._flow,
index=index,
element_dtype=self._dtype,
name=name)
if self._element_shape:
value.set_shape(self._element_shape[0].dims)
return value
示例12: testListFromTensor
def testListFromTensor(self):
with self.cached_session(), self.test_scope():
t = constant_op.constant([1.0, 2.0])
l = list_ops.tensor_list_from_tensor(t, element_shape=[])
e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
self.assertAllEqual(e, 1.0)
l, e0 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
self.assertAllEqual(e0, 2.0)
l, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
self.assertAllEqual(e1, 1.0)
self.assertAllEqual(list_ops.tensor_list_length(l), 0)
示例13: testSkipEagerSetItemWithMismatchedShapeFails
def testSkipEagerSetItemWithMismatchedShapeFails(self):
with self.cached_session() as sess:
ph = array_ops.placeholder(dtypes.float32)
c = constant_op.constant([1.0, 2.0])
l = list_ops.tensor_list_from_tensor(c, element_shape=[])
# Set a placeholder with unknown shape to satisfy the shape inference
# at graph building time.
l = list_ops.tensor_list_set_item(l, 0, ph)
l_0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"incompatible shape"):
sess.run(l_0, {ph: [3.0]})
示例14: testScatterUpdateVariant
def testScatterUpdateVariant(self):
with context.eager_mode():
v = resource_variable_ops.ResourceVariable([
list_ops.empty_tensor_list(
element_dtype=dtypes.float32, element_shape=[])
])
v.scatter_update(
ops.IndexedSlices(
list_ops.tensor_list_from_tensor([1., 2.], element_shape=[]), 0))
self.assertAllEqual(
list_ops.tensor_list_get_item(v[0], 0, element_dtype=dtypes.float32),
1.)
示例15: testStack
def testStack(self):
with self.cached_session(), self.test_scope():
l = list_ops.empty_tensor_list(
element_dtype=dtypes.float32,
element_shape=[],
max_num_elements=2)
l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
self.assertAllEqual(e, 1.0)
l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(t.shape.as_list(), [None])
self.assertAllEqual(t, [1.0, 2.0])