本文整理汇总了Python中torch.distributions.utils.broadcast_all函数的典型用法代码示例。如果您正苦于以下问题:Python broadcast_all函数的具体用法?Python broadcast_all怎么用?Python broadcast_all使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了broadcast_all函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
def __init__(self, probs=None, logits=None):
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
self.probs, = broadcast_all(probs)
else:
self.logits, = broadcast_all(logits)
probs_or_logits = probs if probs is not None else logits
if isinstance(probs_or_logits, Number):
batch_shape = torch.Size()
else:
batch_shape = probs_or_logits.size()
super(Bernoulli, self).__init__(batch_shape)
示例2: __init__
def __init__(self, probs=None, logits=None, validate_args=None):
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
is_scalar = isinstance(probs, Number)
self.probs, = broadcast_all(probs)
else:
is_scalar = isinstance(logits, Number)
self.logits, = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits
if is_scalar:
batch_shape = torch.Size()
else:
batch_shape = self._param.size()
super(Bernoulli, self).__init__(batch_shape, validate_args=validate_args)
示例3: __init__
def __init__(self, probs=None, logits=None, validate_args=None):
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
self.probs, = broadcast_all(probs)
if not self.probs.gt(0).all():
raise ValueError('All elements of probs must be greater than 0')
else:
self.logits, = broadcast_all(logits)
probs_or_logits = probs if probs is not None else logits
if isinstance(probs_or_logits, Number):
batch_shape = torch.Size()
else:
batch_shape = probs_or_logits.size()
super(Geometric, self).__init__(batch_shape, validate_args=validate_args)
示例4: __init__
def __init__(self, loc, scale):
self.loc, self.scale = broadcast_all(loc, scale)
if isinstance(loc, Number) and isinstance(scale, Number):
batch_shape = torch.Size()
else:
batch_shape = self.loc.size()
super(Laplace, self).__init__(batch_shape)
示例5: __init__
def __init__(self, rate, validate_args=None):
self.rate, = broadcast_all(rate)
if isinstance(rate, Number):
batch_shape = torch.Size()
else:
batch_shape = self.rate.size()
super(Poisson, self).__init__(batch_shape, validate_args=validate_args)
示例6: __init__
def __init__(self, alpha, beta):
self.alpha, self.beta = broadcast_all(alpha, beta)
if isinstance(alpha, Number) and isinstance(beta, Number):
batch_shape = torch.Size()
else:
batch_shape = self.alpha.size()
super(Gamma, self).__init__(batch_shape)
示例7: __init__
def __init__(self, scale, alpha):
self.scale, self.alpha = broadcast_all(scale, alpha)
if isinstance(scale, Number) and isinstance(alpha, Number):
batch_shape = torch.Size()
else:
batch_shape = self.scale.size()
super(Pareto, self).__init__(batch_shape)
示例8: __init__
def __init__(self, low, high):
self.low, self.high = broadcast_all(low, high)
if isinstance(low, Number) and isinstance(high, Number):
batch_shape = torch.Size()
else:
batch_shape = self.low.size()
super(Uniform, self).__init__(batch_shape)
示例9: __init__
def __init__(self, concentration, rate):
self.concentration, self.rate = broadcast_all(concentration, rate)
if isinstance(concentration, Number) and isinstance(rate, Number):
batch_shape = torch.Size()
else:
batch_shape = self.concentration.size()
super(Gamma, self).__init__(batch_shape)
示例10: __init__
def __init__(self, loc, scale, validate_args=None):
self.loc, self.scale = broadcast_all(loc, scale)
if isinstance(loc, Number) and isinstance(scale, Number):
batch_shape = torch.Size()
else:
batch_shape = self.loc.size()
super(Normal, self).__init__(batch_shape, validate_args=validate_args)
示例11: __init__
def __init__(self, alpha, beta):
if isinstance(alpha, Number) and isinstance(beta, Number):
alpha_beta = torch.Tensor([alpha, beta])
else:
alpha, beta = broadcast_all(alpha, beta)
alpha_beta = torch.stack([alpha, beta], -1)
self._dirichlet = Dirichlet(alpha_beta)
super(Beta, self).__init__(self._dirichlet._batch_shape)
示例12: __init__
def __init__(self, concentration1, concentration0, validate_args=None):
if isinstance(concentration1, Number) and isinstance(concentration0, Number):
concentration1_concentration0 = torch.tensor([float(concentration1), float(concentration0)])
else:
concentration1, concentration0 = broadcast_all(concentration1, concentration0)
concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)
self._dirichlet = Dirichlet(concentration1_concentration0)
super(Beta, self).__init__(self._dirichlet._batch_shape, validate_args=validate_args)
示例13: log_prob
def log_prob(self, value):
self._validate_log_prob_arg(value)
logits, value = broadcast_all(self.logits.clone(), value)
log_factorial_n = torch.lgamma(value.sum(-1) + 1)
log_factorial_xs = torch.lgamma(value + 1).sum(-1)
logits[(value == 0) & (logits == -float('inf'))] = 0
log_powers = (logits * value).sum(-1)
return log_factorial_n - log_factorial_xs + log_powers
示例14: __init__
def __init__(self, concentration1, concentration0):
if isinstance(concentration1, Number) and isinstance(concentration0, Number):
concentration1_concentration0 = variable([concentration1, concentration0])
else:
concentration1, concentration0 = broadcast_all(concentration1, concentration0)
concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)
self._dirichlet = Dirichlet(concentration1_concentration0)
super(Beta, self).__init__(self._dirichlet._batch_shape)
示例15: __init__
def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
self.total_count, self.probs, = broadcast_all(total_count, probs)
self.total_count = self.total_count.type_as(self.logits)
is_scalar = isinstance(self.probs, Number)
else:
self.total_count, self.logits, = broadcast_all(total_count, logits)
self.total_count = self.total_count.type_as(self.logits)
is_scalar = isinstance(self.logits, Number)
self._param = self.probs if probs is not None else self.logits
if is_scalar:
batch_shape = torch.Size()
else:
batch_shape = self._param.size()
super(Binomial, self).__init__(batch_shape, validate_args=validate_args)