當前位置: 首頁>>代碼示例>>Python>>正文


Python nn.module方法代碼示例

本文整理匯總了Python中torch.nn.module方法的典型用法代碼示例。如果您正苦於以下問題:Python nn.module方法的具體用法?Python nn.module怎麽用?Python nn.module使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch.nn的用法示例。


在下文中一共展示了nn.module方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: compute_average_flops_cost

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def compute_average_flops_cost(self):
    """Compute average FLOPs cost.

    A method to compute average FLOPs cost, which will be available after
    `add_flops_counting_methods()` is called on a desired net object.

    Returns:
        float: Current mean flops consumption per image.
    """
    batches_count = self.__batch_counter__
    flops_sum = 0
    for module in self.modules():
        if is_supported_instance(module):
            flops_sum += module.__flops__
    params_sum = get_model_parameters_number(self)
    return flops_sum / batches_count, params_sum 
開發者ID:open-mmlab,項目名稱:mmcv,代碼行數:18,代碼來源:flops_counter.py

示例2: start_flops_count

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def start_flops_count(self):
    """Activate the computation of mean flops consumption per image.

    A method to activate the computation of mean flops consumption per image.
    which will be available after ``add_flops_counting_methods()`` is called on
    a desired net object. It should be called before running the network.
    """
    add_batch_counter_hook_function(self)

    def add_flops_counter_hook_function(module):
        if is_supported_instance(module):
            if hasattr(module, '__flops_handle__'):
                return

            else:
                handle = module.register_forward_hook(
                    MODULES_MAPPING[type(module)])

            module.__flops_handle__ = handle

    self.apply(partial(add_flops_counter_hook_function)) 
開發者ID:open-mmlab,項目名稱:mmcv,代碼行數:23,代碼來源:flops_counter.py

示例3: conv2d

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def conv2d(input, grad_output, layer):
        """
        :param input: batch_size * in_c * in_h * in_w
        :param grad_output: batch_size * out_c * h * w
        :param layer: nn.module batch_size * out_c * (in_c*k_h*k_w + [1 if with bias])
        :return:
        """
        with torch.no_grad():
            input = _extract_patches(input, layer.kernel_size, layer.stride, layer.padding)
            input = input.view(-1, input.size(-1))  # b * hw * in_c*kh*kw
            grad_output = grad_output.transpose(1, 2).transpose(2, 3)
            grad_output = try_contiguous(grad_output).view(grad_output.size(0), -1, grad_output.size(-1))
            # b * hw * out_c
            if layer.bias is not None:
                input = torch.cat([input, input.new(input.size(0), 1).fill_(1)], 1)
            input = input.view(grad_output.size(0), -1, input.size(-1))  # b * hw * in_c*kh*kw
            grad = torch.einsum('abm,abn->amn', (grad_output, input))
        return grad 
開發者ID:alecwangcq,項目名稱:EigenDamage-Pytorch,代碼行數:20,代碼來源:kfac_utils.py

示例4: test_modify_pool

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def test_modify_pool(net, img_size):
    """Test ability to modify pooling module of network"""

    class AdaptiveMaxAvgPool(nn.Module):

        def __init__(self):
            super().__init__()
            self.ada_avgpool = nn.AdaptiveAvgPool2d(1)
            self.ada_maxpool = nn.AdaptiveMaxPool2d(1)

        def forward(self, x):
            avg_x = self.ada_avgpool(x)
            max_x = self.ada_maxpool(x)
            x = torch.cat((avg_x, max_x), dim=1)
            return x

    avg_pooling = AdaptiveMaxAvgPool()
    fc = nn.Linear(net._fc.in_features * 2, net._global_params.num_classes)

    net._avg_pooling = avg_pooling
    net._fc = fc

    data = torch.zeros((2, 3, img_size, img_size))
    output = net(data)
    assert not torch.isnan(output).any() 
開發者ID:lukemelas,項目名稱:EfficientNet-PyTorch,代碼行數:27,代碼來源:test_model.py

示例5: test_nn

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def test_nn(self) -> None:
        """
        Test a model which is a pre-defined nn.module without defining a new
        customized network.
        """
        batch_size = 5
        input_dim = 8
        output_dim = 4
        x = torch.randn(batch_size, input_dim)
        flop_dict, _ = flop_count(nn.Linear(input_dim, output_dim), (x,))
        gt_flop = batch_size * input_dim * output_dim / 1e9
        gt_dict = defaultdict(float)
        gt_dict["addmm"] = gt_flop
        self.assertDictEqual(
            flop_dict, gt_dict, "nn.Linear failed to pass the flop count test."
        ) 
開發者ID:facebookresearch,項目名稱:fvcore,代碼行數:18,代碼來源:test_flop_count.py

示例6: init_weights

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def init_weights(m):
    """Initialize weights using kaiming uniform initialization in place

    Parameters:
        m (nn.module): Linear module from torch.nn

    Returns:
        None
    """
    if type(m) == nn.Linear:
        nn.init.kaiming_uniform_(m.weight)
        m.bias.data.fill_(0.01) 
開發者ID:IntelAI,項目名稱:cerl,代碼行數:14,代碼來源:mod_utils.py

示例7: get_model_parameters_number

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def get_model_parameters_number(model):
    """Calculate parameter number of a model.

    Args:
        model (nn.module): The model for parameter number calculation.

    Returns:
        float: Parameter number of the model.
    """
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return num_params 
開發者ID:open-mmlab,項目名稱:mmcv,代碼行數:13,代碼來源:flops_counter.py

示例8: add_flops_counting_methods

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def add_flops_counting_methods(net_main_module):
    # adding additional methods to the existing module object,
    # this is done this way so that each function has access to self object
    net_main_module.start_flops_count = start_flops_count.__get__(
        net_main_module)
    net_main_module.stop_flops_count = stop_flops_count.__get__(
        net_main_module)
    net_main_module.reset_flops_count = reset_flops_count.__get__(
        net_main_module)
    net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(  # noqa: E501
        net_main_module)

    net_main_module.reset_flops_count()

    return net_main_module 
開發者ID:open-mmlab,項目名稱:mmcv,代碼行數:17,代碼來源:flops_counter.py

示例9: empty_flops_counter_hook

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def empty_flops_counter_hook(module, input, output):
    module.__flops__ += 0 
開發者ID:open-mmlab,項目名稱:mmcv,代碼行數:4,代碼來源:flops_counter.py

示例10: relu_flops_counter_hook

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def relu_flops_counter_hook(module, input, output):
    active_elements_count = output.numel()
    module.__flops__ += int(active_elements_count) 
開發者ID:open-mmlab,項目名稱:mmcv,代碼行數:5,代碼來源:flops_counter.py

示例11: linear_flops_counter_hook

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def linear_flops_counter_hook(module, input, output):
    input = input[0]
    output_last_dim = output.shape[
        -1]  # pytorch checks dimensions, so here we don't care much
    module.__flops__ += int(np.prod(input.shape) * output_last_dim) 
開發者ID:open-mmlab,項目名稱:mmcv,代碼行數:7,代碼來源:flops_counter.py

示例12: pool_flops_counter_hook

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def pool_flops_counter_hook(module, input, output):
    input = input[0]
    module.__flops__ += int(np.prod(input.shape)) 
開發者ID:open-mmlab,項目名稱:mmcv,代碼行數:5,代碼來源:flops_counter.py

示例13: bn_flops_counter_hook

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def bn_flops_counter_hook(module, input, output):
    input = input[0]

    batch_flops = np.prod(input.shape)
    if module.affine:
        batch_flops *= 2
    module.__flops__ += int(batch_flops) 
開發者ID:open-mmlab,項目名稱:mmcv,代碼行數:9,代碼來源:flops_counter.py

示例14: batch_counter_hook

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def batch_counter_hook(module, input, output):
    batch_size = 1
    if len(input) > 0:
        # Can have multiple inputs, getting the first one
        input = input[0]
        batch_size = len(input)
    else:
        pass
        print('Warning! No positional inputs found for a module, '
              'assuming batch size is 1.')
    module.__batch_counter__ += batch_size 
開發者ID:open-mmlab,項目名稱:mmcv,代碼行數:13,代碼來源:flops_counter.py

示例15: add_batch_counter_hook_function

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def add_batch_counter_hook_function(module):
    if hasattr(module, '__batch_counter_handle__'):
        return

    handle = module.register_forward_hook(batch_counter_hook)
    module.__batch_counter_handle__ = handle 
開發者ID:open-mmlab,項目名稱:mmcv,代碼行數:8,代碼來源:flops_counter.py


注:本文中的torch.nn.module方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。