当前位置: 首页>>代码示例>>Python>>正文


Python utils.broadcast_all方法代码示例

本文整理汇总了Python中torch.distributions.utils.broadcast_all方法的典型用法代码示例。如果您正苦于以下问题:Python utils.broadcast_all方法的具体用法?Python utils.broadcast_all怎么用?Python utils.broadcast_all使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.distributions.utils的用法示例。


在下文中一共展示了utils.broadcast_all方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from torch.distributions import utils [as 别名]
# 或者: from torch.distributions.utils import broadcast_all [as 别名]
def __init__(self, a, b, sigma=0.01, validate_args=False, transform=None):
        TModule.__init__(self)
        _a = torch.tensor(float(a)) if isinstance(a, Number) else a
        _a = _a.view(-1) if _a.dim() < 1 else _a
        _a, _b, _sigma = broadcast_all(_a, b, sigma)
        if not torch.all(constraints.less_than(_b).check(_a)):
            raise ValueError("must have that a < b (element-wise)")
        # TODO: Proper argument validation including broadcasting
        batch_shape, event_shape = _a.shape[:-1], _a.shape[-1:]
        # need to assign values before registering as buffers to make argument validation work
        self.a, self.b, self.sigma = _a, _b, _sigma
        super(SmoothedBoxPrior, self).__init__(batch_shape, event_shape, validate_args=validate_args)
        # now need to delete to be able to register buffer
        del self.a, self.b, self.sigma
        self.register_buffer("a", _a)
        self.register_buffer("b", _b)
        self.register_buffer("sigma", _sigma.clone())
        self.tails = NormalPrior(torch.zeros_like(_a), _sigma, validate_args=validate_args)
        self._transform = transform 
开发者ID:cornellius-gp,项目名称:gpytorch,代码行数:21,代码来源:smoothed_box_prior.py

示例2: compile

# 需要导入模块: from torch.distributions import utils [as 别名]
# 或者: from torch.distributions.utils import broadcast_all [as 别名]
def compile(self) -> 'DynamicMatrix':
        """
        Consolidate assignments then apply the link function. Some assignments can be "frozen" into a pre-computed
        matrix, while others must remain as lists to be evaluated in DesignForBatch as needed.
        """
        from torch_kalman.process.utils.design_matrix.dynamic_matrix import DynamicMatrix

        base_mat = torch.zeros(self.num_groups, len(self.dim1_names), len(self.dim2_names))
        dynamic_assignments = {}
        for (dim1, dim2), values in self._assignments.items():
            r = self.dim1_names.index(dim1)
            c = self.dim2_names.index(dim2)
            ilink = self._ilinks[(dim1, dim2)] or identity
            dynamic, base = bifurcate(values, _is_dynamic_assignment)
            if dynamic:
                # if any dynamic, then all dynamic:
                dynamic = dynamic + [[x] * self.num_timesteps for x in base]
                per_timestep = list(zip(*dynamic))  # invert
                assert len(per_timestep) == self.num_timesteps
                assert (r, c) not in dynamic_assignments.keys()
                dynamic_assignments[(r, c)] = [
                    ilink(torch.sum(torch.stack(broadcast_all(*x), dim=0), dim=0)) for x in per_timestep
                ]
            else:
                base_mat[:, r, c] = ilink(torch.sum(torch.stack(broadcast_all(*base), dim=0), dim=0))

        return DynamicMatrix(base_mat, dynamic_assignments)

    # utils ------------------------------------------ 
开发者ID:strongio,项目名称:torch-kalman,代码行数:31,代码来源:base.py

示例3: __init__

# 需要导入模块: from torch.distributions import utils [as 别名]
# 或者: from torch.distributions.utils import broadcast_all [as 别名]
def __init__(self, loc, concentration, validate_args=None):
        self.loc, self.concentration = broadcast_all(loc, concentration)
        batch_shape = self.loc.shape
        event_shape = torch.Size()

        # Parameters for sampling
        tau = 1 + (1 + 4 * self.concentration ** 2).sqrt()
        rho = (tau - (2 * tau).sqrt()) / (2 * self.concentration)
        self._proposal_r = (1 + rho ** 2) / (2 * rho)

        super(VonMises, self).__init__(batch_shape, event_shape, validate_args) 
开发者ID:mjendrusch,项目名称:torchsupport,代码行数:13,代码来源:von_mises.py

示例4: __init__

# 需要导入模块: from torch.distributions import utils [as 别名]
# 或者: from torch.distributions.utils import broadcast_all [as 别名]
def __init__(self, loc, scale, manifold, validate_args=None, softplus=False):
        self.dtype = loc.dtype
        self.softplus = softplus
        self.loc, self._scale = broadcast_all(loc, scale)
        self.manifold = manifold
        self.manifold.assert_check_point_on_manifold(self.loc)
        self.device = loc.device
        if isinstance(loc, Number) and isinstance(scale, Number):
            batch_shape, event_shape = torch.Size(), torch.Size()
        else:
            batch_shape = self.loc.shape[:-1]
            event_shape = torch.Size([self.manifold.dim])
        super(WrappedNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args) 
开发者ID:emilemathieu,项目名称:pvae,代码行数:15,代码来源:wrapped_normal.py

示例5: __init__

# 需要导入模块: from torch.distributions import utils [as 别名]
# 或者: from torch.distributions.utils import broadcast_all [as 别名]
def __init__(self, loc, scale, validate_args=None):
        self.loc, self.scale = broadcast_all(loc, scale)
        if isinstance(loc, Number) and isinstance(scale, Number):
            batch_shape = torch.Size()
        else:
            batch_shape = self.loc.size()
        super(Normal_, self).__init__(batch_shape, validate_args=validate_args) 
开发者ID:conormdurkan,项目名称:autoregressive-energy-machines,代码行数:9,代码来源:distributions_.py


注:本文中的torch.distributions.utils.broadcast_all方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。