当前位置: 首页>>代码示例>>Python>>正文


Python torch.get_default_dtype方法代码示例

本文整理汇总了Python中torch.get_default_dtype方法的典型用法代码示例。如果您正苦于以下问题:Python torch.get_default_dtype方法的具体用法?Python torch.get_default_dtype怎么用?Python torch.get_default_dtype使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.get_default_dtype方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_mu_law_companding

# 需要导入模块: import torch [as 别名]
# 或者: from torch import get_default_dtype [as 别名]
def test_mu_law_companding(self):

        quantization_channels = 256

        waveform = self.waveform.clone()
        if not waveform.is_floating_point():
            waveform = waveform.to(torch.get_default_dtype())
        waveform /= torch.abs(waveform).max()

        self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)

        waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
        self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)

        waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
        self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.) 
开发者ID:pytorch,项目名称:audio,代码行数:18,代码来源:test_transforms.py

示例2: _test_istft_of_sine

# 需要导入模块: import torch [as 别名]
# 或者: from torch import get_default_dtype [as 别名]
def _test_istft_of_sine(self, amplitude, L, n):
        # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
        x = torch.arange(2 * L + 1, dtype=torch.get_default_dtype())
        sound = amplitude * torch.sin(2 * math.pi / L * x * n)
        # stft = torch.stft(sound, L, hop_length=L, win_length=L,
        #                   window=torch.ones(L), center=False, normalized=False)
        stft = torch.zeros((L // 2 + 1, 2, 2))
        stft_largest_val = (amplitude * L) / 2.0
        if n < stft.size(0):
            stft[n, :, 1] = -stft_largest_val

        if 0 <= L - n < stft.size(0):
            # symmetric about L // 2
            stft[L - n, :, 1] = stft_largest_val

        estimate = torchaudio.functional.istft(stft, L, hop_length=L, win_length=L,
                                               window=torch.ones(L), center=False, normalized=False)
        # There is a larger error due to the scaling of amplitude
        _compare_estimate(sound, estimate, atol=1e-3) 
开发者ID:pytorch,项目名称:audio,代码行数:21,代码来源:functional_cpu_test.py

示例3: __init__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import get_default_dtype [as 别名]
def __init__(self, manifold: Manifold, scale=1.0, learnable=False):
        super().__init__()
        self.base = manifold
        scale = torch.as_tensor(scale, dtype=torch.get_default_dtype())
        scale = scale.requires_grad_(False)
        if not learnable:
            self.register_buffer("_scale", scale)
            self.register_buffer("_log_scale", None)
        else:
            self.register_buffer("_scale", None)
            self.register_parameter("_log_scale", torch.nn.Parameter(scale.log()))
        # do not rebuild scaled functions very frequently, save them

        for method, scaling_info in self.base.__scaling__.items():
            # register rescaled functions as bound methods of this particular instance
            unbound_method = getattr(self.base, method).__func__  # unbound method
            self.__setattr__(
                method, types.MethodType(rescale(unbound_method, scaling_info), self)
            ) 
开发者ID:geoopt,项目名称:geoopt,代码行数:21,代码来源:scaled.py

示例4: penalty

# 需要导入模块: import torch [as 别名]
# 或者: from torch import get_default_dtype [as 别名]
def penalty(self, s):
        """
        Calculate L1 Penalty.
        """
        to_return = torch.sum(s) / self.D
        if self.soft_groups is not None:
            # if soft_groups, there is an additional penalty for using more
            # groups
            s_grouped = torch.zeros(self.soft_D, 1,
                                    dtype=torch.get_default_dtype(),
                                    device=self.device)
            for group in torch.unique(self.soft_groups):
                # groups should be indexed 0 to n_group - 1
                # TODO: consider other functions here
                s_grouped[group] = s[self.soft_groups == group].max()
            # each component of the penalty contributes .5
            # TODO: could make this a user given parameter
            to_return = (to_return + torch.sum(s_grouped) / self.soft_D) * .5
        return to_return 
开发者ID:microsoft,项目名称:nni,代码行数:21,代码来源:learnability.py

示例5: forward_and_backward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import get_default_dtype [as 别名]
def forward_and_backward(self, s, xsub, ysub, retain_graph=False):
        """
        Completes the forward operation and computes gradients for learnability and penalty.
        """
        f_train = self.f_train(s, xsub, ysub)
        pen = self.penalty(s)
        # pylint: disable=E1102
        grad_outputs = torch.tensor([[1]], dtype=torch.get_default_dtype(),
                                    device=self.device)
        g1, = torch.autograd.grad([f_train], [self.x], grad_outputs,
                                  retain_graph=True)
        # pylint: disable=E1102
        grad_outputs = torch.tensor([[1]], dtype=torch.get_default_dtype(),
                                    device=self.device)
        g2, = torch.autograd.grad([pen], [self.x], grad_outputs,
                                  retain_graph=retain_graph)
        return f_train, pen, g1, g2 
开发者ID:microsoft,项目名称:nni,代码行数:19,代码来源:learnability.py

示例6: scale

# 需要导入模块: import torch [as 别名]
# 或者: from torch import get_default_dtype [as 别名]
def scale(self, waveform, factor=2.0**31):
        # scales a waveform by a factor
        if not waveform.is_floating_point():
            waveform = waveform.to(torch.get_default_dtype())
        return waveform / factor 
开发者ID:pytorch,项目名称:audio,代码行数:7,代码来源:test_transforms.py

示例7: use_floatX

# 需要导入模块: import torch [as 别名]
# 或者: from torch import get_default_dtype [as 别名]
def use_floatX(request):
    dtype_old = torch.get_default_dtype()
    torch.set_default_dtype(request.param)
    yield request.param
    torch.set_default_dtype(dtype_old) 
开发者ID:geoopt,项目名称:geoopt,代码行数:7,代码来源:test_manifold_basic.py

示例8: __enter__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import get_default_dtype [as 别名]
def __enter__(self):
        """Set new dtype."""
        self.old_dtype = torch.get_default_dtype()
        torch.set_default_dtype(self.new_dtype) 
开发者ID:befelix,项目名称:safe-exploration,代码行数:6,代码来源:utilities.py

示例9: test_set_torch_dtype

# 需要导入模块: import torch [as 别名]
# 或者: from torch import get_default_dtype [as 别名]
def test_set_torch_dtype():
    """Test dtype context manager."""
    dtype = torch.get_default_dtype()

    torch.set_default_dtype(torch.float32)
    with SetTorchDtype(torch.float64):
        a = torch.zeros(1)

    assert a.dtype is torch.float64
    b = torch.zeros(1)
    assert b.dtype is torch.float32

    torch.set_default_dtype(dtype) 
开发者ID:befelix,项目名称:safe-exploration,代码行数:15,代码来源:test_utilities.py

示例10: __init__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import get_default_dtype [as 别名]
def __init__(self, *sizes, dtype=None, device=None):
        super(ZeroLazyTensor, self).__init__(*sizes)
        self.sizes = list(sizes)

        self._dtype = dtype or torch.get_default_dtype()
        self._device = device or torch.device("cpu") 
开发者ID:cornellius-gp,项目名称:gpytorch,代码行数:8,代码来源:zero_lazy_tensor.py

示例11: dtype

# 需要导入模块: import torch [as 别名]
# 或者: from torch import get_default_dtype [as 别名]
def dtype(self):
        if self.has_lengthscale:
            return self.lengthscale.dtype
        else:
            for param in self.parameters():
                return param.dtype
            return torch.get_default_dtype() 
开发者ID:cornellius-gp,项目名称:gpytorch,代码行数:9,代码来源:kernel.py

示例12: __init__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import get_default_dtype [as 别名]
def __init__(self, network, heading_dim=2, pre_norm=False, separate=True, get_prediction=False, weight=None):
        """
        Calculate heading angle with/out loss.
        The heading angle is absolute heading from the point person starts moving

        :param network: Network model
            - input - imu features
            - output - absolute_heading, prenorm_abs
        :param pre_norm: force network to output normalized values by adding a loss
        :param separate: report errors separately with total (The #losses can be obtained from get_channels()).
                        If False, only (weighted) sum of all losses will be returned.
        :param get_prediction: For testing phase, to get prediction values. In training only losses will be returned.
        :param weight: weight for each loss type (should ensure enough channels are given). If not all will be
                        weighed equally.
        """
        super(HeadingNetwork, self).__init__()
        self.network = network
        self.h_dim = heading_dim
        self.pre_norm = pre_norm
        self.concat_loss = not separate

        self.predict = get_prediction

        losses, channels = self.get_channels()
        if self.predict or weight is None or weight in ('None', 'none', False):
            self.weights = torch.ones(channels, dtype=torch.get_default_dtype(), device=_device)
        else:
            assert len(weight) == channels
            self.weights = torch.tensor(weight).to(_device) 
开发者ID:Sachini,项目名称:ronin,代码行数:31,代码来源:ronin_body_heading.py

示例13: get_init

# 需要导入模块: import torch [as 别名]
# 或者: from torch import get_default_dtype [as 别名]
def get_init(data_train, init_type='on', rng=np.random.RandomState(0), prev_score=None):
    """
    Initialize the 'x' variable with different settings
    """

    D = data_train.n_features
    value_off = constants.Initialization.VALUE_DICT[
        constants.Initialization.OFF]
    value_on = constants.Initialization.VALUE_DICT[
        constants.Initialization.ON]

    if prev_score is not None:
        x0 = prev_score
    elif not isinstance(init_type, str):
        x0 = value_off * np.ones(D)
        x0[init_type] = value_on
    elif init_type.startswith(constants.Initialization.RANDOM):
        d = int(init_type.replace(constants.Initialization.RANDOM, ''))
        x0 = value_off * np.ones(D)
        x0[rng.permutation(D)[:d]] = value_on
    elif init_type == constants.Initialization.SKLEARN:
        B = data_train.return_raw
        X, y = data_train.get_dense_data()
        data_train.set_return_raw(B)
        ix = train_sk_dense(init_type, X, y, data_train.classification)
        x0 = value_off * np.ones(D)
        x0[ix] = value_on
    elif init_type in constants.Initialization.VALUE_DICT:
        x0 = constants.Initialization.VALUE_DICT[init_type] * np.ones(D)
    else:
        raise NotImplementedError(
            'init_type {0} not supported yet'.format(init_type))
    # pylint: disable=E1102
    return torch.tensor(x0.reshape((-1, 1)),
                        dtype=torch.get_default_dtype()) 
开发者ID:microsoft,项目名称:nni,代码行数:37,代码来源:fgtrain.py

示例14: __init__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import get_default_dtype [as 别名]
def __init__(self, Nminibatch, D, coeff, groups=None, binary=False,
                 device=constants.Device.CPU):
        super(LearnabilityMB, self).__init__()

        a = coeff / scipy.special.binom(Nminibatch, np.arange(coeff.size) + 2)
        self.order = a.size
        # pylint: disable=E1102
        self.a = torch.tensor(a, dtype=torch.get_default_dtype(), requires_grad=False)
        self.binary = binary

        self.a = self.a.to(device) 
开发者ID:microsoft,项目名称:nni,代码行数:13,代码来源:learnability.py

示例15: set_dense_X

# 需要导入模块: import torch [as 别名]
# 或者: from torch import get_default_dtype [as 别名]
def set_dense_X(self):
        if self.storage_level != constants.StorageLevel.DISK:
            if self.dense_size_gb <= self.MAXMEMGB:
                if self.storage_level == constants.StorageLevel.SPARSE:
                    self.X = self.X.toarray()
                self.X = torch.as_tensor(
                    self.X, dtype=torch.get_default_dtype())
                self.storage_level = constants.StorageLevel.DENSE 
开发者ID:microsoft,项目名称:nni,代码行数:10,代码来源:fginitialize.py


注:本文中的torch.get_default_dtype方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。