当前位置: 首页>>代码示例>>Python>>正文


Python torch.typename函数代码示例

本文整理汇总了Python中torch.typename函数的典型用法代码示例。如果您正苦于以下问题:Python typename函数的具体用法?Python typename怎么用?Python typename使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了typename函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __repr__

    def __repr__(self):
        if self.is_sparse:
            data_str = ' \n{} with indices:\n{}and values:\n{}'.format(
                torch.typename(self.data), self._indices().data,
                self._values().data)
        else:
            data_str = torch._tensor_str._str(self.data, False)
        strt = 'Variable containing:' + data_str
        # let's make our own Variable-specific footer
        size_str = '(' + ','.join(str(size) for size in self.size()) + (',)' if len(self.size()) == 1 else ')')
        device_str = '' if not self.is_cuda else \
            ' (GPU {})'.format(self.get_device())
        strt += '[{} of size {}{}]\n'.format(torch.typename(self.data),
                                             size_str, device_str)

        # All strings are unicode in Python 3, while we have to encode unicode
        # strings in Python2. If we can't, let python decide the best
        # characters to replace unicode characters with.
        if sys.version_info > (3,):
            return strt
        else:
            if hasattr(sys.stdout, 'encoding'):
                return strt.encode(
                    sys.stdout.encoding or 'UTF-8', 'replace')
            else:
                return strt.encode('UTF-8', 'replace')
开发者ID:bhuWenDongchao,项目名称:pytorch,代码行数:26,代码来源:variable.py

示例2: _reinforce

 def _reinforce(self, reward):
     is_number = isinstance(reward, Number)
     if not is_number and type(reward) != self.reward_info[0]:
         raise TypeError("mismatch between reward and output type: got {}, "
                         "but expected {}".format(torch.typename(reward),
                                                  torch.typename(self.reward_info[0])))
     if not is_number and reward.size() != self.reward_info[1]:
         raise ValueError("got reward of size {}, but expected a tensor of size {}".format(
                          'x'.join(map(str, reward.size())),
                          'x'.join(map(str, self.reward_info[1]))))
     if self.reward is not _NOT_PROVIDED:
         raise RuntimeError("you can only reinforce a stochastic Function once")
     self.reward = reward
开发者ID:Northrend,项目名称:pytorch,代码行数:13,代码来源:stochastic_function.py

示例3: forward

    def forward(ctx, input, *params):
        ctx._backend = type2backend[input.type()]

        ctx.additional_args = []
        tensor_param_list = []
        for param in params:
            if torch.is_tensor(param):
                if type(param) != type(input):
                    raise RuntimeError("input type ({}) doesn't match the type of "
                                       "a parameter tensor ({})".format(torch.typename(input),
                                                                        torch.typename(param)))
                tensor_param_list.append(param)
            else:
                ctx.additional_args.append(param)

        tensor_params = tuple(tensor_param_list)
        if is_inplace:
            ctx.inplace = params[-1]
        # Allocate temporary buffers and insert them into additional_args
        ctx.buffers = defaultdict(type(input))
        additional_args = _initialize_buffers(ctx, 'update_output')

        # Fill in optional params with None
        args = tensor_params
        for i in range(len(params), len(expected_params)):
            param = expected_params[i]
            if param.is_optional:
                args += (None,)
            else:
                raise ValueError("missing required argument '%s'" % param.name)

        args += tuple(additional_args)

        # If the module is working in-place its output will be set to the
        # same storage as input, but its variable won't be dirty.
        if is_inplace and ctx.inplace:
            ctx.mark_dirty(input)
            output = input
        else:
            output = input.new()

        if save_output:
            ctx.save_for_backward(input, output, *tensor_params)
        else:
            ctx.save_for_backward(input, *tensor_params)

        if not ctx.requires_grad:
            del ctx.buffers

        getattr(ctx._backend, update_output.name)(ctx._backend.library_state, input, output, *args)
        return output
开发者ID:Jsmilemsj,项目名称:pytorch,代码行数:51,代码来源:auto.py

示例4: __str__

    def __str__(self):
        if not self.__dict__:
            return 'Empty {} instance'.format(torch.typename(self))

        fields_to_index = filter(lambda field: field is not None, self.fields)
        var_strs = '\n'.join(['\t[.' + name + ']' + ":" + _short_str(getattr(self, name))
                              for name in fields_to_index if hasattr(self, name)])

        data_str = (' from {}'.format(self.dataset.name.upper())
                    if hasattr(self.dataset, 'name') and
                    isinstance(self.dataset.name, str) else '')

        strt = '[{} of size {}{}]\n{}'.format(torch.typename(self),
                                              self.batch_size, data_str, var_strs)
        return '\n' + strt
开发者ID:tu-artem,项目名称:text,代码行数:15,代码来源:batch.py

示例5: __bool__

 def __bool__(self):
     if self.numel() == 0:
         return False
     elif self.numel() == 1:
         return torch.squeeze(self)[0] != 0
     raise RuntimeError("bool value of " + torch.typename(self) +
                        " containing more than one value is ambiguous")
开发者ID:lxlhh,项目名称:pytorch,代码行数:7,代码来源:tensor.py

示例6: _lazyInit

 def _lazyInit(self):
     if self._output is None:
         self._output = self.output.new()
     if self._indices is None:
         self._indices = \
             (torch.cuda.LongTensor() if torch.typename(self.output) == 'torch.cuda.FloatTensor'
              else torch.LongTensor())
开发者ID:Northrend,项目名称:pytorch,代码行数:7,代码来源:Min.py

示例7: recursiveType

def recursiveType(param, type, tensorCache={}):
    from .Criterion import Criterion
    from .Module import Module
    if isinstance(param, list):
        for i, p in enumerate(param):
            param[i] = recursiveType(p, type, tensorCache)
    elif isinstance(param, Module) or isinstance(param, Criterion):
        param.type(type, tensorCache)
    elif torch.is_tensor(param):
        if torch.typename(param) != type:
            key = param._cdata
            if key in tensorCache:
                newparam = tensorCache[key]
            else:
                newparam = torch.Tensor().type(type)
                storageType = type.replace('Tensor', 'Storage')
                param_storage = param.storage()
                if param_storage:
                    storage_key = param_storage._cdata
                    if storage_key not in tensorCache:
                        tensorCache[storage_key] = torch._import_dotted_name(
                            storageType)(param_storage.size()).copy_(param_storage)
                    newparam.set_(
                        tensorCache[storage_key],
                        param.storage_offset(),
                        param.size(),
                        param.stride()
                    )
                tensorCache[key] = newparam
            param = newparam
    return param
开发者ID:Northrend,项目名称:pytorch,代码行数:31,代码来源:utils.py

示例8: _str

def _str(self):
    if self.ndimension() == 0:
        return '[{} with no dimension]\n'.format(torch.typename(self))
    elif self.ndimension() == 1:
        strt = _vector_str(self)
    elif self.ndimension() == 2:
        strt = _matrix_str(self)
    else:
        strt = _tensor_str(self)

    size_str = 'x'.join(str(size) for size in self.size())
    device_str = '' if not self.is_cuda else \
        ' (GPU {})'.format(self.get_device())
    strt += '[{} of size {}{}]\n'.format(torch.typename(self),
                                         size_str, device_str)
    return '\n' + strt
开发者ID:Northrend,项目名称:pytorch,代码行数:16,代码来源:_tensor_str.py

示例9: _check_container_source

 def _check_container_source(container_type, source_file, original_source):
     current_source = inspect.getsource(container_type)
     if original_source != current_source:
         if container_type.dump_patches:
             file_name = container_type.__name__ + '.patch'
             diff = difflib.unified_diff(current_source.split('\n'),
                                         original_source.split('\n'),
                                         source_file,
                                         source_file, lineterm="")
             lines = '\n'.join(diff)
             try:
                 with open(file_name, 'a+') as f:
                     file_size = f.seek(0, 2)
                     f.seek(0)
                     if file_size == 0:
                         f.write(lines)
                     elif file_size != len(lines) or f.read() != lines:
                         raise IOError
                 msg = ("Saved a reverse patch to " + file_name + ". "
                        "Run `patch -p0 < " + file_name + "` to revert your "
                        "changes.")
             except IOError:
                 msg = ("Tried to save a patch, but couldn't create a "
                        "writable file " + file_name + ". Make sure it "
                        "doesn't exist and your working directory is "
                        "writable.")
         else:
             msg = ("you can retrieve the original source code by "
                    "accessing the object's source attribute or set "
                    "`torch.nn.Module.dump_patches = True` and use the "
                    "patch tool to revert the changes.")
         msg = ("source code of class '{}' has changed. {}"
                .format(torch.typename(container_type), msg))
         warnings.warn(msg, SourceChangeWarning)
开发者ID:Northrend,项目名称:pytorch,代码行数:34,代码来源:serialization.py

示例10: register_parameter

    def register_parameter(self, name, param):
        r"""Adds a parameter to the module.

        The parameter can be accessed as an attribute using given name.

        Args:
            name (string): name of the parameter. The parameter can be accessed
                from this module using the given name
            parameter (Parameter): parameter to be added to the module.
        """
        if '_parameters' not in self.__dict__:
            raise AttributeError(
                "cannot assign parameter before Module.__init__() call")

        elif hasattr(self, name) and name not in self._parameters:
            raise KeyError("attribute '{}' already exists".format(name))
        elif '.' in name:
            raise KeyError("parameter name can't contain \".\"")
        elif name == '':
            raise KeyError("parameter name can't be empty string \"\"")

        if param is None:
            self._parameters[name] = None
        elif not isinstance(param, Parameter):
            raise TypeError("cannot assign '{}' object to parameter '{}' "
                            "(torch.nn.Parameter or None required)"
                            .format(torch.typename(param), name))
        elif param.grad_fn:
            raise ValueError(
                "Cannot assign non-leaf Tensor to parameter '{0}'. Model "
                "parameters must be created explicitly. To express '{0}' "
                "as a function of another Tensor, compute the value in "
                "the forward() method.".format(name))
        else:
            self._parameters[name] = param
开发者ID:RichieMay,项目名称:pytorch,代码行数:35,代码来源:module.py

示例11: test_Copy

    def test_Copy(self):
        input = torch.randn(3, 4).double()
        c = nn.Copy(torch.DoubleTensor, torch.FloatTensor)
        output = c.forward(input)
        self.assertEqual(torch.typename(output), 'torch.FloatTensor')
        self.assertEqual(output, input.float(), 1e-6)
        gradInput = c.backward(input, output.fill_(1))
        self.assertEqual(torch.typename(gradInput), 'torch.DoubleTensor')
        self.assertEqual(gradInput, output.double(), 1e-6)
        c.dontCast = True
        c.double()
        self.assertEqual(torch.typename(output), 'torch.FloatTensor')

        # Check that these don't raise errors
        c.__repr__()
        str(c)
开发者ID:bhuWenDongchao,项目名称:pytorch,代码行数:16,代码来源:test_legacy_nn.py

示例12: location_tag

def location_tag(storage):
    for _, tagger, _ in _package_registry:
        location = tagger(storage)
        if location:
            return location
    raise RuntimeError("don't know how to determine data location of " +
                       torch.typename(storage))
开发者ID:Northrend,项目名称:pytorch,代码行数:7,代码来源:serialization.py

示例13: vector_to_parameters

def vector_to_parameters(vec, parameters):
    r"""Convert one vector to the parameters

    Arguments:
        vec (Variable): a single vector represents the parameters of a model.
        parameters (Iterable[Variable]): an iterator of Variables that are the
            parameters of a model.
    """
    # Ensure vec of type Variable
    if not isinstance(vec, Variable):
        raise TypeError('expected torch.autograd.Variable, but got: {}'
                        .format(torch.typename(vec)))
    # Flag for the device where the parameter is located
    param_device = None

    # Pointer for slicing the vector for each parameter
    pointer = 0
    for param in parameters:
        # Ensure the parameters are located in the same device
        param_device = _check_param_device(param, param_device)

        # The length of the parameter
        num_param = torch.prod(torch.LongTensor(list(param.size())))
        # Slice the vector, reshape it, and replace the old data of the parameter
        param.data = vec[pointer:pointer + num_param].view(param.size()).data

        # Increment the pointer
        pointer += num_param
开发者ID:Jsmilemsj,项目名称:pytorch,代码行数:28,代码来源:convert_parameters.py

示例14: register_buffer

    def register_buffer(self, name, tensor):
        r"""Adds a persistent buffer to the module.

        This is typically used to register a buffer that should not to be
        considered a model parameter. For example, BatchNorm's ``running_mean``
        is not a parameter, but is part of the persistent state.

        Buffers can be accessed as attributes using given names.

        Args:
            name (string): name of the buffer. The buffer can be accessed
                from this module using the given name
            tensor (Tensor): buffer to be registered.

        Example::

            >>> self.register_buffer('running_mean', torch.zeros(num_features))

        """
        if hasattr(self, name) and name not in self._buffers:
            raise KeyError("attribute '{}' already exists".format(name))
        elif '.' in name:
            raise KeyError("buffer name can't contain \".\"")
        elif name == '':
            raise KeyError("buffer name can't be empty string \"\"")
        elif tensor is not None and not isinstance(tensor, torch.Tensor):
            raise TypeError("cannot assign '{}' object to buffer '{}' "
                            "(torch Tensor or None required)"
                            .format(torch.typename(tensor), name))
        else:
            self._buffers[name] = tensor
开发者ID:RichieMay,项目名称:pytorch,代码行数:31,代码来源:module.py

示例15: default_restore_location

def default_restore_location(storage, location):
    for _, _, fn in _package_registry:
        result = fn(storage, location)
        if result is not None:
            return result
    raise RuntimeError("don't know how to restore data location of " +
                       torch.typename(storage) + " (tagged with " +
                       location + ")")
开发者ID:Northrend,项目名称:pytorch,代码行数:8,代码来源:serialization.py


注:本文中的torch.typename函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。