本文整理汇总了Python中tensorflow.python.data.experimental.ops.map_defun.map_defun函数的典型用法代码示例。如果您正苦于以下问题:Python map_defun函数的具体用法?Python map_defun怎么用?Python map_defun使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了map_defun函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testMapDefunWithInvalidInput
def testMapDefunWithInvalidInput(self):
@function.Defun(dtypes.int32)
def simple_fn(x):
return x * 2
c = constant_op.constant(2)
with self.assertRaises(ValueError):
# Fails at graph construction time for inputs with known shapes.
r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0]
p = array_ops.placeholder(dtypes.int32)
r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0]
with session.Session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(r, feed_dict={p: 0})
示例2: testMapDefunPartialShapeInference
def testMapDefunPartialShapeInference(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
def fn(x):
return x
elems = array_ops.placeholder(dtypes.int64, (None, 2))
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])
self.assertEqual(result[0].get_shape().as_list(), [None, 2])
示例3: testMapDefunShapeInference
def testMapDefunShapeInference(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
def fn(x):
return x
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0]
self.assertEqual(result.get_shape(), (3, 2))
示例4: testMapDefunWithWrongOutputShape
def testMapDefunWithWrongOutputShape(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
def simple_fn(x):
return x * 2 + 3
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1,)])[0]
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(r)
示例5: testMapDefunSimple
def testMapDefunSimple(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
def simple_fn(x):
return x * 2 + 3
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0]
expected = elems * 2 + 3
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
示例6: testMapDefunMismatchedTypes
def testMapDefunMismatchedTypes(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def fn(x):
return math_ops.cast(x, dtypes.float64)
nums = [1, 2, 3, 4, 5, 6]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(r)
示例7: testMapDefunRaisesDefunError
def testMapDefunRaisesDefunError(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def fn(x):
with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
return array_ops.identity(x)
elems = constant_op.constant([0, 0, 0, 37, 0])
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(result)
示例8: testMapDefunWithCapturedInputs
def testMapDefunWithCapturedInputs(self):
c = constant_op.constant(2)
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def fn(x):
return x + c
x = constant_op.constant([1, 2, 3, 4])
map_defun_op = map_defun.map_defun(fn, [x], [dtypes.int32], [()])[0]
expected = x + c
self.assertAllEqual(self.evaluate(expected), self.evaluate(map_defun_op))
示例9: testMapDefunReduceDim
def testMapDefunReduceDim(self):
# Tests where the output has a different rank from the input
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
def fn(x):
return array_ops.gather(x, 0)
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
expected = constant_op.constant([1, 3, 5])
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
示例10: testMapDefunWithDifferentOutputShapeEachRun
def testMapDefunWithDifferentOutputShapeEachRun(self):
@function.Defun(dtypes.int32)
def simple_fn(x):
return x * 2 + 3
elems = array_ops.placeholder(dtypes.int32, name="data")
r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0]
with session.Session() as sess:
self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3])
self.assertAllEqual(
sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]])
示例11: testMapDefunMultipleOutputs
def testMapDefunMultipleOutputs(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
def fn(x):
return (x, math_ops.cast(x * 2 + 3, dtypes.float64))
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], [(2,),
(2,)])
expected = [elems, elems * 2 + 3]
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
示例12: testMapDefunRaisesErrorOnRuntimeShapeMismatch
def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self):
@function.Defun(dtypes.int32, dtypes.int32)
def fn(x, y):
return x, y
elems1 = array_ops.placeholder(dtypes.int32)
elems2 = array_ops.placeholder(dtypes.int32)
result = map_defun.map_defun(fn, [elems1, elems2],
[dtypes.int32, dtypes.int32], [(), ()])
with self.cached_session() as sess:
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
"All inputs must have the same dimension 0."):
sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]})
示例13: testMapDefunCancelledCorrectly
def testMapDefunCancelledCorrectly(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([5], dtypes.int64)])
def defun(x):
# x has leading dimension 5, this will raise an error
return array_ops.gather(x, 10)
c = array_ops.tile(
array_ops.expand_dims(
constant_op.constant([1, 2, 3, 4, 5], dtype=dtypes.int64), 0),
[100, 1])
map_defun_op = map_defun.map_defun(defun, [c], [dtypes.int64], [()])[0]
with self.assertRaisesRegexp(errors.InvalidArgumentError,
r"indices = 10 is not in \[0, 5\)"):
self.evaluate(map_defun_op)
示例14: testMapDefunWithUnspecifiedOutputShape
def testMapDefunWithUnspecifiedOutputShape(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
def simple_fn(x):
res = x * 2 + 3
return (res, res + 1, res + 2)
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(simple_fn, [elems],
[dtypes.int32, dtypes.int32, dtypes.int32],
[None, (None,), (2,)])
expected = elems * 2 + 3
self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected))
self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1))
self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
示例15: testMapDefunWithVariantTensorAsCaptured
def testMapDefunWithVariantTensorAsCaptured(self):
st = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.variant)
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def fn(x):
del x
return serialized
x = constant_op.constant([0, 0])
map_defun_op = map_defun.map_defun(fn, [x], [dtypes.variant], [None])[0]
deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]],
values=[1, 2, 1, 2],
dense_shape=[2, 3, 4])
actual = self.evaluate(deserialized)
self.assertSparseValuesEqual(expected, actual)