本文整理汇总了Python中torch.utils.checkpoint.checkpoint方法的典型用法代码示例。如果您正苦于以下问题:Python checkpoint.checkpoint方法的具体用法?Python checkpoint.checkpoint怎么用?Python checkpoint.checkpoint使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.utils.checkpoint
的用法示例。
在下文中一共展示了checkpoint.checkpoint方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: forward
# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward(self, inputs):
"""Forward function."""
assert len(inputs) == self.num_ins
outs = [inputs[0]]
for i in range(1, self.num_ins):
outs.append(
F.interpolate(inputs[i], scale_factor=2**i, mode='bilinear'))
out = torch.cat(outs, dim=1)
if out.requires_grad and self.with_cp:
out = checkpoint(self.reduction_conv, out)
else:
out = self.reduction_conv(out)
outs = [out]
for i in range(1, self.num_outs):
outs.append(self.pooling(out, kernel_size=2**i, stride=2**i))
outputs = []
for i in range(self.num_outs):
if outs[i].requires_grad and self.with_cp:
tmp_out = checkpoint(self.fpn_convs[i], outs[i])
else:
tmp_out = self.fpn_convs[i](outs[i])
outputs.append(tmp_out)
return tuple(outputs)
示例2: forward
# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward(self, inputs):
assert len(inputs) == self.num_ins
outs = [inputs[0]]
for i in range(1, self.num_ins):
outs.append(
F.interpolate(inputs[i], scale_factor=2**i, mode='bilinear'))
out = torch.cat(outs, dim=1)
if out.requires_grad and self.with_cp:
out = checkpoint(self.reduction_conv, out)
else:
out = self.reduction_conv(out)
outs = [out]
for i in range(1, self.num_outs):
outs.append(self.pooling(out, kernel_size=2**i, stride=2**i))
outputs = []
for i in range(self.num_outs):
if outs[i].requires_grad and self.with_cp:
tmp_out = checkpoint(self.fpn_convs[i], outs[i])
else:
tmp_out = self.fpn_convs[i](outs[i])
outputs.append(tmp_out)
return tuple(outputs)
示例3: forward
# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward(self, x):
if self.downsample is not None:
residual = self.downsample(x)
else:
residual = x
if self.use_checkpoint:
out = checkpoint(self.group1, x)
else:
out = self.group1(x)
if self.use_se:
weight = F.adaptive_avg_pool2d(out, output_size=1)
weight = self.se_block(weight)
out = out * weight
out = out + residual
return self.act2(out)
示例4: forward
# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward(self, *prev_features):
"""原有的两次BN层需要消耗的两块显存空间,
通过使用checkpoint,实现了只开辟一块空间用来存储中间特征
"""
bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
# requires_grad is True means that model is in train status
# checkpoint implement shared memory storage function
if self.efficient and any(prev_feature.requires_grad for prev_feature in prev_features):
bottleneck_output = cp.checkpoint(bn_function, *prev_features)
else:
bottleneck_output = bn_function(*prev_features)
new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
if self.drop_rate > 0:
new_features = F.dropout(new_features,
p=self.drop_rate,
training=self.training
)
return new_features
示例5: forward
# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward(self, x):
"""Forward function."""
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
示例6: forward
# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward(self, inputs):
return checkpoint.checkpoint(self.transform, inputs)
示例7: forward_checkpoint
# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward_checkpoint(self, x):
with self.set_activation_inplace():
return checkpoint(self.forward, x)
示例8: forward
# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward(
self, x, layer_past=None,
attention_mask=None, head_mask=None):
return checkpoint(
self.forward_wrapper, x, layer_past,
attention_mask, head_mask)
示例9: layer_function
# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def layer_function(self, data, layer, previous_layer, i):
if self.use_checkpointing:
data = checkpoint(layer, data)
else:
data = layer(data)
if self.residuals[i]:
if previous_layer is not None:
data += previous_layer
previous_layer = data
return data, previous_layer
示例10: forward
# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward(self, x):
feat4, feat16 = self.entry_flow(x)
# feat_mid = ckpt.checkpoint(self.middle_flow, feat16)
feat_mid = self.middle_flow(feat16)
feat_exit = self.exit_flow(feat_mid)
return feat4, feat_exit
示例11: do_efficient_fwd
# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def do_efficient_fwd(block, x, efficient):
# return block(x)
if efficient and x.requires_grad:
return cp.checkpoint(block, x)
else:
return block(x)
示例12: do_efficient_fwd
# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def do_efficient_fwd(block, x, efficient):
if efficient and x.requires_grad:
return cp.checkpoint(block, x)
else:
return block(x)
示例13: forward
# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward(self, input):
stack = []
x = input['kspace'].clone() if isinstance(input, dict) else input
nmodules = len(self._modules.values())
for module_idx, module in enumerate(self._modules.values()):
last_module_flag = module_idx == nmodules - 1
if isinstance(module, Push):
stack.append(x)
elif isinstance(module, Pop):
if module.method == 'concat':
x = torch.cat((x, stack.pop()), 1)
elif module.method == 'add':
x = x + stack.pop()
else:
assert False
else:
if isinstance(module, (DC, SensExpand, SensReduce, SoftDC, GRAPPA, MaskCenter)):
x = module(x, input)
else:
if args.gradient_checkpointing and not last_module_flag:
x.requires_grad_()
x = checkpoint(module, x)
else:
x = module(x)
return x
示例14: forward
# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward(self, x):
def _inner_forward(x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
示例15: forward
# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward(self, x):
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
if not self.with_dcn:
out = self.conv2(out)
elif self.with_modulated_dcn:
offset_mask = self.conv2_offset(out)
offset = offset_mask[:, :18, :, :]
mask = offset_mask[:, -9:, :, :].sigmoid()
out = self.conv2(out, offset, mask)
else:
offset = self.conv2_offset(out)
out = self.conv2(out, offset)
out = self.norm2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.norm3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out