当前位置: 首页>>代码示例>>Python>>正文


Python torch.HalfTensor方法代码示例

本文整理汇总了Python中torch.HalfTensor方法的典型用法代码示例。如果您正苦于以下问题:Python torch.HalfTensor方法的具体用法?Python torch.HalfTensor怎么用?Python torch.HalfTensor使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.HalfTensor方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _worker_loop

# 需要导入模块: import torch [as 别名]
# 或者: from torch import HalfTensor [as 别名]
def _worker_loop(dataset, index_queue, data_queue, collate_fn):
    global _use_shared_memory
    _use_shared_memory = True

    # torch.set_num_threads(1)
    while True:
        r = index_queue.get()
        if r is None:
            data_queue.put(None)
            break
        idx, batch_indices = r
        try:
            samples = collate_fn([dataset[i] for i in batch_indices])
        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples))


# numpy_type_map = {
#     'float64': torch.DoubleTensor,
#     'float32': torch.FloatTensor,
#     'float16': torch.HalfTensor,
#     'int64': torch.LongTensor,
#     'int32': torch.IntTensor,
#     'int16': torch.ShortTensor,
#     'int8': torch.CharTensor,
#     'uint8': torch.ByteTensor,
# } 
开发者ID:Lyken17,项目名称:mxbox,代码行数:31,代码来源:torchloader.py

示例2: convert_tensor

# 需要导入模块: import torch [as 别名]
# 或者: from torch import HalfTensor [as 别名]
def convert_tensor(tensor: torch.Tensor, clone: bool) -> torch.Tensor:
    tensor = tensor.detach().cpu()
    if isinstance(tensor, torch.HalfTensor):
        # We convert any fp16 params to fp32 to make sure operations like
        # division by a scalar value are supported.
        tensor = tensor.float()
    elif clone:
        # tensor.float() would have effectively cloned the fp16 tensor already,
        # so we don't need to do it again even if clone=True.
        tensor = tensor.clone()
    return tensor 
开发者ID:pytorch,项目名称:translate,代码行数:13,代码来源:checkpoint.py

示例3: setUp

# 需要导入模块: import torch [as 别名]
# 或者: from torch import HalfTensor [as 别名]
def setUp(self):
        self._params_1 = OrderedDict(
            [
                ("double_tensor", torch.DoubleTensor([100.0])),
                ("float_tensor", torch.FloatTensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])),
                ("long_tensor", torch.LongTensor([7, 8, 9])),
                ("half_tensor", torch.HalfTensor([10.0, 20.0])),
            ]
        )
        self._params_2 = OrderedDict(
            [
                ("double_tensor", torch.DoubleTensor([1.0])),
                ("float_tensor", torch.FloatTensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])),
                # Any integer tensor must remain the same in all the params.
                ("long_tensor", torch.LongTensor([7, 8, 9])),
                ("half_tensor", torch.HalfTensor([50.0, 0.0])),
            ]
        )
        self._params_avg = OrderedDict(
            [
                ("double_tensor", torch.DoubleTensor([50.5])),
                ("float_tensor", torch.FloatTensor([[1.0, 1.5, 2.0], [2.5, 3.0, 3.5]])),
                ("long_tensor", torch.LongTensor([7, 8, 9])),
                # We convert fp16 to fp32 when averaging params.
                ("half_tensor", torch.FloatTensor([30.0, 10.0])),
            ]
        )

        self._fd_1, self._filename_1 = tempfile.mkstemp()
        self._fd_2, self._filename_2 = tempfile.mkstemp()
        torch.save(OrderedDict([("model", self._params_1)]), self._filename_1)
        torch.save(OrderedDict([("model", self._params_2)]), self._filename_2) 
开发者ID:pytorch,项目名称:translate,代码行数:34,代码来源:test_checkpoint.py

示例4: T

# 需要导入模块: import torch [as 别名]
# 或者: from torch import HalfTensor [as 别名]
def T(a, half=False, cuda=True):
    if not torch.is_tensor(a):
        a = np.array(np.ascontiguousarray(a))
        if a.dtype in (np.int8, np.int16, np.int32, np.int64):
            a = torch.LongTensor(a.astype(np.int64))
        elif a.dtype in (np.float32, np.float64):
            a = torch.cuda.HalfTensor(a) if half else torch.FloatTensor(a)
        else: raise NotImplementedError(a.dtype)
    if cuda: a = to_gpu(a, async=True)
    return a 
开发者ID:Esri,项目名称:raster-deep-learning,代码行数:12,代码来源:model.py

示例5: to_np

# 需要导入模块: import torch [as 别名]
# 或者: from torch import HalfTensor [as 别名]
def to_np(v):
    if isinstance(v, (np.ndarray, np.generic)): return v
    if isinstance(v, (list,tuple)): return [to_np(o) for o in v]
    if isinstance(v, Variable): v=v.data
    if USE_GPU:
        if isinstance(v, torch.cuda.HalfTensor): v=v.float()
    else:
        if isinstance(v, torch.HalfTensor): v=v.float()
    return v.cpu().numpy() 
开发者ID:Esri,项目名称:raster-deep-learning,代码行数:11,代码来源:model.py

示例6: default_collate_with_string

# 需要导入模块: import torch [as 别名]
# 或者: from torch import HalfTensor [as 别名]
def default_collate_with_string(batch):
    "Puts each data field into a tensor with outer dimension batch size"
    _use_shared_memory = False
    numpy_type_map = {
        'float64': torch.DoubleTensor,
        'float32': torch.FloatTensor,
        'float16': torch.HalfTensor,
        'int64': torch.LongTensor,
        'int32': torch.IntTensor,
        'int16': torch.ShortTensor,
        'int8': torch.CharTensor,
        'uint8': torch.ByteTensor,
    }
    string_classes = (str, bytes)
    if torch.is_tensor(batch[0]):
        #print("IN","torch.is_tensor(batch[0])")
        #IPython.embed()
        out = None
        if _use_shared_memory:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = batch[0].storage()._new_shared(numel)
            out = batch[0].new(storage)
        #print("batch:",[e.numpy().shape for e in batch])
        return torch.stack(batch, 0, out=out)
    elif type(batch[0]).__module__ == 'numpy':
        elem = batch[0]
        #print("IN", "type(batch[0]).__module__ == 'numpy'")
        #IPython.embed()
        if type(elem).__name__ == 'ndarray':
            if elem.dtype.kind in {'U', 'S'}:
                return np.stack(batch, 0)
            else:
                return torch.stack([torch.from_numpy(b) for b in batch], 0)
        if elem.shape == ():  # scalars
            py_type = float if elem.dtype.name.startswith('float') else int
            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
    elif isinstance(batch[0], int):
        return torch.LongTensor(batch)
    elif isinstance(batch[0], float):
        return torch.FloatTensor(batch)
    elif isinstance(batch[0], string_classes):
        return batch
    elif isinstance(batch[0], collections.Mapping):
        return {key: default_collate_with_string([d[key] for d in batch]) for key in batch[0]}
    elif isinstance(batch[0], collections.Sequence):
        transposed = zip(*batch)
        return [default_collate_with_string(samples) for samples in transposed]

    raise TypeError(("batch must contain tensors, numbers, dicts or lists; found {}"
                     .format(type(batch[0])))) 
开发者ID:hrhodin,项目名称:UnsupervisedGeometryAwareRepresentationLearning,代码行数:54,代码来源:datasets.py

示例7: average_checkpoints

# 需要导入模块: import torch [as 别名]
# 或者: from torch import HalfTensor [as 别名]
def average_checkpoints(inputs):
    """Loads checkpoints from inputs and returns a model with averaged weights.

    Args:
      inputs: An iterable of string paths of checkpoints to load from.

    Returns:
      A dict of string keys mapping to various values. The 'model' key
      from the returned dict should correspond to an OrderedDict mapping
      string parameter names to torch Tensors.
    """
    params_dict = collections.OrderedDict()
    params_keys = None
    new_state = None
    for f in inputs:
        state = torch.load(
            f,
            map_location=(
                lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
            ),
        )
        # Copies over the settings from the first checkpoint
        if new_state is None:
            new_state = state

        model_params = state['model']

        model_params_keys = list(model_params.keys())
        if params_keys is None:
            params_keys = model_params_keys
        elif params_keys != model_params_keys:
            raise KeyError(
                'For checkpoint {}, expected list of params: {}, '
                'but found: {}'.format(f, params_keys, model_params_keys)
            )

        for k in params_keys:
            if k not in params_dict:
                params_dict[k] = []
            p = model_params[k]
            if isinstance(p, torch.HalfTensor):
                p = p.float()
            params_dict[k].append(p)

    averaged_params = collections.OrderedDict()
    # v should be a list of torch Tensor.
    for k, v in params_dict.items():
        summed_v = None
        for x in v:
            summed_v = summed_v + x if summed_v is not None else x
        averaged_params[k] = summed_v / len(v)
    new_state['model'] = averaged_params
    return new_state 
开发者ID:nusnlp,项目名称:crosentgec,代码行数:55,代码来源:average_checkpoints.py

示例8: predict_transform_half

# 需要导入模块: import torch [as 别名]
# 或者: from torch import HalfTensor [as 别名]
def predict_transform_half(prediction, inp_dim, anchors, num_classes, CUDA = True):
    batch_size = prediction.size(0)
    stride =  inp_dim // prediction.size(2)

    bbox_attrs = 5 + num_classes
    num_anchors = len(anchors)
    grid_size = inp_dim // stride

    
    prediction = prediction.view(batch_size, bbox_attrs*num_anchors, grid_size*grid_size)
    prediction = prediction.transpose(1,2).contiguous()
    prediction = prediction.view(batch_size, grid_size*grid_size*num_anchors, bbox_attrs)
    
    
    #Sigmoid the  centre_X, centre_Y. and object confidencce
    prediction[:,:,0] = torch.sigmoid(prediction[:,:,0])
    prediction[:,:,1] = torch.sigmoid(prediction[:,:,1])
    prediction[:,:,4] = torch.sigmoid(prediction[:,:,4])

    
    #Add the center offsets
    grid_len = np.arange(grid_size)
    a,b = np.meshgrid(grid_len, grid_len)
    
    x_offset = torch.FloatTensor(a).view(-1,1)
    y_offset = torch.FloatTensor(b).view(-1,1)
    
    if CUDA:
        x_offset = x_offset.cuda().half()
        y_offset = y_offset.cuda().half()
    
    x_y_offset = torch.cat((x_offset, y_offset), 1).repeat(1,num_anchors).view(-1,2).unsqueeze(0)
    
    prediction[:,:,:2] += x_y_offset
      
    #log space transform height and the width
    anchors = torch.HalfTensor(anchors)
    
    if CUDA:
        anchors = anchors.cuda()
    
    anchors = anchors.repeat(grid_size*grid_size, 1).unsqueeze(0)
    prediction[:,:,2:4] = torch.exp(prediction[:,:,2:4])*anchors

    #Softmax the class scores
    prediction[:,:,5: 5 + num_classes] = nn.Softmax(-1)(Variable(prediction[:,:, 5 : 5 + num_classes])).data

    prediction[:,:,:4] *= stride
    
    
    return prediction 
开发者ID:MagicChuyi,项目名称:SlowFast-Network-pytorch,代码行数:53,代码来源:util.py

示例9: average_checkpoints

# 需要导入模块: import torch [as 别名]
# 或者: from torch import HalfTensor [as 别名]
def average_checkpoints(inputs):
    """Loads checkpoints from inputs and returns a model with averaged weights.

    Args:
      inputs: An iterable of string paths of checkpoints to load from.

    Returns:
      A dict of string keys mapping to various values. The 'model' key
      from the returned dict should correspond to an OrderedDict mapping
      string parameter names to torch Tensors.
    """
    params_dict = collections.OrderedDict()
    params_keys = None
    new_state = None
    num_models = len(inputs)

    for f in inputs:
        state = torch.load(
            f,
            map_location=(
                lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
            ),
        )
        # Copies over the settings from the first checkpoint
        if new_state is None:
            new_state = state

        model_params = state['model']

        model_params_keys = list(model_params.keys())
        if params_keys is None:
            params_keys = model_params_keys
        elif params_keys != model_params_keys:
            raise KeyError(
                'For checkpoint {}, expected list of params: {}, '
                'but found: {}'.format(f, params_keys, model_params_keys)
            )

        for k in params_keys:
            p = model_params[k]
            if isinstance(p, torch.HalfTensor):
                p = p.float()
            if k not in params_dict:
                params_dict[k] = p.clone()
                # NOTE: clone() is needed in case of p is a shared parameter
            else:
                params_dict[k] += p

    averaged_params = collections.OrderedDict()
    for k, v in params_dict.items():
        averaged_params[k] = v
        averaged_params[k].div_(num_models)
    new_state['model'] = averaged_params
    return new_state 
开发者ID:LiyuanLucasLiu,项目名称:RAdam,代码行数:56,代码来源:average_checkpoints.py

示例10: average_checkpoints

# 需要导入模块: import torch [as 别名]
# 或者: from torch import HalfTensor [as 别名]
def average_checkpoints(inputs):
    """Loads checkpoints from inputs and returns a model with averaged weights.

    Args:
      inputs: An iterable of string paths of checkpoints to load from.

    Returns:
      A dict of string keys mapping to various values. The 'model' key
      from the returned dict should correspond to an OrderedDict mapping
      string parameter names to torch Tensors.
    """
    params_dict = collections.OrderedDict()
    params_keys = None
    new_state = None
    num_models = len(inputs)

    for fpath in inputs:
        with PathManager.open(fpath, 'rb') as f:
            state = torch.load(
                f,
                map_location=(
                    lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
                ),
            )
        # Copies over the settings from the first checkpoint
        if new_state is None:
            new_state = state

        model_params = state['model']

        model_params_keys = list(model_params.keys())
        if params_keys is None:
            params_keys = model_params_keys
        elif params_keys != model_params_keys:
            raise KeyError(
                'For checkpoint {}, expected list of params: {}, '
                'but found: {}'.format(f, params_keys, model_params_keys)
            )

        for k in params_keys:
            p = model_params[k]
            if isinstance(p, torch.HalfTensor):
                p = p.float()
            if k not in params_dict:
                params_dict[k] = p.clone()
                # NOTE: clone() is needed in case of p is a shared parameter
            else:
                params_dict[k] += p

    averaged_params = collections.OrderedDict()
    for k, v in params_dict.items():
        averaged_params[k] = v
        if averaged_params[k].is_floating_point():
            averaged_params[k].div_(num_models)
        else:
            averaged_params[k] //= num_models
    new_state['model'] = averaged_params
    return new_state 
开发者ID:pytorch,项目名称:fairseq,代码行数:60,代码来源:average_checkpoints.py

示例11: load_optimizer_state_dict

# 需要导入模块: import torch [as 别名]
# 或者: from torch import HalfTensor [as 别名]
def load_optimizer_state_dict(optimizer, state_dict):
    # deepcopy, to be consistent with module API
    state_dict = deepcopy(state_dict)
    # Validate the state_dict
    groups = optimizer.param_groups
    saved_groups = state_dict['param_groups']

    if len(groups) != len(saved_groups):
        raise ValueError("loaded state dict has a different number of "
                         "parameter groups")
    param_lens = (len(g['params']) for g in groups)
    saved_lens = (len(g['params']) for g in saved_groups)
    if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
        raise ValueError("loaded state dict contains a parameter group "
                         "that doesn't match the size of optimizer's group")

    # Update the state
    id_map = {old_id: p for old_id, p in
                zip(chain(*(g['params'] for g in saved_groups)),
                    chain(*(g['params'] for g in groups)))}

    def cast(param, value):
        """Make a deep copy of value, casting all tensors to device of param."""
        if torch.is_tensor(value):
            # Floating-point types are a bit special here. They are the only ones
            # that are assumed to always match the type of params.
            if isinstance(param.data, (torch.FloatTensor, torch.cuda.FloatTensor,
                                       torch.DoubleTensor, torch.cuda.DoubleTensor,
                                       torch.HalfTensor, torch.cuda.HalfTensor)):  # param.is_floating_point():
                value = value.type_as(param.data)
            value = value.cuda(param.get_device()) if param.is_cuda else value.cpu()
            return value
        elif isinstance(value, dict):
            return {k: cast(param, v) for k, v in value.items()}
        elif isinstance(value, Iterable):
            return type(value)(cast(param, v) for v in value)
        else:
            return value

    # Copy state assigned to params (and cast tensors to appropriate types).
    # State that is not assigned to params is copied as is (needed for
    # backward compatibility).
    state = defaultdict(dict)
    for k, v in state_dict['state'].items():
        if k in id_map:
            param = id_map[k]
            state[param] = cast(param, v)
        else:
            state[k] = v

    # Update parameter groups, setting their 'params' value
    def update_group(group, new_group):
        new_group['params'] = group['params']
        return new_group
    param_groups = [
        update_group(g, ng) for g, ng in zip(groups, saved_groups)]
    optimizer.__setstate__({'state': state, 'param_groups': param_groups}) 
开发者ID:roytseng-tw,项目名称:Detectron.pytorch,代码行数:59,代码来源:misc.py

示例12: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import HalfTensor [as 别名]
def forward(self, input, running_mean, running_var, weight, bias, momentum,
                eps, group, group_size):
        self.momentum = momentum
        self.eps = eps
        self.group = group
        self.group_size = group_size

        assert isinstance(
                   input, (torch.HalfTensor, torch.FloatTensor,
                           torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \
               f'only support Half or Float Tensor, but {input.type()}'
        output = torch.empty_like(input)
        input3d = input.view(input.size(0), input.size(1), -1)
        output3d = output.view_as(input3d)

        mean = torch.empty(
            input3d.size(1), dtype=torch.float, device=input3d.device)
        var = torch.empty(
            input3d.size(1), dtype=torch.float, device=input3d.device)
        if input3d.requires_grad or weight.requires_grad or bias.requires_grad:
            norm = torch.empty_like(
                input3d, dtype=torch.float, device=input3d.device)
            std = torch.empty(
                input3d.size(1), dtype=torch.float, device=input3d.device)
        else:
            norm = torch.empty(0, dtype=torch.float, device=input3d.device)
            std = torch.empty(0, dtype=torch.float, device=input3d.device)

        ext_module.sync_bn_forward_mean(input3d, mean)
        if self.group_size > 1:
            dist.all_reduce(mean, group=self.group)
            mean /= self.group_size
        ext_module.sync_bn_forward_var(input3d, mean, var)
        if self.group_size > 1:
            dist.all_reduce(var, group=self.group)
            var /= self.group_size
        ext_module.sync_bn_forward_output(
            input3d,
            mean,
            var,
            running_mean,
            running_var,
            weight,
            bias,
            norm,
            std,
            output3d,
            eps=self.eps,
            momentum=self.momentum,
            group_size=self.group_size)
        self.save_for_backward(norm, std, weight)
        return output 
开发者ID:open-mmlab,项目名称:mmcv,代码行数:54,代码来源:sync_bn.py

示例13: average_checkpoints

# 需要导入模块: import torch [as 别名]
# 或者: from torch import HalfTensor [as 别名]
def average_checkpoints(inputs: List[str]) -> dict:
    """Loads checkpoints from inputs and returns a model with averaged weights.
    Args:
      inputs: An iterable of string paths of checkpoints to load from.
    Returns:
      A dict of string keys mapping to various values. The 'model' key
      from the returned dict should correspond to an OrderedDict mapping
      string parameter names to torch Tensors.
    """
    params_dict = collections.OrderedDict()
    params_keys = None
    new_state = None
    num_models = len(inputs)

    for f in inputs:
        state = torch.load(
            f,
            map_location=(
                lambda s, _: torch.serialization.default_restore_location(
                    s, 'cpu')
            ),
        )
        # Copies over the settings from the first checkpoint
        if new_state is None:
            new_state = state

        # Averaging: only handle the network params. 
        model_params = state['model_state']

        model_params_keys = list(model_params.keys())
        if params_keys is None:
            params_keys = model_params_keys
        elif params_keys != model_params_keys:
            raise KeyError(
                'For checkpoint {}, expected list of params: {}, '
                'but found: {}'.format(f, params_keys, model_params_keys)
            )

        for k in params_keys:
            p = model_params[k]
            if isinstance(p, torch.HalfTensor):
                p = p.float()
            if k not in params_dict:
                params_dict[k] = p
            else:
                params_dict[k] += p

    averaged_params = collections.OrderedDict()
    # v should be a list of torch Tensor.
    for k, v in params_dict.items():
        averaged_params[k] = v / num_models
    new_state['model_state'] = averaged_params
    return new_state 
开发者ID:joeynmt,项目名称:joeynmt,代码行数:55,代码来源:average_checkpoints.py

示例14: predict_transform_half

# 需要导入模块: import torch [as 别名]
# 或者: from torch import HalfTensor [as 别名]
def predict_transform_half(prediction, inp_dim, anchors, num_classes, CUDA = True):
    batch_size = prediction.size(0)
    stride =  inp_dim // prediction.size(2)

    bbox_attrs = 5 + num_classes
    num_anchors = len(anchors)
    grid_size = inp_dim // stride


    prediction = prediction.view(batch_size, bbox_attrs*num_anchors, grid_size*grid_size)
    prediction = prediction.transpose(1,2).contiguous()
    prediction = prediction.view(batch_size, grid_size*grid_size*num_anchors, bbox_attrs)


    #Sigmoid the  centre_X, centre_Y. and object confidencce
    prediction[:,:,0] = torch.sigmoid(prediction[:,:,0])
    prediction[:,:,1] = torch.sigmoid(prediction[:,:,1])
    prediction[:,:,4] = torch.sigmoid(prediction[:,:,4])


    #Add the center offsets
    grid_len = np.arange(grid_size)
    a,b = np.meshgrid(grid_len, grid_len)

    x_offset = torch.FloatTensor(a).view(-1,1)
    y_offset = torch.FloatTensor(b).view(-1,1)

    if CUDA:
        x_offset = x_offset.cuda().half()
        y_offset = y_offset.cuda().half()

    x_y_offset = torch.cat((x_offset, y_offset), 1).repeat(1,num_anchors).view(-1,2).unsqueeze(0)

    prediction[:,:,:2] += x_y_offset

    #log space transform height and the width
    anchors = torch.HalfTensor(anchors)

    if CUDA:
        anchors = anchors.cuda()

    anchors = anchors.repeat(grid_size*grid_size, 1).unsqueeze(0)
    prediction[:,:,2:4] = torch.exp(prediction[:,:,2:4])*anchors

    #Softmax the class scores
    prediction[:,:,5: 5 + num_classes] = nn.Softmax(-1)(Variable(prediction[:,:, 5 : 5 + num_classes])).data

    prediction[:,:,:4] *= stride


    return prediction 
开发者ID:lxy5513,项目名称:cvToolkit,代码行数:53,代码来源:util.py

示例15: average_checkpoints

# 需要导入模块: import torch [as 别名]
# 或者: from torch import HalfTensor [as 别名]
def average_checkpoints(inputs):
    """Loads checkpoints from inputs and returns a model with averaged weights.

    Args:
      inputs: An iterable of string paths of checkpoints to load from.

    Returns:
      A dict of string keys mapping to various values. The 'model' key
      from the returned dict should correspond to an OrderedDict mapping
      string parameter names to torch Tensors.
    """
    params_dict = collections.OrderedDict()
    params_keys = None
    new_state = None
    num_models = len(inputs)

    for fpath in inputs:
        with PathManager.open(fpath, 'rb') as f:
            state = torch.load(
                f,
                map_location=(
                    lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
                ),
            )
        # Copies over the settings from the first checkpoint
        if new_state is None:
            new_state = state

        model_params = state['model']

        model_params_keys = list(model_params.keys())
        if params_keys is None:
            params_keys = model_params_keys
        elif params_keys != model_params_keys:
            raise KeyError(
                'For checkpoint {}, expected list of params: {}, '
                'but found: {}'.format(f, params_keys, model_params_keys)
            )

        for k in params_keys:
            p = model_params[k]
            if isinstance(p, torch.HalfTensor):
                p = p.float()
            if k not in params_dict:
                params_dict[k] = p.clone()
                # NOTE: clone() is needed in case of p is a shared parameter
            else:
                params_dict[k] += p

    averaged_params = collections.OrderedDict()
    for k, v in params_dict.items():
        averaged_params[k] = v
        averaged_params[k].div_(num_models)
    new_state['model'] = averaged_params
    return new_state 
开发者ID:elbayadm,项目名称:attn2d,代码行数:57,代码来源:average_checkpoints.py


注:本文中的torch.HalfTensor方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。