本文整理汇总了Python中tensorflow.python.ops.array_ops.broadcast_static_shape方法的典型用法代码示例。如果您正苦于以下问题:Python array_ops.broadcast_static_shape方法的具体用法?Python array_ops.broadcast_static_shape怎么用?Python array_ops.broadcast_static_shape使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.ops.array_ops
的用法示例。
在下文中一共展示了array_ops.broadcast_static_shape方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: prefer_static_broadcast_shape
# 需要导入模块: from tensorflow.python.ops import array_ops [as 别名]
# 或者: from tensorflow.python.ops.array_ops import broadcast_static_shape [as 别名]
def prefer_static_broadcast_shape(
shape1, shape2, name="prefer_static_broadcast_shape"):
"""Convenience function which statically broadcasts shape when possible.
Args:
shape1: `1-D` integer `Tensor`. Already converted to tensor!
shape2: `1-D` integer `Tensor`. Already converted to tensor!
name: A string name to prepend to created ops.
Returns:
The broadcast shape, either as `TensorShape` (if broadcast can be done
statically), or as a `Tensor`.
"""
with ops.name_scope(name, values=[shape1, shape2]):
if (tensor_util.constant_value(shape1) is not None and
tensor_util.constant_value(shape2) is not None):
return array_ops.broadcast_static_shape(
tensor_shape.TensorShape(tensor_util.constant_value(shape1)),
tensor_shape.TensorShape(tensor_util.constant_value(shape2)))
return array_ops.broadcast_dynamic_shape(shape1, shape2)
示例2: _check_shapes
# 需要导入模块: from tensorflow.python.ops import array_ops [as 别名]
# 或者: from tensorflow.python.ops.array_ops import broadcast_static_shape [as 别名]
def _check_shapes(self):
"""Static check that shapes are compatible."""
# Broadcast shape also checks that u and v are compatible.
uv_shape = array_ops.broadcast_static_shape(
self.u.get_shape(), self.v.get_shape())
batch_shape = array_ops.broadcast_static_shape(
self.base_operator.batch_shape, uv_shape[:-2])
self.base_operator.domain_dimension.assert_is_compatible_with(
uv_shape[-2])
if self._diag_update is not None:
uv_shape[-1].assert_is_compatible_with(self._diag_update.get_shape()[-1])
array_ops.broadcast_static_shape(
batch_shape, self._diag_update.get_shape()[:-1])
示例3: _batch_shape
# 需要导入模块: from tensorflow.python.ops import array_ops [as 别名]
# 或者: from tensorflow.python.ops.array_ops import broadcast_static_shape [as 别名]
def _batch_shape(self):
return array_ops.broadcast_static_shape(
self.concentration.get_shape(),
self.rate.get_shape())
示例4: _log_prob
# 需要导入模块: from tensorflow.python.ops import array_ops [as 别名]
# 或者: from tensorflow.python.ops.array_ops import broadcast_static_shape [as 别名]
def _log_prob(self, y):
x = self.bijector.inverse(y)
ildj = self.bijector.inverse_log_det_jacobian(y)
x = self._maybe_rotate_dims(x, rotate_right=True)
log_prob = self.distribution.log_prob(x)
if self._is_maybe_event_override:
log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices)
log_prob = ildj + log_prob
if self._is_maybe_event_override:
log_prob.set_shape(array_ops.broadcast_static_shape(
y.get_shape().with_rank_at_least(1)[:-1], self.batch_shape))
return log_prob
示例5: _prob
# 需要导入模块: from tensorflow.python.ops import array_ops [as 别名]
# 或者: from tensorflow.python.ops.array_ops import broadcast_static_shape [as 别名]
def _prob(self, y):
x = self.bijector.inverse(y)
ildj = self.bijector.inverse_log_det_jacobian(y)
x = self._maybe_rotate_dims(x, rotate_right=True)
prob = self.distribution.prob(x)
if self._is_maybe_event_override:
prob = math_ops.reduce_prod(prob, self._reduce_event_indices)
prob *= math_ops.exp(ildj)
if self._is_maybe_event_override:
prob.set_shape(array_ops.broadcast_static_shape(
y.get_shape().with_rank_at_least(1)[:-1], self.batch_shape))
return prob
示例6: _batch_shape
# 需要导入模块: from tensorflow.python.ops import array_ops [as 别名]
# 或者: from tensorflow.python.ops.array_ops import broadcast_static_shape [as 别名]
def _batch_shape(self):
return array_ops.broadcast_static_shape(
array_ops.broadcast_static_shape(self.df.get_shape(),
self.loc.get_shape()),
self.scale.get_shape())
示例7: _batch_shape
# 需要导入模块: from tensorflow.python.ops import array_ops [as 别名]
# 或者: from tensorflow.python.ops.array_ops import broadcast_static_shape [as 别名]
def _batch_shape(self):
return array_ops.broadcast_static_shape(
self.low.get_shape(),
self.high.get_shape())
示例8: _batch_shape
# 需要导入模块: from tensorflow.python.ops import array_ops [as 别名]
# 或者: from tensorflow.python.ops.array_ops import broadcast_static_shape [as 别名]
def _batch_shape(self):
return array_ops.broadcast_static_shape(
self.loc.get_shape(),
self.scale.get_shape())
示例9: _batch_shape
# 需要导入模块: from tensorflow.python.ops import array_ops [as 别名]
# 或者: from tensorflow.python.ops.array_ops import broadcast_static_shape [as 别名]
def _batch_shape(self):
return array_ops.broadcast_static_shape(
self.loc.get_shape(), self.scale.get_shape())
示例10: _batch_shape
# 需要导入模块: from tensorflow.python.ops import array_ops [as 别名]
# 或者: from tensorflow.python.ops.array_ops import broadcast_static_shape [as 别名]
def _batch_shape(self):
return array_ops.broadcast_static_shape(
self.total_count.get_shape(),
self.probs.get_shape())
示例11: _static_check_for_broadcastable_batch_shape
# 需要导入模块: from tensorflow.python.ops import array_ops [as 别名]
# 或者: from tensorflow.python.ops.array_ops import broadcast_static_shape [as 别名]
def _static_check_for_broadcastable_batch_shape(operators):
"""ValueError if operators determined to have non-broadcastable shapes."""
if len(operators) < 2:
return
# This will fail if they cannot be broadcast together.
batch_shape = operators[0].batch_shape
for op in operators[1:]:
batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape)
示例12: _possibly_broadcast_batch_shape
# 需要导入模块: from tensorflow.python.ops import array_ops [as 别名]
# 或者: from tensorflow.python.ops.array_ops import broadcast_static_shape [as 别名]
def _possibly_broadcast_batch_shape(self, x):
"""Return 'x', possibly after broadcasting the leading dimensions."""
# If we have no batch shape, our batch shape broadcasts with everything!
if self._batch_shape_arg is None:
return x
# Static attempt:
# If we determine that no broadcast is necessary, pass x through
# If we need a broadcast, add to an array of zeros.
#
# special_shape is the shape that, when broadcast with x's shape, will give
# the correct broadcast_shape. Note that
# We have already verified the second to last dimension of self.shape
# matches x's shape in assert_compatible_matrix_dimensions.
# Also, the final dimension of 'x' can have any shape.
# Therefore, the final two dimensions of special_shape are 1's.
special_shape = self.batch_shape.concatenate([1, 1])
bshape = array_ops.broadcast_static_shape(x.get_shape(), special_shape)
if special_shape.is_fully_defined():
# bshape.is_fully_defined iff special_shape.is_fully_defined.
if bshape == x.get_shape():
return x
# Use the built in broadcasting of addition.
zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
return x + zeros
# Dynamic broadcast:
# Always add to an array of zeros, rather than using a "cond", since a
# cond would require copying data from GPU --> CPU.
special_shape = array_ops.concat((self.batch_shape_tensor(), [1, 1]), 0)
zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
return x + zeros
示例13: _shape
# 需要导入模块: from tensorflow.python.ops import array_ops [as 别名]
# 或者: from tensorflow.python.ops.array_ops import broadcast_static_shape [as 别名]
def _shape(self):
batch_shape = array_ops.broadcast_static_shape(
self.base_operator.batch_shape,
self.u.get_shape()[:-2])
return batch_shape.concatenate(self.base_operator.shape[-2:])
示例14: _get_batch_shape
# 需要导入模块: from tensorflow.python.ops import array_ops [as 别名]
# 或者: from tensorflow.python.ops.array_ops import broadcast_static_shape [as 别名]
def _get_batch_shape(self):
return array_ops.broadcast_static_shape(
self.alpha.get_shape(), self.beta.get_shape())
示例15: _get_batch_shape
# 需要导入模块: from tensorflow.python.ops import array_ops [as 别名]
# 或者: from tensorflow.python.ops.array_ops import broadcast_static_shape [as 别名]
def _get_batch_shape(self):
return array_ops.broadcast_static_shape(
self.loc.get_shape(), self.scale.get_shape())