本文整理汇总了Python中tensorflow.python.ops.distributions.util.embed_check_nonnegative_integer_form函数的典型用法代码示例。如果您正苦于以下问题:Python embed_check_nonnegative_integer_form函数的具体用法?Python embed_check_nonnegative_integer_form怎么用?Python embed_check_nonnegative_integer_form使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了embed_check_nonnegative_integer_form函数的13个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _log_unnormalized_prob
def _log_unnormalized_prob(self, x):
if self.validate_args:
x = distribution_util.embed_check_nonnegative_integer_form(x)
else:
# For consistency with cdf, we take the floor.
x = math_ops.floor(x)
return x * self.log_rate - math_ops.lgamma(1. + x)
示例2: _cdf
def _cdf(self, x):
if self.validate_args:
x = distribution_util.embed_check_nonnegative_integer_form(x)
else:
# Whether or not x is integer-form, the following is well-defined.
# However, scipy takes the floor, so we do too.
x = math_ops.floor(x)
return math_ops.igammac(1. + x, self.rate)
示例3: __init__
def __init__(self,
total_count,
logits=None,
probs=None,
validate_args=False,
allow_nan_stats=True,
name="Multinomial"):
"""Initialize a batch of Multinomial distributions.
Args:
total_count: Non-negative floating point tensor with shape broadcastable
to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of
`N1 x ... x Nm` different Multinomial distributions. Its components
should be equal to integer values.
logits: Floating point tensor representing unnormalized log-probabilities
of a positive event with shape broadcastable to
`[N1,..., Nm, K]` `m >= 0`, and the same dtype as `total_count`. Defines
this as a batch of `N1 x ... x Nm` different `K` class Multinomial
distributions. Only one of `logits` or `probs` should be passed in.
probs: Positive floating point tensor with shape broadcastable to
`[N1,..., Nm, K]` `m >= 0` and same dtype as `total_count`. Defines
this as a batch of `N1 x ... x Nm` different `K` class Multinomial
distributions. `probs`'s components in the last portion of its shape
should sum to `1`. Only one of `logits` or `probs` should be passed in.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
result is undefined. When `False`, an exception is raised if one or
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
parameters = locals()
with ops.name_scope(name, values=[total_count, logits, probs]):
self._total_count = ops.convert_to_tensor(total_count, name="total_count")
if validate_args:
self._total_count = (
distribution_util.embed_check_nonnegative_integer_form(
self._total_count))
self._logits, self._probs = distribution_util.get_logits_and_probs(
logits=logits,
probs=probs,
multidimensional=True,
validate_args=validate_args,
name=name)
self._mean_val = self._total_count[..., array_ops.newaxis] * self._probs
super(Multinomial, self).__init__(
dtype=self._probs.dtype,
reparameterization_type=distribution.NOT_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[self._total_count,
self._logits,
self._probs],
name=name)
示例4: __init__
def __init__(self,
total_count,
concentration,
validate_args=False,
allow_nan_stats=True,
name="DirichletMultinomial"):
"""Initialize a batch of DirichletMultinomial distributions.
Args:
total_count: Non-negative floating point tensor, whose dtype is the same
as `concentration`. The shape is broadcastable to `[N1,..., Nm]` with
`m >= 0`. Defines this as a batch of `N1 x ... x Nm` different
Dirichlet multinomial distributions. Its components should be equal to
integer values.
concentration: Positive floating point tensor, whose dtype is the
same as `n` with shape broadcastable to `[N1,..., Nm, K]` `m >= 0`.
Defines this as a batch of `N1 x ... x Nm` different `K` class Dirichlet
multinomial distributions.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
result is undefined. When `False`, an exception is raised if one or
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
parameters = locals()
with ops.name_scope(name, values=[total_count, concentration]):
# Broadcasting works because:
# * The broadcasting convention is to prepend dimensions of size [1], and
# we use the last dimension for the distribution, whereas
# the batch dimensions are the leading dimensions, which forces the
# distribution dimension to be defined explicitly (i.e. it cannot be
# created automatically by prepending). This forces enough explicitness.
# * All calls involving `counts` eventually require a broadcast between
# `counts` and concentration.
self._total_count = ops.convert_to_tensor(total_count, name="total_count")
if validate_args:
self._total_count = (
distribution_util.embed_check_nonnegative_integer_form(
self._total_count))
self._concentration = self._maybe_assert_valid_concentration(
ops.convert_to_tensor(concentration,
name="concentration"),
validate_args)
self._total_concentration = math_ops.reduce_sum(self._concentration, -1)
super(DirichletMultinomial, self).__init__(
dtype=self._concentration.dtype,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
reparameterization_type=distribution.NOT_REPARAMETERIZED,
parameters=parameters,
graph_parents=[self._total_count,
self._concentration],
name=name)
示例5: _maybe_assert_valid_sample
def _maybe_assert_valid_sample(self, counts):
"""Check counts for proper shape, values, then return tensor version."""
if not self.validate_args:
return counts
counts = distribution_util.embed_check_nonnegative_integer_form(counts)
return control_flow_ops.with_dependencies([
check_ops.assert_equal(
self.total_count, math_ops.reduce_sum(counts, -1),
message="counts last-dimension must sum to `self.total_count`"),
], counts)
示例6: _maybe_assert_valid_sample
def _maybe_assert_valid_sample(self, counts):
"""Check counts for proper shape, values, then return tensor version."""
if not self.validate_args:
return counts
counts = distribution_util.embed_check_nonnegative_integer_form(counts)
return control_flow_ops.with_dependencies([
check_ops.assert_less_equal(
counts, self.total_count,
message="counts are not less than or equal to n."),
], counts)
示例7: _log_prob
def _log_prob(self, x):
if self.validate_args:
x = distribution_util.embed_check_nonnegative_integer_form(x)
else:
# For consistency with cdf, we take the floor.
x = tf.floor(x)
x *= tf.ones_like(self.probs)
probs = self.probs * tf.ones_like(x)
safe_domain = tf.where(tf.equal(x, 0.), tf.zeros_like(probs), probs)
return x * tf.log1p(-safe_domain) + tf.log(probs)
示例8: _cdf
def _cdf(self, x):
if self.validate_args:
x = distribution_util.embed_check_nonnegative_integer_form(x)
else:
# Whether or not x is integer-form, the following is well-defined.
# However, scipy takes the floor, so we do too.
x = tf.floor(x)
x *= tf.ones_like(self.probs)
return tf.where(x < 0., tf.zeros_like(x), -tf.expm1(
(1. + x) * tf.log1p(-self.probs)))
示例9: _log_normalization
def _log_normalization(self, x):
if self.validate_args:
x = distribution_util.embed_check_nonnegative_integer_form(x)
return (-math_ops.lgamma(self.total_count + x)
+ math_ops.lgamma(1. + x)
+ math_ops.lgamma(self.total_count))
示例10: _log_unnormalized_prob
def _log_unnormalized_prob(self, x):
if self.validate_args:
x = distribution_util.embed_check_nonnegative_integer_form(x)
return (self.total_count * math_ops.log_sigmoid(-self.logits)
+ x * math_ops.log_sigmoid(self.logits))
示例11: _cdf
def _cdf(self, x):
if self.validate_args:
x = distribution_util.embed_check_nonnegative_integer_form(x)
return math_ops.betainc(self.total_count, 1. + x,
math_ops.sigmoid(-self.logits))
示例12: _log_unnormalized_prob
def _log_unnormalized_prob(self, x):
if self.validate_args:
x = distribution_util.embed_check_nonnegative_integer_form(x)
return x * self.log_rate - math_ops.lgamma(1. + x)
示例13: _cdf
def _cdf(self, x):
if self.validate_args:
x = distribution_util.embed_check_nonnegative_integer_form(x)
return math_ops.igammac(1. + x, self.rate)