当前位置: 首页>>代码示例>>Python>>正文


Python checkpoint.checkpoint方法代码示例

本文整理汇总了Python中torch.utils.checkpoint.checkpoint方法的典型用法代码示例。如果您正苦于以下问题:Python checkpoint.checkpoint方法的具体用法?Python checkpoint.checkpoint怎么用?Python checkpoint.checkpoint使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.checkpoint的用法示例。


在下文中一共展示了checkpoint.checkpoint方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: forward

# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward(self, inputs):
        """Forward function."""
        assert len(inputs) == self.num_ins
        outs = [inputs[0]]
        for i in range(1, self.num_ins):
            outs.append(
                F.interpolate(inputs[i], scale_factor=2**i, mode='bilinear'))
        out = torch.cat(outs, dim=1)
        if out.requires_grad and self.with_cp:
            out = checkpoint(self.reduction_conv, out)
        else:
            out = self.reduction_conv(out)
        outs = [out]
        for i in range(1, self.num_outs):
            outs.append(self.pooling(out, kernel_size=2**i, stride=2**i))
        outputs = []

        for i in range(self.num_outs):
            if outs[i].requires_grad and self.with_cp:
                tmp_out = checkpoint(self.fpn_convs[i], outs[i])
            else:
                tmp_out = self.fpn_convs[i](outs[i])
            outputs.append(tmp_out)
        return tuple(outputs) 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:26,代码来源:hrfpn.py

示例2: forward

# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward(self, inputs):
        assert len(inputs) == self.num_ins
        outs = [inputs[0]]
        for i in range(1, self.num_ins):
            outs.append(
                F.interpolate(inputs[i], scale_factor=2**i, mode='bilinear'))
        out = torch.cat(outs, dim=1)
        if out.requires_grad and self.with_cp:
            out = checkpoint(self.reduction_conv, out)
        else:
            out = self.reduction_conv(out)
        outs = [out]
        for i in range(1, self.num_outs):
            outs.append(self.pooling(out, kernel_size=2**i, stride=2**i))
        outputs = []

        for i in range(self.num_outs):
            if outs[i].requires_grad and self.with_cp:
                tmp_out = checkpoint(self.fpn_convs[i], outs[i])
            else:
                tmp_out = self.fpn_convs[i](outs[i])
            outputs.append(tmp_out)
        return tuple(outputs) 
开发者ID:dingjiansw101,项目名称:AerialDetection,代码行数:25,代码来源:hrfpn.py

示例3: forward

# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward(self, x):
        if self.downsample is not None:
            residual = self.downsample(x)
        else:
            residual = x

        if self.use_checkpoint:
            out = checkpoint(self.group1, x)
        else:
            out = self.group1(x)

        if self.use_se:
            weight = F.adaptive_avg_pool2d(out, output_size=1)
            weight = self.se_block(weight)
            out = out * weight

        out = out + residual

        return self.act2(out) 
开发者ID:liuyuisanai,项目名称:trojans-face-recognizer,代码行数:21,代码来源:PolyFace.py

示例4: forward

# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward(self, *prev_features):
        """原有的两次BN层需要消耗的两块显存空间,
        通过使用checkpoint,实现了只开辟一块空间用来存储中间特征
        """
        bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
        # requires_grad is True means that model is in train status
        # checkpoint implement shared memory storage function
        
        if self.efficient and any(prev_feature.requires_grad for prev_feature in prev_features):
            bottleneck_output = cp.checkpoint(bn_function, *prev_features)
        else:
            bottleneck_output = bn_function(*prev_features)
        
        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
        
        if self.drop_rate > 0:
            new_features = F.dropout(new_features,
                                    p=self.drop_rate,
                                    training=self.training
                                    )
        return new_features 
开发者ID:zhouyuangan,项目名称:SE_DenseNet,代码行数:23,代码来源:se_efficient_densenet.py

示例5: forward

# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward(self, x):
        """Forward function."""

        def _inner_forward(x):
            identity = x

            out = self.conv1(x)
            out = self.norm1(out)
            out = self.relu(out)

            out = self.conv2(out)
            out = self.norm2(out)

            if self.downsample is not None:
                identity = self.downsample(x)

            out += identity

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:30,代码来源:resnet.py

示例6: forward

# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward(self, inputs):
        return checkpoint.checkpoint(self.transform, inputs) 
开发者ID:bayesiains,项目名称:nsf,代码行数:4,代码来源:autils.py

示例7: forward_checkpoint

# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward_checkpoint(self, x):
        with self.set_activation_inplace():
            return checkpoint(self.forward, x) 
开发者ID:yu45020,项目名称:Text_Segmentation_Image_Inpainting,代码行数:5,代码来源:MobileNetV2.py

示例8: forward

# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward(
            self, x, layer_past=None, 
            attention_mask=None, head_mask=None):

        return checkpoint(
            self.forward_wrapper, x, layer_past,
            attention_mask, head_mask) 
开发者ID:bme-chatbots,项目名称:dialogue-generation,代码行数:9,代码来源:model.py

示例9: layer_function

# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def layer_function(self, data, layer, previous_layer, i):
        if self.use_checkpointing:
            data = checkpoint(layer, data)
        else:
            data = layer(data)
        if self.residuals[i]:
            if previous_layer is not None:
                data += previous_layer
            previous_layer = data
        return data, previous_layer 
开发者ID:nussl,项目名称:nussl,代码行数:12,代码来源:blocks.py

示例10: forward

# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward(self, x):
        feat4, feat16 = self.entry_flow(x)
        #  feat_mid = ckpt.checkpoint(self.middle_flow, feat16)
        feat_mid = self.middle_flow(feat16)
        feat_exit = self.exit_flow(feat_mid)
        return feat4, feat_exit 
开发者ID:CoinCheung,项目名称:DeepLab-v3-plus-cityscapes,代码行数:8,代码来源:xception.py

示例11: do_efficient_fwd

# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def do_efficient_fwd(block, x, efficient):
    # return block(x)
    if efficient and x.requires_grad:
        return cp.checkpoint(block, x)
    else:
        return block(x) 
开发者ID:orsic,项目名称:swiftnet,代码行数:8,代码来源:resnet_pyramid.py

示例12: do_efficient_fwd

# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def do_efficient_fwd(block, x, efficient):
    if efficient and x.requires_grad:
        return cp.checkpoint(block, x)
    else:
        return block(x) 
开发者ID:orsic,项目名称:swiftnet,代码行数:7,代码来源:resnet_single_scale.py

示例13: forward

# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward(self, input):
        stack = []
        x = input['kspace'].clone() if isinstance(input, dict) else input

        nmodules = len(self._modules.values())

        for module_idx, module in enumerate(self._modules.values()):
            last_module_flag = module_idx == nmodules - 1

            if isinstance(module, Push):
                stack.append(x)
            elif isinstance(module, Pop):
                if module.method == 'concat':
                    x = torch.cat((x, stack.pop()), 1)
                elif module.method == 'add':
                    x = x + stack.pop()
                else:
                    assert False
            else:
                if isinstance(module, (DC, SensExpand, SensReduce, SoftDC, GRAPPA, MaskCenter)):
                    x = module(x, input)
                else:
                    if args.gradient_checkpointing and not last_module_flag:
                        x.requires_grad_()
                        x = checkpoint(module, x)
                    else:
                        x = module(x)
        return x 
开发者ID:facebookresearch,项目名称:fastMRI,代码行数:30,代码来源:var_net.py

示例14: forward

# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward(self, x):

        def _inner_forward(x):
            residual = x

            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)

            out = self.conv2(out)
            out = self.bn2(out)
            out = self.relu(out)

            out = self.conv3(out)
            out = self.bn3(out)

            if self.downsample is not None:
                residual = self.downsample(x)

            out += residual

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out 
开发者ID:open-mmlab,项目名称:mmcv,代码行数:33,代码来源:resnet.py

示例15: forward

# 需要导入模块: from torch.utils import checkpoint [as 别名]
# 或者: from torch.utils.checkpoint import checkpoint [as 别名]
def forward(self, x):

        def _inner_forward(x):
            identity = x

            out = self.conv1(x)
            out = self.norm1(out)
            out = self.relu(out)

            if not self.with_dcn:
                out = self.conv2(out)
            elif self.with_modulated_dcn:
                offset_mask = self.conv2_offset(out)
                offset = offset_mask[:, :18, :, :]
                mask = offset_mask[:, -9:, :, :].sigmoid()
                out = self.conv2(out, offset, mask)
            else:
                offset = self.conv2_offset(out)
                out = self.conv2(out, offset)
            out = self.norm2(out)
            out = self.relu(out)

            out = self.conv3(out)
            out = self.norm3(out)

            if self.downsample is not None:
                identity = self.downsample(x)

            out += identity

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out 
开发者ID:qixuxiang,项目名称:mmdetection_with_SENet154,代码行数:42,代码来源:resnet.py


注:本文中的torch.utils.checkpoint.checkpoint方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。