当前位置: 首页>>代码示例>>Python>>正文


Python utils.broadcast_all函数代码示例

本文整理汇总了Python中torch.distributions.utils.broadcast_all函数的典型用法代码示例。如果您正苦于以下问题:Python broadcast_all函数的具体用法?Python broadcast_all怎么用?Python broadcast_all使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了broadcast_all函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

 def __init__(self, probs=None, logits=None):
     if (probs is None) == (logits is None):
         raise ValueError("Either `probs` or `logits` must be specified, but not both.")
     if probs is not None:
         self.probs, = broadcast_all(probs)
     else:
         self.logits, = broadcast_all(logits)
     probs_or_logits = probs if probs is not None else logits
     if isinstance(probs_or_logits, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = probs_or_logits.size()
     super(Bernoulli, self).__init__(batch_shape)
开发者ID:lxlhh,项目名称:pytorch,代码行数:13,代码来源:bernoulli.py

示例2: __init__

 def __init__(self, probs=None, logits=None, validate_args=None):
     if (probs is None) == (logits is None):
         raise ValueError("Either `probs` or `logits` must be specified, but not both.")
     if probs is not None:
         is_scalar = isinstance(probs, Number)
         self.probs, = broadcast_all(probs)
     else:
         is_scalar = isinstance(logits, Number)
         self.logits, = broadcast_all(logits)
     self._param = self.probs if probs is not None else self.logits
     if is_scalar:
         batch_shape = torch.Size()
     else:
         batch_shape = self._param.size()
     super(Bernoulli, self).__init__(batch_shape, validate_args=validate_args)
开发者ID:gtgalone,项目名称:pytorch,代码行数:15,代码来源:bernoulli.py

示例3: __init__

 def __init__(self, probs=None, logits=None, validate_args=None):
     if (probs is None) == (logits is None):
         raise ValueError("Either `probs` or `logits` must be specified, but not both.")
     if probs is not None:
         self.probs, = broadcast_all(probs)
         if not self.probs.gt(0).all():
             raise ValueError('All elements of probs must be greater than 0')
     else:
         self.logits, = broadcast_all(logits)
     probs_or_logits = probs if probs is not None else logits
     if isinstance(probs_or_logits, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = probs_or_logits.size()
     super(Geometric, self).__init__(batch_shape, validate_args=validate_args)
开发者ID:RichieMay,项目名称:pytorch,代码行数:15,代码来源:geometric.py

示例4: __init__

 def __init__(self, loc, scale):
     self.loc, self.scale = broadcast_all(loc, scale)
     if isinstance(loc, Number) and isinstance(scale, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.loc.size()
     super(Laplace, self).__init__(batch_shape)
开发者ID:MaheshBhosale,项目名称:pytorch,代码行数:7,代码来源:laplace.py

示例5: __init__

 def __init__(self, rate, validate_args=None):
     self.rate, = broadcast_all(rate)
     if isinstance(rate, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.rate.size()
     super(Poisson, self).__init__(batch_shape, validate_args=validate_args)
开发者ID:RichieMay,项目名称:pytorch,代码行数:7,代码来源:poisson.py

示例6: __init__

 def __init__(self, alpha, beta):
     self.alpha, self.beta = broadcast_all(alpha, beta)
     if isinstance(alpha, Number) and isinstance(beta, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.alpha.size()
     super(Gamma, self).__init__(batch_shape)
开发者ID:lxlhh,项目名称:pytorch,代码行数:7,代码来源:gamma.py

示例7: __init__

 def __init__(self, scale, alpha):
     self.scale, self.alpha = broadcast_all(scale, alpha)
     if isinstance(scale, Number) and isinstance(alpha, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.scale.size()
     super(Pareto, self).__init__(batch_shape)
开发者ID:bhuWenDongchao,项目名称:pytorch,代码行数:7,代码来源:pareto.py

示例8: __init__

 def __init__(self, low, high):
     self.low, self.high = broadcast_all(low, high)
     if isinstance(low, Number) and isinstance(high, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.low.size()
     super(Uniform, self).__init__(batch_shape)
开发者ID:bhuWenDongchao,项目名称:pytorch,代码行数:7,代码来源:uniform.py

示例9: __init__

 def __init__(self, concentration, rate):
     self.concentration, self.rate = broadcast_all(concentration, rate)
     if isinstance(concentration, Number) and isinstance(rate, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.concentration.size()
     super(Gamma, self).__init__(batch_shape)
开发者ID:MaheshBhosale,项目名称:pytorch,代码行数:7,代码来源:gamma.py

示例10: __init__

 def __init__(self, loc, scale, validate_args=None):
     self.loc, self.scale = broadcast_all(loc, scale)
     if isinstance(loc, Number) and isinstance(scale, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.loc.size()
     super(Normal, self).__init__(batch_shape, validate_args=validate_args)
开发者ID:RichieMay,项目名称:pytorch,代码行数:7,代码来源:normal.py

示例11: __init__

 def __init__(self, alpha, beta):
     if isinstance(alpha, Number) and isinstance(beta, Number):
         alpha_beta = torch.Tensor([alpha, beta])
     else:
         alpha, beta = broadcast_all(alpha, beta)
         alpha_beta = torch.stack([alpha, beta], -1)
     self._dirichlet = Dirichlet(alpha_beta)
     super(Beta, self).__init__(self._dirichlet._batch_shape)
开发者ID:lxlhh,项目名称:pytorch,代码行数:8,代码来源:beta.py

示例12: __init__

 def __init__(self, concentration1, concentration0, validate_args=None):
     if isinstance(concentration1, Number) and isinstance(concentration0, Number):
         concentration1_concentration0 = torch.tensor([float(concentration1), float(concentration0)])
     else:
         concentration1, concentration0 = broadcast_all(concentration1, concentration0)
         concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)
     self._dirichlet = Dirichlet(concentration1_concentration0)
     super(Beta, self).__init__(self._dirichlet._batch_shape, validate_args=validate_args)
开发者ID:gtgalone,项目名称:pytorch,代码行数:8,代码来源:beta.py

示例13: log_prob

 def log_prob(self, value):
     self._validate_log_prob_arg(value)
     logits, value = broadcast_all(self.logits.clone(), value)
     log_factorial_n = torch.lgamma(value.sum(-1) + 1)
     log_factorial_xs = torch.lgamma(value + 1).sum(-1)
     logits[(value == 0) & (logits == -float('inf'))] = 0
     log_powers = (logits * value).sum(-1)
     return log_factorial_n - log_factorial_xs + log_powers
开发者ID:bhuWenDongchao,项目名称:pytorch,代码行数:8,代码来源:multinomial.py

示例14: __init__

 def __init__(self, concentration1, concentration0):
     if isinstance(concentration1, Number) and isinstance(concentration0, Number):
         concentration1_concentration0 = variable([concentration1, concentration0])
     else:
         concentration1, concentration0 = broadcast_all(concentration1, concentration0)
         concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)
     self._dirichlet = Dirichlet(concentration1_concentration0)
     super(Beta, self).__init__(self._dirichlet._batch_shape)
开发者ID:Jsmilemsj,项目名称:pytorch,代码行数:8,代码来源:beta.py

示例15: __init__

    def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
        if (probs is None) == (logits is None):
            raise ValueError("Either `probs` or `logits` must be specified, but not both.")
        if probs is not None:
            self.total_count, self.probs, = broadcast_all(total_count, probs)
            self.total_count = self.total_count.type_as(self.logits)
            is_scalar = isinstance(self.probs, Number)
        else:
            self.total_count, self.logits, = broadcast_all(total_count, logits)
            self.total_count = self.total_count.type_as(self.logits)
            is_scalar = isinstance(self.logits, Number)

        self._param = self.probs if probs is not None else self.logits
        if is_scalar:
            batch_shape = torch.Size()
        else:
            batch_shape = self._param.size()
        super(Binomial, self).__init__(batch_shape, validate_args=validate_args)
开发者ID:RichieMay,项目名称:pytorch,代码行数:18,代码来源:binomial.py


注:本文中的torch.distributions.utils.broadcast_all函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。