本文整理汇总了Python中tensorflow.python.framework.ops.convert_n_to_tensor函数的典型用法代码示例。如果您正苦于以下问题:Python convert_n_to_tensor函数的具体用法?Python convert_n_to_tensor怎么用?Python convert_n_to_tensor使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了convert_n_to_tensor函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: piecewise_constant
def piecewise_constant(x, boundaries, values, name=None):
""" Piecewise constant from boundaries and interval values.
Example: use a learning rate that's 1.0 for the first 100000 steps, 0.5
for steps 100001 to 110000, and 0.1 for any additional steps.
```python
global_step = tf.Variable(0, trainable=False)
boundaries = [100000, 110000]
values = [1.0, 0.5, 0.1]
learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)
# Later, whenever we perform an optimization step, we increment global_step.
```
Args:
x: A 0-D scalar `Tensor`. Must be one of the following types: `float32`,
`float64`, `uint8`, `int8`, `int16`, `int32`, `int64`.
boundaries: A list of `Tensor`s or `int`s or `float`s with strictly
increasing entries, and with all elements having the same type as `x`.
values: A list of `Tensor`s or float`s or `int`s that specifies the values
for the intervals defined by `boundaries`. It should have one more element
than `boundaries`, and all elements should have the same type.
name: A string. Optional name of the operation. Defaults to
'PiecewiseConstant'.
Returns:
A 0-D Tensor. Its value is `values[0]` when `x <= boundaries[0]`,
`values[1]` when `x > boundaries[0]` and `x <= boundaries[1]`, ...,
and values[-1] when `x > boundaries[-1]`.
"""
with ops.name_scope(name, 'PiecewiseConstant',
[x, boundaries, values, name]) as name:
x = ops.convert_to_tensor(x)
# Avoid explicit conversion to x's dtype. This could result in faulty
# comparisons, for example if floats are converted to integers.
boundaries = ops.convert_n_to_tensor(boundaries)
if not all(b.dtype == x.dtype for b in boundaries):
raise ValueError('boundaries must have the same dtype as x.')
# TODO(rdipietro): Ensure that boundaries' elements are strictly increasing.
values = ops.convert_n_to_tensor(values)
if not all(v.dtype == values[0].dtype for v in values):
raise ValueError('values must have elements all with the same dtype.')
pred_fn_pairs = {}
pred_fn_pairs[x <= boundaries[0]] = lambda: values[0]
pred_fn_pairs[x > boundaries[-1]] = lambda: values[-1]
for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
# Need to bind v here; can do this with lambda v=v: ...
pred = (x > low) & (x <= high)
pred_fn_pairs[pred] = lambda v=v: v
# The default isn't needed here because our conditions are mutually
# exclusive and exhaustive, but tf.case requires it.
default = lambda: values[0]
return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
示例2: testInt
def testInt(self):
np.random.seed(54321)
x = [np.random.randint(-128, 128, (5, 4, 3, 2, 1)) for _ in range(6)]
tf_x = ops.convert_n_to_tensor(x)
with self.test_session(use_gpu=True):
self.assertAllEqual(sum(x), math_ops.accumulate_n(tf_x).eval())
self.assertAllEqual(x[0] * 6, math_ops.accumulate_n([tf_x[0]] * 6).eval())
示例3: testFloat
def testFloat(self):
np.random.seed(12345)
x = [np.random.random((1, 2, 3, 4, 5)) - 0.5 for _ in range(5)]
tf_x = ops.convert_n_to_tensor(x)
with self.test_session(use_gpu=True):
self.assertAllClose(sum(x), math_ops.accumulate_n(tf_x).eval())
self.assertAllClose(x[0] * 5, math_ops.accumulate_n([tf_x[0]] * 5).eval())
示例4: test_mean
def test_mean(self):
m = metrics.Mean(name='my_mean')
# check config
self.assertEqual(m.name, 'my_mean')
self.assertTrue(m.stateful)
self.assertEqual(m.dtype, dtypes.float32)
self.assertEqual(len(m.variables), 2)
self.evaluate(variables.global_variables_initializer())
# check initial state
self.assertEqual(self.evaluate(m.total), 0)
self.assertEqual(self.evaluate(m.count), 0)
# check __call__()
self.assertEqual(self.evaluate(m(100)), 100)
self.assertEqual(self.evaluate(m.total), 100)
self.assertEqual(self.evaluate(m.count), 1)
# check update_state() and result() + state accumulation + tensor input
update_op = m.update_state(ops.convert_n_to_tensor([1, 5]))
self.evaluate(update_op)
self.assertAlmostEqual(self.evaluate(m.result()), 106 / 3, 2)
self.assertEqual(self.evaluate(m.total), 106) # 100 + 1 + 5
self.assertEqual(self.evaluate(m.count), 3)
# check reset_states()
m.reset_states()
self.assertEqual(self.evaluate(m.total), 0)
self.assertEqual(self.evaluate(m.count), 0)
示例5: xla_launch_eager_fallback
def xla_launch_eager_fallback(constants, args, resources, Tresults, function, name=None, ctx=None):
r"""This is the slowpath function for Eager mode.
This is for function xla_launch
"""
_ctx = ctx if ctx else _context.context()
if not isinstance(resources, (list, tuple)):
raise TypeError(
"Expected list for 'resources' argument to "
"'xla_launch' Op, not %r." % resources)
_attr_Nresources = len(resources)
if not isinstance(Tresults, (list, tuple)):
raise TypeError(
"Expected list for 'Tresults' argument to "
"'xla_launch' Op, not %r." % Tresults)
Tresults = [_execute.make_type(_t, "Tresults") for _t in Tresults]
_attr_Tconstants, constants = _execute.convert_to_mixed_eager_tensors(constants, _ctx)
_attr_Targs, args = _execute.convert_to_mixed_eager_tensors(args, _ctx)
resources = _ops.convert_n_to_tensor(resources, _dtypes.resource)
_inputs_flat = list(constants) + list(args) + list(resources)
_attrs = ("Tconstants", _attr_Tconstants, "Targs", _attr_Targs,
"Nresources", _attr_Nresources, "Tresults", Tresults, "function", function)
_result = _execute.execute(b"XlaLaunch", len(Tresults), inputs=_inputs_flat,
attrs=_attrs, ctx=_ctx, name=name)
_execute.record_gradient(
"XlaLaunch", _inputs_flat, _attrs, _result, name)
return _result
示例6: testFloat
def testFloat(self):
np.random.seed(12345)
x = [np.random.random((1, 2, 3, 4, 5)) - 0.5 for _ in range(5)]
tf_x = ops.convert_n_to_tensor(x)
self.assertAllClose(sum(x), math_ops.accumulate_n(tf_x))
self.assertAllClose(x[0] * 5,
math_ops.accumulate_n([tf_x[0]] * 5))
示例7: testBuild
def testBuild(self):
graph = graph_pb2.GraphDef()
node = graph.node.add()
node.name = "a"
node.op = "op0"
node = graph.node.add()
node.name = "b"
node.op = "op1"
inputs = [ops.convert_n_to_tensor([1], dtypes.int64)]
output_types = [np.int64, np.int64]
graph_input_node_names = ["a"]
graph_output_node_names = ["a", "b"]
executor_name = ""
serialized_executor_parameters = b""
default_graph_input_tensor_type_shapes = [[dtypes.int64, [1]]]
default_graph_output_tensor_type_shapes = [[dtypes.int64, [1]],
[dtypes.int64, [1]]]
output_nodes = remote_fused_graph_ops.remote_fused_graph_execute(
inputs, output_types, graph, graph_input_node_names,
graph_output_node_names, executor_name, serialized_executor_parameters,
default_graph_input_tensor_type_shapes,
default_graph_output_tensor_type_shapes)
self.assertEqual(2, len(output_nodes))
for output_node in output_nodes:
with self.test_session(use_gpu=False):
output_node.eval()
示例8: grow_tree_ensemble_eager_fallback
def grow_tree_ensemble_eager_fallback(tree_ensemble_handle, stamp_token, next_stamp_token, learning_rate, dropout_seed, max_tree_depth, weak_learner_type, partition_ids, gains, splits, learner_config, center_bias, name=None, ctx=None):
r"""This is the slowpath function for Eager mode.
This is for function grow_tree_ensemble
"""
_ctx = ctx if ctx else _context.context()
if not isinstance(partition_ids, (list, tuple)):
raise TypeError(
"Expected list for 'partition_ids' argument to "
"'grow_tree_ensemble' Op, not %r." % partition_ids)
_attr_num_handlers = len(partition_ids)
if not isinstance(gains, (list, tuple)):
raise TypeError(
"Expected list for 'gains' argument to "
"'grow_tree_ensemble' Op, not %r." % gains)
if len(gains) != _attr_num_handlers:
raise ValueError(
"List argument 'gains' to 'grow_tree_ensemble' Op with length %d "
"must match length %d of argument 'partition_ids'." %
(len(gains), _attr_num_handlers))
if not isinstance(splits, (list, tuple)):
raise TypeError(
"Expected list for 'splits' argument to "
"'grow_tree_ensemble' Op, not %r." % splits)
if len(splits) != _attr_num_handlers:
raise ValueError(
"List argument 'splits' to 'grow_tree_ensemble' Op with length %d "
"must match length %d of argument 'partition_ids'." %
(len(splits), _attr_num_handlers))
learner_config = _execute.make_str(learner_config, "learner_config")
center_bias = _execute.make_bool(center_bias, "center_bias")
tree_ensemble_handle = _ops.convert_to_tensor(tree_ensemble_handle, _dtypes.resource)
stamp_token = _ops.convert_to_tensor(stamp_token, _dtypes.int64)
next_stamp_token = _ops.convert_to_tensor(next_stamp_token, _dtypes.int64)
learning_rate = _ops.convert_to_tensor(learning_rate, _dtypes.float32)
dropout_seed = _ops.convert_to_tensor(dropout_seed, _dtypes.int64)
max_tree_depth = _ops.convert_to_tensor(max_tree_depth, _dtypes.int32)
weak_learner_type = _ops.convert_to_tensor(weak_learner_type, _dtypes.int32)
partition_ids = _ops.convert_n_to_tensor(partition_ids, _dtypes.int32)
gains = _ops.convert_n_to_tensor(gains, _dtypes.float32)
splits = _ops.convert_n_to_tensor(splits, _dtypes.string)
_inputs_flat = [tree_ensemble_handle, stamp_token, next_stamp_token, learning_rate, dropout_seed, max_tree_depth, weak_learner_type] + list(partition_ids) + list(gains) + list(splits)
_attrs = ("learner_config", learner_config, "num_handlers",
_attr_num_handlers, "center_bias", center_bias)
_result = _execute.execute(b"GrowTreeEnsemble", 0, inputs=_inputs_flat,
attrs=_attrs, ctx=_ctx, name=name)
_result = None
return _result
示例9: testFloat
def testFloat(self):
np.random.seed(12345)
x = [np.random.random((1, 2, 3, 4, 5)) - 0.5 for _ in range(5)]
tf_x = ops.convert_n_to_tensor(x)
for u in tf_x:
print("shape=%s" % u.get_shape())
with self.test_session():
self.assertAllClose(sum(x), math_ops.accumulate_n(tf_x).eval())
self.assertAllClose(x[0] * 5, math_ops.accumulate_n([tf_x[0]] * 5).eval())
示例10: decayed_lr
def decayed_lr(x, boundaries, values, name):
"""Helper to recompute learning rate; most helpful in eager-mode."""
with ops.name_scope(name, "PiecewiseConstant",
[x, boundaries, values, name]) as name:
boundaries = ops.convert_n_to_tensor(boundaries)
values = ops.convert_n_to_tensor(values)
x_recomp = ops.convert_to_tensor(x)
# Avoid explicit conversion to x's dtype. This could result in faulty
# comparisons, for example if floats are converted to integers.
for i, b in enumerate(boundaries):
if b.dtype.base_dtype != x_recomp.dtype.base_dtype:
# We can promote int32 boundaries to int64 without loss of precision.
# This covers the most common case where the user passes in boundaries
# as an array of Python integers.
if (b.dtype.base_dtype == dtypes.int32 and
x_recomp.dtype.base_dtype == dtypes.int64):
b = math_ops.cast(b, x_recomp.dtype.base_dtype)
boundaries[i] = b
else:
raise ValueError(
"Boundaries (%s) must have the same dtype as x (%s)." %
(b.dtype.base_dtype, x_recomp.dtype.base_dtype))
# TODO(rdipietro): Ensure that boundaries' elements strictly increases.
for v in values[1:]:
if v.dtype.base_dtype != values[0].dtype.base_dtype:
raise ValueError(
"Values must have elements all with the same dtype (%s vs %s)." %
(values[0].dtype.base_dtype, v.dtype.base_dtype))
pred_fn_pairs = []
pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0]))
pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1]))
for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
# Need to bind v here; can do this with lambda v=v: ...
pred = (x_recomp > low) & (x_recomp <= high)
pred_fn_pairs.append((pred, lambda v=v: v))
# The default isn't needed here because our conditions are mutually
# exclusive and exhaustive, but tf.case requires it.
default = lambda: values[0]
return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
示例11: sparse_feature_cross_v2_eager_fallback
def sparse_feature_cross_v2_eager_fallback(indices, values, shapes, dense, hashed_output, num_buckets, hash_key, out_type, internal_type, name=None, ctx=None):
r"""This is the slowpath function for Eager mode.
This is for function sparse_feature_cross_v2
"""
_ctx = ctx if ctx else _context.context()
if not isinstance(indices, (list, tuple)):
raise TypeError(
"Expected list for 'indices' argument to "
"'sparse_feature_cross_v2' Op, not %r." % indices)
_attr_N = len(indices)
if not isinstance(shapes, (list, tuple)):
raise TypeError(
"Expected list for 'shapes' argument to "
"'sparse_feature_cross_v2' Op, not %r." % shapes)
if len(shapes) != _attr_N:
raise ValueError(
"List argument 'shapes' to 'sparse_feature_cross_v2' Op with length %d "
"must match length %d of argument 'indices'." %
(len(shapes), _attr_N))
hashed_output = _execute.make_bool(hashed_output, "hashed_output")
num_buckets = _execute.make_int(num_buckets, "num_buckets")
hash_key = _execute.make_int(hash_key, "hash_key")
out_type = _execute.make_type(out_type, "out_type")
internal_type = _execute.make_type(internal_type, "internal_type")
_attr_sparse_types, values = _execute.convert_to_mixed_eager_tensors(values, _ctx)
_attr_dense_types, dense = _execute.convert_to_mixed_eager_tensors(dense, _ctx)
indices = _ops.convert_n_to_tensor(indices, _dtypes.int64)
shapes = _ops.convert_n_to_tensor(shapes, _dtypes.int64)
_inputs_flat = list(indices) + list(values) + list(shapes) + list(dense)
_attrs = ("N", _attr_N, "hashed_output", hashed_output, "num_buckets",
num_buckets, "hash_key", hash_key, "sparse_types", _attr_sparse_types,
"dense_types", _attr_dense_types, "out_type", out_type, "internal_type",
internal_type)
_result = _execute.execute(b"SparseFeatureCrossV2", 3, inputs=_inputs_flat,
attrs=_attrs, ctx=_ctx, name=name)
_execute.record_gradient(
"SparseFeatureCrossV2", _inputs_flat, _attrs, _result, name)
_result = _SparseFeatureCrossV2Output._make(_result)
return _result
示例12: apply_op
def apply_op(self, op_type_name, name=None, **keywords):
# pylint: disable=g-doc-args
"""Add a node invoking a registered Op to a graph.
Example usage:
# input1 and input2 can be Tensors or anything ops.convert_to_tensor()
# will convert to a Tensor.
op_def_library.apply_op("op", input1=input1, input2=input2)
# Can specify a node name.
op_def_library.apply_op("op", input1=input1, name="node_name")
# Must use keyword arguments, with the names specified in the OpDef.
op_def_library.apply_op("op", input_name=input, attr_name=attr)
All attrs must either be inferred from an input or specified.
(If inferred, the attr must not be specified.) If an attr has a default
value specified in the Op's OpDef, then you may pass None as the value
of that attr to get the default.
Args:
op_type_name: string. Must match the name field of a registered Op.
name: string. Optional name of the created op.
**keywords: input Tensor and attr arguments specified by name,
and optional parameters to pass when constructing the Operation.
Returns:
The Tensor(s) representing the output of the operation, or the Operation
itself if there are no outputs.
Raises:
RuntimeError: On some errors.
TypeError: On some errors.
ValueError: On some errors.
"""
output_structure, is_stateful, op = self._apply_op_helper(
op_type_name, name, **keywords)
if output_structure:
outputs = op.outputs
res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure)
if isinstance(res, list) and not res and is_stateful:
return op
else:
return res
else:
return op
示例13: apply_op
#.........这里部分代码省略.........
# the type indicated by the attrs (if they have already been
# inferred via an earlier input).
# * If the input_arg has an explicit type, make sure the input
# conforms.
if _IsListParameter(input_arg):
if not _IsListValue(values):
raise TypeError(
"Expected list for '%s' argument to '%s' Op, not %s." %
(input_name, op_type_name, values))
# In cases where we expect all elements of the list to have the
# same dtype, try to cast non-Tensor elements to that type.
dtype = None
default_dtype = None
if input_arg.type != types_pb2.DT_INVALID:
dtype = input_arg.type
elif input_arg.number_attr:
if input_arg.type_attr in attrs:
dtype = attrs[input_arg.type_attr]
else:
for t in values:
if isinstance(t, ops.Tensor):
dtype = t.dtype
break
# dtype still not found, prefer using the default dtype
# from the attr.
if dtype is None and input_arg.type_attr in default_type_attr_map:
default_dtype = default_type_attr_map[input_arg.type_attr]
try:
if not input_arg.is_ref and dtype:
dtype = dtypes.as_dtype(dtype).base_dtype
values = ops.convert_n_to_tensor(
values,
name=input_arg.name,
dtype=dtype if dtype else None,
preferred_dtype=default_dtype,
as_ref=input_arg.is_ref)
if input_arg.number_attr and len(
set(v.dtype.base_dtype for v in values)) > 1:
raise TypeError() # All types should match.
except (TypeError, ValueError):
# What types does the conversion function think values have?
observed_types = []
for value in values:
try:
converted_value = ops.convert_to_tensor(
value, as_ref=input_arg.is_ref)
observed_types.append(converted_value.dtype.base_dtype.name)
except (TypeError, ValueError):
observed_types.append("<NOT CONVERTIBLE TO TENSOR>")
observed = ", ".join(observed_types)
prefix = (
"Tensors in list passed to '%s' of '%s' Op have types [%s]" %
(input_name, op_type_name, observed))
if input_arg.number_attr:
if input_arg.type != types_pb2.DT_INVALID:
raise TypeError("%s that do not match expected type %s." %
(prefix, dtype.name))
elif input_arg.type_attr in attrs:
raise TypeError("%s that do not match type %s inferred from "
"earlier arguments." %
(prefix, dtype.name))
else:
示例14: __init__
def __init__(self,
filenames,
record_defaults,
buffer_size=None,
header=False,
field_delim=",",
use_quote_delim=True,
na_value="",
select_cols=None):
"""Creates a `CsvDataset` by reading and decoding CSV files.
The elements of this dataset correspond to records from the file(s).
RFC 4180 format is expected for CSV files
(https://tools.ietf.org/html/rfc4180)
Note that we allow leading and trailing spaces with int or float field.
For example, suppose we have a file 'my_file0.csv' with four CSV columns of
different data types:
```
abcdefg,4.28E10,5.55E6,12
hijklmn,-5.3E14,,2
```
We can construct a CsvDataset from it as follows:
```python
dataset = tf.contrib.data.CsvDataset(
"my_file*.csv",
[tf.float32, # Required field, use dtype or empty tensor
tf.constant([0.0], dtype=tf.float32), # Optional field, default to 0.0
tf.int32, # Required field, use dtype or empty tensor
],
select_cols=[1,2,3] # Only parse last three columns
)
```
The expected output of its iterations is:
```python
next = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
while True:
try:
print(sess.run(nxt))
except tf.errors.OutOfRangeError:
break
>> (4.28e10, 5.55e6, 12)
>> (-5.3e14, 0.0, 2)
```
Args:
filenames: A `tf.string` tensor containing one or more filenames.
record_defaults: A list of default values for the CSV fields. Each item in
the list is either a valid CSV `DType` (float32, float64, int32, int64,
string), or a `Tensor` object with one of the above types. One per
column of CSV data, with either a scalar `Tensor` default value for the
column if it is optional, or `DType` or empty `Tensor` if required. If
both this and `select_columns` are specified, these must have the same
lengths, and `column_defaults` is assumed to be sorted in order of
increasing column index.
buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
to buffer while reading files. Defaults to 4MB.
header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s)
have header line(s) that should be skipped when parsing. Defaults to
`False`.
field_delim: (Optional.) A `tf.string` scalar containing the delimiter
character that separates fields in a record. Defaults to `","`.
use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats
double quotation marks as regular characters inside of string fields
(ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`.
na_value: (Optional.) A `tf.string` scalar indicating a value that will
be treated as NA/NaN.
select_cols: (Optional.) A sorted list of column indices to select from
the input data. If specified, only this subset of columns will be
parsed. Defaults to parsing all columns.
"""
super(CsvDataset, self).__init__()
self._filenames = ops.convert_to_tensor(
filenames, dtype=dtypes.string, name="filenames")
record_defaults = [
constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
for x in record_defaults
]
self._record_defaults = ops.convert_n_to_tensor(
record_defaults, name="record_defaults")
self._buffer_size = convert.optional_param_to_tensor(
"buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
self._header = ops.convert_to_tensor(
header, dtype=dtypes.bool, name="header")
self._field_delim = ops.convert_to_tensor(
field_delim, dtype=dtypes.string, name="field_delim")
self._use_quote_delim = ops.convert_to_tensor(
use_quote_delim, dtype=dtypes.bool, name="use_quote_delim")
self._na_value = ops.convert_to_tensor(
na_value, dtype=dtypes.string, name="na_value")
self._select_cols = convert.optional_param_to_tensor(
"select_cols",
select_cols,
argument_default=[],
argument_dtype=dtypes.int64,
#.........这里部分代码省略.........
示例15: piecewise_constant
def piecewise_constant(x, boundaries, values, name=None):
"""Piecewise constant from boundaries and interval values.
Example: use a learning rate that's 1.0 for the first 100000 steps, 0.5
for steps 100001 to 110000, and 0.1 for any additional steps.
```python
global_step = tf.Variable(0, trainable=False)
boundaries = [100000, 110000]
values = [1.0, 0.5, 0.1]
learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)
# Later, whenever we perform an optimization step, we increment global_step.
```
Args:
x: A 0-D scalar `Tensor`. Must be one of the following types: `float32`,
`float64`, `uint8`, `int8`, `int16`, `int32`, `int64`.
boundaries: A list of `Tensor`s or `int`s or `float`s with strictly
increasing entries, and with all elements having the same type as `x`.
values: A list of `Tensor`s or float`s or `int`s that specifies the values
for the intervals defined by `boundaries`. It should have one more element
than `boundaries`, and all elements should have the same type.
name: A string. Optional name of the operation. Defaults to
'PiecewiseConstant'.
Returns:
A 0-D Tensor. Its value is `values[0]` when `x <= boundaries[0]`,
`values[1]` when `x > boundaries[0]` and `x <= boundaries[1]`, ...,
and values[-1] when `x > boundaries[-1]`.
Raises:
ValueError: if types of `x` and `boundaries` do not match, or types of all
`values` do not match or
the number of elements in the lists does not match.
"""
if len(boundaries) != len(values) - 1:
raise ValueError(
"The length of boundaries should be 1 less than the length of values")
with ops.name_scope(name, "PiecewiseConstant",
[x, boundaries, values, name]) as name:
x = ops.convert_to_tensor(x)
# Avoid explicit conversion to x's dtype. This could result in faulty
# comparisons, for example if floats are converted to integers.
boundaries = ops.convert_n_to_tensor(boundaries)
for i, b in enumerate(boundaries):
if b.dtype.base_dtype != x.dtype.base_dtype:
# We can promote int32 boundaries to int64 without loss of precision.
# This covers the most common case where the user passes in boundaries
# as an array of Python integers.
if (b.dtype.base_dtype == dtypes.int32 and
x.dtype.base_dtype == dtypes.int64):
b = math_ops.cast(b, x.dtype.base_dtype)
boundaries[i] = b
else:
raise ValueError(
"Boundaries (%s) must have the same dtype as x (%s)." % (
b.dtype.base_dtype, x.dtype.base_dtype))
# TODO(rdipietro): Ensure that boundaries' elements are strictly increasing.
values = ops.convert_n_to_tensor(values)
for v in values[1:]:
if v.dtype.base_dtype != values[0].dtype.base_dtype:
raise ValueError(
"Values must have elements all with the same dtype (%s vs %s)." % (
values[0].dtype.base_dtype, v.dtype.base_dtype))
pred_fn_pairs = []
pred_fn_pairs.append((x <= boundaries[0], lambda: values[0]))
pred_fn_pairs.append((x > boundaries[-1], lambda: values[-1]))
for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
# Need to bind v here; can do this with lambda v=v: ...
pred = (x > low) & (x <= high)
pred_fn_pairs.append((pred, lambda v=v: v))
# The default isn't needed here because our conditions are mutually
# exclusive and exhaustive, but tf.case requires it.
default = lambda: values[0]
return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)