本文整理汇总了Python中tensorflow.python.ops.check_ops.assert_type方法的典型用法代码示例。如果您正苦于以下问题:Python check_ops.assert_type方法的具体用法?Python check_ops.assert_type怎么用?Python check_ops.assert_type使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.ops.check_ops
的用法示例。
在下文中一共展示了check_ops.assert_type方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _count_condition
# 需要导入模块: from tensorflow.python.ops import check_ops [as 别名]
# 或者: from tensorflow.python.ops.check_ops import assert_type [as 别名]
def _count_condition(values, weights=None, metrics_collections=None,
updates_collections=None):
"""Sums the weights of cases where the given values are True.
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
Args:
values: A `bool` `Tensor` of arbitrary size.
weights: An optional `Tensor` whose shape is broadcastable to `values`.
metrics_collections: An optional list of collections that the metric
value variable should be added to.
updates_collections: An optional list of collections that the metric update
ops should be added to.
Returns:
value_tensor: A tensor representing the current value of the metric.
update_op: An operation that accumulates the error from a batch of data.
Raises:
ValueError: If `weights` is not `None` and its shape doesn't match `values`,
or if either `metrics_collections` or `updates_collections` are not a list
or tuple.
"""
check_ops.assert_type(values, dtypes.bool)
count = _create_local('count', shape=[])
values = math_ops.to_float(values)
if weights is not None:
weights = math_ops.to_float(weights)
values = math_ops.mul(values, weights)
value_tensor = array_ops.identity(count)
update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
if metrics_collections:
ops.add_to_collections(metrics_collections, value_tensor)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return value_tensor, update_op
示例2: _count_condition
# 需要导入模块: from tensorflow.python.ops import check_ops [as 别名]
# 或者: from tensorflow.python.ops.check_ops import assert_type [as 别名]
def _count_condition(values, weights=None, metrics_collections=None,
updates_collections=None):
"""Sums the weights of cases where the given values are True.
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
Args:
values: A `bool` `Tensor` of arbitrary size.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`values`, and must be broadcastable to `values` (i.e., all dimensions must
be either `1`, or the same as the corresponding `values` dimension).
metrics_collections: An optional list of collections that the metric
value variable should be added to.
updates_collections: An optional list of collections that the metric update
ops should be added to.
Returns:
value_tensor: A `Tensor` representing the current value of the metric.
update_op: An operation that accumulates the error from a batch of data.
Raises:
ValueError: If `weights` is not `None` and its shape doesn't match `values`,
or if either `metrics_collections` or `updates_collections` are not a list
or tuple.
"""
check_ops.assert_type(values, dtypes.bool)
count = _create_local('count', shape=[])
values = math_ops.to_float(values)
if weights is not None:
with ops.control_dependencies((
check_ops.assert_rank_in(weights, (0, array_ops.rank(values))),)):
weights = math_ops.to_float(weights)
values = math_ops.multiply(values, weights)
value_tensor = array_ops.identity(count)
update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
if metrics_collections:
ops.add_to_collections(metrics_collections, value_tensor)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return value_tensor, update_op
示例3: _count_condition
# 需要导入模块: from tensorflow.python.ops import check_ops [as 别名]
# 或者: from tensorflow.python.ops.check_ops import assert_type [as 别名]
def _count_condition(values, weights=None, metrics_collections=None,
updates_collections=None):
"""Sums the weights of cases where the given values are True.
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
Args:
values: A `bool` `Tensor` of arbitrary size.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`values`, and must be broadcastable to `values` (i.e., all dimensions
must be either `1`, or the same as the corresponding `values`
dimension).
metrics_collections: An optional list of collections that the metric
value variable should be added to.
updates_collections: An optional list of collections that the metric update
ops should be added to.
Returns:
value_tensor: A `Tensor` representing the current value of the metric.
update_op: An operation that accumulates the error from a batch of data.
Raises:
ValueError: If `weights` is not `None` and its shape doesn't match `values`,
or if either `metrics_collections` or `updates_collections` are not a list
or tuple.
"""
check_ops.assert_type(values, dtypes.bool)
count = _create_local('count', shape=[])
values = math_ops.to_float(values)
if weights is not None:
weights = math_ops.to_float(weights)
with ops.control_dependencies((_assert_weights_rank(weights, values),)):
values = math_ops.multiply(values, weights)
value_tensor = array_ops.identity(count)
update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
if metrics_collections:
ops.add_to_collections(metrics_collections, value_tensor)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return value_tensor, update_op
示例4: _concatenate_context_input
# 需要导入模块: from tensorflow.python.ops import check_ops [as 别名]
# 或者: from tensorflow.python.ops.check_ops import assert_type [as 别名]
def _concatenate_context_input(sequence_input, context_input):
"""Replicates `context_input` across all timesteps of `sequence_input`.
Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
This value is appended to `sequence_input` on dimension 2 and the result is
returned.
Args:
sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size,
padded_length, d0]`.
context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`.
Returns:
A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
d0 + d1]`.
Raises:
ValueError: If `sequence_input` does not have rank 3 or `context_input` does
not have rank 2.
"""
seq_rank_check = check_ops.assert_rank(
sequence_input,
3,
message='sequence_input must have rank 3',
data=[array_ops.shape(sequence_input)])
seq_type_check = check_ops.assert_type(
sequence_input,
dtypes.float32,
message='sequence_input must have dtype float32; got {}.'.format(
sequence_input.dtype))
ctx_rank_check = check_ops.assert_rank(
context_input,
2,
message='context_input must have rank 2',
data=[array_ops.shape(context_input)])
ctx_type_check = check_ops.assert_type(
context_input,
dtypes.float32,
message='context_input must have dtype float32; got {}.'.format(
context_input.dtype))
with ops.control_dependencies(
[seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
padded_length = array_ops.shape(sequence_input)[1]
tiled_context_input = array_ops.tile(
array_ops.expand_dims(context_input, 1),
array_ops.concat([[1], [padded_length], [1]], 0))
return array_ops.concat([sequence_input, tiled_context_input], 2)
示例5: _concatenate_context_input
# 需要导入模块: from tensorflow.python.ops import check_ops [as 别名]
# 或者: from tensorflow.python.ops.check_ops import assert_type [as 别名]
def _concatenate_context_input(sequence_input, context_input):
"""Replicates `context_input` accross all timesteps of `sequence_input`.
Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
This value is appended to `sequence_input` on dimension 2 and the result is
returned.
Args:
sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size,
padded_length, d0]`.
context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`.
Returns:
A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
d0 + d1]`.
Raises:
ValueError: If `sequence_input` does not have rank 3 or `context_input` does
not have rank 2.
"""
seq_rank_check = check_ops.assert_rank(
sequence_input,
3,
message='sequence_input must have rank 3',
data=[array_ops.shape(sequence_input)])
seq_type_check = check_ops.assert_type(
sequence_input,
dtypes.float32,
message='sequence_input must have dtype float32; got {}.'.format(
sequence_input.dtype))
ctx_rank_check = check_ops.assert_rank(
context_input,
2,
message='context_input must have rank 2',
data=[array_ops.shape(context_input)])
ctx_type_check = check_ops.assert_type(
context_input,
dtypes.float32,
message='context_input must have dtype float32; got {}.'.format(
context_input.dtype))
with ops.control_dependencies(
[seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
padded_length = array_ops.shape(sequence_input)[1]
tiled_context_input = array_ops.tile(
array_ops.expand_dims(context_input, 1),
array_ops.concat([[1], [padded_length], [1]], 0))
return array_ops.concat([sequence_input, tiled_context_input], 2)
示例6: _concatenate_context_input
# 需要导入模块: from tensorflow.python.ops import check_ops [as 别名]
# 或者: from tensorflow.python.ops.check_ops import assert_type [as 别名]
def _concatenate_context_input(sequence_input, context_input):
"""Replicates `context_input` accross all timesteps of `sequence_input`.
Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
This value is appended to `sequence_input` on dimension 2 and the result is
returned.
Args:
sequence_input: a `Tensor` of dtype `float32` and shape `[batch_size,
padded_length, d0]`.
context_input: a `Tensor` of dtype `float32` and shape `[batch_size, d1]`.
Returns:
A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
d0 + d1]`.
Raises:
ValueError: if `sequence_input` does not have rank 3 or `context_input` does
not have rank 2.
"""
seq_rank_check = check_ops.assert_rank(
sequence_input,
3,
message='sequence_input must have rank 3',
data=[array_ops.shape(sequence_input)])
seq_type_check = check_ops.assert_type(
sequence_input,
dtypes.float32,
message='sequence_input must have dtype float32; got {}.'.format(
sequence_input.dtype))
ctx_rank_check = check_ops.assert_rank(
context_input,
2,
message='context_input must have rank 2',
data=[array_ops.shape(context_input)])
ctx_type_check = check_ops.assert_type(
context_input,
dtypes.float32,
message='context_input must have dtype float32; got {}.'.format(
context_input.dtype))
with ops.control_dependencies(
[seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
padded_length = array_ops.shape(sequence_input)[1]
tiled_context_input = array_ops.tile(
array_ops.expand_dims(context_input, 1),
array_ops.concat(0, [[1], [padded_length], [1]]))
return array_ops.concat(2, [sequence_input, tiled_context_input])