本文整理汇总了Python中tensorflow.python.ops.list_ops.tensor_list_set_item函数的典型用法代码示例。如果您正苦于以下问题:Python tensor_list_set_item函数的具体用法?Python tensor_list_set_item怎么用?Python tensor_list_set_item使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了tensor_list_set_item函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testAddTensorLists
def testAddTensorLists(self):
with self.cached_session(), self.test_scope():
l1 = list_ops.tensor_list_reserve(
element_shape=[], element_dtype=dtypes.float32, num_elements=3)
l2 = list_ops.tensor_list_reserve(
element_shape=[], element_dtype=dtypes.float32, num_elements=3)
l1 = list_ops.tensor_list_set_item(l1, 0, 5.)
l2 = list_ops.tensor_list_set_item(l2, 2, 10.)
l = math_ops.add_n([l1, l2])
self.assertAllEqual(
list_ops.tensor_list_stack(l, element_dtype=dtypes.float32),
[5.0, 0.0, 10.0])
示例2: testSerializeListWithInvalidTensors
def testSerializeListWithInvalidTensors(self):
worker = test_util.create_local_cluster(num_workers=1, num_ps=1)[0][0]
with ops.Graph().as_default(), session.Session(target=worker.target):
with ops.device("/job:worker"):
l = list_ops.tensor_list_reserve(
element_dtype=dtypes.float32, element_shape=[], num_elements=2)
l = list_ops.tensor_list_set_item(l, 0, 1.)
with ops.device("/job:ps"):
l_ps = array_ops.identity(l)
l_ps = list_ops.tensor_list_set_item(l_ps, 1, 2.)
t = list_ops.tensor_list_stack(l_ps, element_dtype=dtypes.float32)
with ops.device("/job:worker"):
worker_t = array_ops.identity(t)
self.assertAllEqual(self.evaluate(worker_t), [1.0, 2.0])
示例3: testSetStackReservedUnknownElementShape
def testSetStackReservedUnknownElementShape(self):
with self.cached_session(), self.test_scope():
l = list_ops.tensor_list_reserve(
element_dtype=dtypes.float32, element_shape=None, num_elements=2)
l = list_ops.tensor_list_set_item(l, 0, [3.0, 4.0])
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(t, [[3.0, 4.0], [0., 0.]])
示例4: testZerosLikeUninitialized
def testZerosLikeUninitialized(self):
l0 = list_ops.tensor_list_reserve([], 3, element_dtype=dtypes.float32)
l1 = list_ops.tensor_list_set_item(l0, 0, 1.) # [1., _, _]
zeros_1 = array_ops.zeros_like(l1) # [0., _, _]
l2 = list_ops.tensor_list_set_item(l1, 2, 2.) # [1., _, 2.]
zeros_2 = array_ops.zeros_like(l2) # [0., _, 0.]
# Gather indices with zeros in `zeros_1`.
res_1 = list_ops.tensor_list_gather(
zeros_1, [0], element_dtype=dtypes.float32)
# Gather indices with zeros in `zeros_2`.
res_2 = list_ops.tensor_list_gather(
zeros_2, [0, 2], element_dtype=dtypes.float32)
self.assertAllEqual(self.evaluate(res_1), [0.])
self.assertAllEqual(self.evaluate(res_2), [0., 0.])
示例5: testSetOnEmptyListWithMaxNumElementsFails
def testSetOnEmptyListWithMaxNumElementsFails(self):
l = list_ops.empty_tensor_list(
element_dtype=dtypes.float32, element_shape=[], max_num_elements=3)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Trying to modify element 0 in a list with 0 elements."):
l = list_ops.tensor_list_set_item(l, 0, 1.)
self.evaluate(l)
示例6: testGetSetItem
def testGetSetItem(self):
t = constant_op.constant([1.0, 2.0])
l = list_ops.tensor_list_from_tensor(t, element_shape=[])
e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
self.assertAllEqual(self.evaluate(e0), 1.0)
l = list_ops.tensor_list_set_item(l, 0, 3.0)
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(self.evaluate(t), [3.0, 2.0])
示例7: testGetSetReserved
def testGetSetReserved(self):
with self.cached_session(), self.test_scope():
l = list_ops.tensor_list_reserve(
element_dtype=dtypes.float32, element_shape=[], num_elements=2)
e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
self.assertAllEqual(e0, 0.0)
l = list_ops.tensor_list_set_item(l, 0, 3.0)
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(t, [3.0, 0.0])
示例8: testGetSet
def testGetSet(self):
with self.cached_session(), self.test_scope():
t = constant_op.constant([1.0, 2.0])
l = list_ops.tensor_list_from_tensor(t, element_shape=[])
e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
self.assertAllEqual(e0, 1.0)
l = list_ops.tensor_list_set_item(l, 0, 3.0)
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(t, [3.0, 2.0])
示例9: testSetGetGrad
def testSetGetGrad(self):
with backprop.GradientTape() as tape:
t = constant_op.constant(5.)
tape.watch(t)
l = list_ops.tensor_list_reserve(
element_dtype=dtypes.float32, element_shape=[], num_elements=3)
l = list_ops.tensor_list_set_item(l, 1, 2. * t)
e = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32)
self.assertAllEqual(self.evaluate(e), 10.0)
self.assertAllEqual(self.evaluate(tape.gradient(e, t)), 2.0)
示例10: testSetDoesNotUpdatePushIndex
def testSetDoesNotUpdatePushIndex(self):
with self.cached_session(), self.test_scope():
l = list_ops.empty_tensor_list(
element_shape=[], element_dtype=dtypes.float32, max_num_elements=2)
# SetItem should not change the push index.
l = list_ops.tensor_list_set_item(l, 1, 3.)
l = list_ops.tensor_list_push_back(l, 5.)
l = list_ops.tensor_list_push_back(l, 7.)
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(t, [5., 7.])
示例11: testSkipEagerTensorListGetItemGradAggregation
def testSkipEagerTensorListGetItemGradAggregation(self):
l = list_ops.tensor_list_reserve(
element_shape=[], num_elements=1, element_dtype=dtypes.float32)
x = constant_op.constant(1.0)
l = list_ops.tensor_list_set_item(l, 0, x)
l_read1 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
l_read2 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
grad = gradients_impl.gradients([l_read1, l_read2], [x])
with self.cached_session() as sess:
self.assertSequenceEqual(self.evaluate(grad), [2.])
示例12: write
def write(self, index, value, name=None):
"""See TensorArray."""
with ops.name_scope(name, "TensorArrayV2Write", [self._flow, index, value]):
value = ops.convert_to_tensor(value, name="value")
if self._infer_shape:
self._merge_element_shape(value.shape)
flow_out = list_ops.tensor_list_set_item(
input_handle=self._flow, index=index, item=value, name=name)
ta = TensorArray(dtype=self._dtype, handle=None, flow=flow_out)
ta._infer_shape = self._infer_shape
ta._element_shape = self._element_shape
return ta
示例13: testGetSetReservedNonScalar
def testGetSetReservedNonScalar(self):
with self.cached_session() as sess, self.test_scope():
l = list_ops.tensor_list_reserve(
element_dtype=dtypes.float32,
element_shape=(7, 15),
num_elements=2)
l = list_ops.tensor_list_set_item(
l, 0, constant_op.constant(1.0, shape=(7, 15)))
e1 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
e2 = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32)
self.assertAllEqual(sess.run(e1), np.ones((7, 15)))
self.assertAllEqual(sess.run(e2), np.zeros((7, 15)))
示例14: testSkipEagerSetItemWithMismatchedShapeFails
def testSkipEagerSetItemWithMismatchedShapeFails(self):
with self.cached_session() as sess:
ph = array_ops.placeholder(dtypes.float32)
c = constant_op.constant([1.0, 2.0])
l = list_ops.tensor_list_from_tensor(c, element_shape=[])
# Set a placeholder with unknown shape to satisfy the shape inference
# at graph building time.
l = list_ops.tensor_list_set_item(l, 0, ph)
l_0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"incompatible shape"):
sess.run(l_0, {ph: [3.0]})
示例15: write
def write(self, index, value, name=None):
"""See TensorArray."""
with ops.name_scope(name, "TensorArrayV2Write", [self._flow, index, value]):
value = ops.convert_to_tensor(value, name="value")
if self._infer_shape:
self._merge_element_shape(value.shape)
flow_out = list_ops.tensor_list_set_item(
input_handle=self._flow,
index=index,
item=value,
resize_if_index_out_of_bounds=self._dynamic_size,
name=name)
return build_ta_with_new_flow(self, flow_out)