本文整理汇总了Python中torch._six.container_abcs.Iterable方法的典型用法代码示例。如果您正苦于以下问题:Python container_abcs.Iterable方法的具体用法?Python container_abcs.Iterable怎么用?Python container_abcs.Iterable使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch._six.container_abcs
的用法示例。
在下文中一共展示了container_abcs.Iterable方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _ntuple
# 需要导入模块: from torch._six import container_abcs [as 别名]
# 或者: from torch._six.container_abcs import Iterable [as 别名]
def _ntuple(n):
def parse(x):
if isinstance(x, container_abcs.Iterable):
return x
return tuple(repeat(x, n))
return parse
示例2: _ntuple
# 需要导入模块: from torch._six import container_abcs [as 别名]
# 或者: from torch._six.container_abcs import Iterable [as 别名]
def _ntuple(n):
def parse(x):
if isinstance(x, container_abcs.Iterable):
return x
return tuple(repeat(x, n))
return parse
示例3: __init__
# 需要导入模块: from torch._six import container_abcs [as 别名]
# 或者: from torch._six.container_abcs import Iterable [as 别名]
def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
super(CondConv2d, self).__init__()
assert num_experts > 1
if isinstance(stride, container_abcs.Iterable) and len(stride) == 1:
stride = stride[0]
# print('CondConv', num_experts)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
padding_val, is_padding_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
self.padding = _pair(padding_val)
self.dilation = _pair(dilation)
self.groups = groups
self.num_experts = num_experts
self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
weight_num_param = 1
for wd in self.weight_shape:
weight_num_param *= wd
self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
if bias:
self.bias_shape = (self.out_channels,)
self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
示例4: update
# 需要导入模块: from torch._six import container_abcs [as 别名]
# 或者: from torch._six.container_abcs import Iterable [as 别名]
def update(self, buffers):
r"""Update the :class:`~torch.nn.BufferDict` with the key-value pairs from a
mapping or an iterable, overwriting existing keys.
.. note::
If :attr:`buffers` is an ``OrderedDict``, a :class:`~torch.nn.BufferDict`,
or an iterable of key-value pairs, the order of new elements in it is
preserved.
Arguments:
buffers (iterable): a mapping (dictionary) from string to
:class:`~torch.Tensor`, or an iterable of
key-value pairs of type (string, :class:`~torch.Tensor`)
"""
if not isinstance(buffers, container_abcs.Iterable):
raise TypeError(
"BuffersDict.update should be called with an "
"iterable of key/value pairs, but got " + type(buffers).__name__
)
if isinstance(buffers, container_abcs.Mapping):
if isinstance(buffers, (OrderedDict, BufferDict)):
for key, buffer in buffers.items():
self[key] = buffer
else:
for key, buffer in sorted(buffers.items()):
self[key] = buffer
else:
for j, p in enumerate(buffers):
if not isinstance(p, container_abcs.Iterable):
raise TypeError(
"BufferDict update sequence element "
"#" + str(j) + " should be Iterable; is" + type(p).__name__
)
if not len(p) == 2:
raise ValueError(
"BufferDict update sequence element "
"#" + str(j) + " has length " + str(len(p)) + "; 2 is required"
)
self[p[0]] = p[1]
示例5: load_state_dict
# 需要导入模块: from torch._six import container_abcs [as 别名]
# 或者: from torch._six.container_abcs import Iterable [as 别名]
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
Arguments:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)
# Validate the state_dict
groups = self.param_groups
saved_groups = state_dict['param_groups']
if len(groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of "
"parameter groups")
param_lens = (len(g['params']) for g in groups)
saved_lens = (len(g['params']) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError("loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group")
# Update the state
id_map = {old_id: p for old_id, p in
zip(chain(*(g['params'] for g in saved_groups)),
chain(*(g['params'] for g in groups)))}
def cast(param, value):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
if param.is_floating_point():
value = value.to(param.dtype)
value = value.to(param.device)
return value
elif isinstance(value, dict):
return {k: cast(param, v) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable):
return type(value)(cast(param, v) for v in value)
else:
return value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
for k, v in state_dict['state'].items():
if k in id_map:
param = id_map[k]
state[param] = cast(param, v)
else:
state[k] = v
# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
new_group['params'] = group['params']
return new_group
param_groups = [
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups})