本文整理匯總了Python中torch.nn.module方法的典型用法代碼示例。如果您正苦於以下問題:Python nn.module方法的具體用法?Python nn.module怎麽用?Python nn.module使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類torch.nn
的用法示例。
在下文中一共展示了nn.module方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: compute_average_flops_cost
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def compute_average_flops_cost(self):
"""Compute average FLOPs cost.
A method to compute average FLOPs cost, which will be available after
`add_flops_counting_methods()` is called on a desired net object.
Returns:
float: Current mean flops consumption per image.
"""
batches_count = self.__batch_counter__
flops_sum = 0
for module in self.modules():
if is_supported_instance(module):
flops_sum += module.__flops__
params_sum = get_model_parameters_number(self)
return flops_sum / batches_count, params_sum
示例2: start_flops_count
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def start_flops_count(self):
"""Activate the computation of mean flops consumption per image.
A method to activate the computation of mean flops consumption per image.
which will be available after ``add_flops_counting_methods()`` is called on
a desired net object. It should be called before running the network.
"""
add_batch_counter_hook_function(self)
def add_flops_counter_hook_function(module):
if is_supported_instance(module):
if hasattr(module, '__flops_handle__'):
return
else:
handle = module.register_forward_hook(
MODULES_MAPPING[type(module)])
module.__flops_handle__ = handle
self.apply(partial(add_flops_counter_hook_function))
示例3: conv2d
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def conv2d(input, grad_output, layer):
"""
:param input: batch_size * in_c * in_h * in_w
:param grad_output: batch_size * out_c * h * w
:param layer: nn.module batch_size * out_c * (in_c*k_h*k_w + [1 if with bias])
:return:
"""
with torch.no_grad():
input = _extract_patches(input, layer.kernel_size, layer.stride, layer.padding)
input = input.view(-1, input.size(-1)) # b * hw * in_c*kh*kw
grad_output = grad_output.transpose(1, 2).transpose(2, 3)
grad_output = try_contiguous(grad_output).view(grad_output.size(0), -1, grad_output.size(-1))
# b * hw * out_c
if layer.bias is not None:
input = torch.cat([input, input.new(input.size(0), 1).fill_(1)], 1)
input = input.view(grad_output.size(0), -1, input.size(-1)) # b * hw * in_c*kh*kw
grad = torch.einsum('abm,abn->amn', (grad_output, input))
return grad
示例4: test_modify_pool
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def test_modify_pool(net, img_size):
"""Test ability to modify pooling module of network"""
class AdaptiveMaxAvgPool(nn.Module):
def __init__(self):
super().__init__()
self.ada_avgpool = nn.AdaptiveAvgPool2d(1)
self.ada_maxpool = nn.AdaptiveMaxPool2d(1)
def forward(self, x):
avg_x = self.ada_avgpool(x)
max_x = self.ada_maxpool(x)
x = torch.cat((avg_x, max_x), dim=1)
return x
avg_pooling = AdaptiveMaxAvgPool()
fc = nn.Linear(net._fc.in_features * 2, net._global_params.num_classes)
net._avg_pooling = avg_pooling
net._fc = fc
data = torch.zeros((2, 3, img_size, img_size))
output = net(data)
assert not torch.isnan(output).any()
示例5: test_nn
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def test_nn(self) -> None:
"""
Test a model which is a pre-defined nn.module without defining a new
customized network.
"""
batch_size = 5
input_dim = 8
output_dim = 4
x = torch.randn(batch_size, input_dim)
flop_dict, _ = flop_count(nn.Linear(input_dim, output_dim), (x,))
gt_flop = batch_size * input_dim * output_dim / 1e9
gt_dict = defaultdict(float)
gt_dict["addmm"] = gt_flop
self.assertDictEqual(
flop_dict, gt_dict, "nn.Linear failed to pass the flop count test."
)
示例6: init_weights
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def init_weights(m):
"""Initialize weights using kaiming uniform initialization in place
Parameters:
m (nn.module): Linear module from torch.nn
Returns:
None
"""
if type(m) == nn.Linear:
nn.init.kaiming_uniform_(m.weight)
m.bias.data.fill_(0.01)
示例7: get_model_parameters_number
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def get_model_parameters_number(model):
"""Calculate parameter number of a model.
Args:
model (nn.module): The model for parameter number calculation.
Returns:
float: Parameter number of the model.
"""
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return num_params
示例8: add_flops_counting_methods
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def add_flops_counting_methods(net_main_module):
# adding additional methods to the existing module object,
# this is done this way so that each function has access to self object
net_main_module.start_flops_count = start_flops_count.__get__(
net_main_module)
net_main_module.stop_flops_count = stop_flops_count.__get__(
net_main_module)
net_main_module.reset_flops_count = reset_flops_count.__get__(
net_main_module)
net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( # noqa: E501
net_main_module)
net_main_module.reset_flops_count()
return net_main_module
示例9: empty_flops_counter_hook
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def empty_flops_counter_hook(module, input, output):
module.__flops__ += 0
示例10: relu_flops_counter_hook
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def relu_flops_counter_hook(module, input, output):
active_elements_count = output.numel()
module.__flops__ += int(active_elements_count)
示例11: linear_flops_counter_hook
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def linear_flops_counter_hook(module, input, output):
input = input[0]
output_last_dim = output.shape[
-1] # pytorch checks dimensions, so here we don't care much
module.__flops__ += int(np.prod(input.shape) * output_last_dim)
示例12: pool_flops_counter_hook
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def pool_flops_counter_hook(module, input, output):
input = input[0]
module.__flops__ += int(np.prod(input.shape))
示例13: bn_flops_counter_hook
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def bn_flops_counter_hook(module, input, output):
input = input[0]
batch_flops = np.prod(input.shape)
if module.affine:
batch_flops *= 2
module.__flops__ += int(batch_flops)
示例14: batch_counter_hook
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def batch_counter_hook(module, input, output):
batch_size = 1
if len(input) > 0:
# Can have multiple inputs, getting the first one
input = input[0]
batch_size = len(input)
else:
pass
print('Warning! No positional inputs found for a module, '
'assuming batch size is 1.')
module.__batch_counter__ += batch_size
示例15: add_batch_counter_hook_function
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import module [as 別名]
def add_batch_counter_hook_function(module):
if hasattr(module, '__batch_counter_handle__'):
return
handle = module.register_forward_hook(batch_counter_hook)
module.__batch_counter_handle__ = handle