當前位置: 首頁>>代碼示例>>Python>>正文


Python checkpoint.checkpoint方法代碼示例

本文整理匯總了Python中torch.utils.checkpoint.checkpoint方法的典型用法代碼示例。如果您正苦於以下問題:Python checkpoint.checkpoint方法的具體用法?Python checkpoint.checkpoint怎麽用?Python checkpoint.checkpoint使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch.utils.checkpoint的用法示例。


在下文中一共展示了checkpoint.checkpoint方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: forward

# 需要導入模塊: from torch.utils import checkpoint [as 別名]
# 或者: from torch.utils.checkpoint import checkpoint [as 別名]
def forward(self, inputs):
        """Forward function."""
        assert len(inputs) == self.num_ins
        outs = [inputs[0]]
        for i in range(1, self.num_ins):
            outs.append(
                F.interpolate(inputs[i], scale_factor=2**i, mode='bilinear'))
        out = torch.cat(outs, dim=1)
        if out.requires_grad and self.with_cp:
            out = checkpoint(self.reduction_conv, out)
        else:
            out = self.reduction_conv(out)
        outs = [out]
        for i in range(1, self.num_outs):
            outs.append(self.pooling(out, kernel_size=2**i, stride=2**i))
        outputs = []

        for i in range(self.num_outs):
            if outs[i].requires_grad and self.with_cp:
                tmp_out = checkpoint(self.fpn_convs[i], outs[i])
            else:
                tmp_out = self.fpn_convs[i](outs[i])
            outputs.append(tmp_out)
        return tuple(outputs) 
開發者ID:open-mmlab,項目名稱:mmdetection,代碼行數:26,代碼來源:hrfpn.py

示例2: forward

# 需要導入模塊: from torch.utils import checkpoint [as 別名]
# 或者: from torch.utils.checkpoint import checkpoint [as 別名]
def forward(self, inputs):
        assert len(inputs) == self.num_ins
        outs = [inputs[0]]
        for i in range(1, self.num_ins):
            outs.append(
                F.interpolate(inputs[i], scale_factor=2**i, mode='bilinear'))
        out = torch.cat(outs, dim=1)
        if out.requires_grad and self.with_cp:
            out = checkpoint(self.reduction_conv, out)
        else:
            out = self.reduction_conv(out)
        outs = [out]
        for i in range(1, self.num_outs):
            outs.append(self.pooling(out, kernel_size=2**i, stride=2**i))
        outputs = []

        for i in range(self.num_outs):
            if outs[i].requires_grad and self.with_cp:
                tmp_out = checkpoint(self.fpn_convs[i], outs[i])
            else:
                tmp_out = self.fpn_convs[i](outs[i])
            outputs.append(tmp_out)
        return tuple(outputs) 
開發者ID:dingjiansw101,項目名稱:AerialDetection,代碼行數:25,代碼來源:hrfpn.py

示例3: forward

# 需要導入模塊: from torch.utils import checkpoint [as 別名]
# 或者: from torch.utils.checkpoint import checkpoint [as 別名]
def forward(self, x):
        if self.downsample is not None:
            residual = self.downsample(x)
        else:
            residual = x

        if self.use_checkpoint:
            out = checkpoint(self.group1, x)
        else:
            out = self.group1(x)

        if self.use_se:
            weight = F.adaptive_avg_pool2d(out, output_size=1)
            weight = self.se_block(weight)
            out = out * weight

        out = out + residual

        return self.act2(out) 
開發者ID:liuyuisanai,項目名稱:trojans-face-recognizer,代碼行數:21,代碼來源:PolyFace.py

示例4: forward

# 需要導入模塊: from torch.utils import checkpoint [as 別名]
# 或者: from torch.utils.checkpoint import checkpoint [as 別名]
def forward(self, *prev_features):
        """原有的兩次BN層需要消耗的兩塊顯存空間,
        通過使用checkpoint,實現了隻開辟一塊空間用來存儲中間特征
        """
        bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
        # requires_grad is True means that model is in train status
        # checkpoint implement shared memory storage function
        
        if self.efficient and any(prev_feature.requires_grad for prev_feature in prev_features):
            bottleneck_output = cp.checkpoint(bn_function, *prev_features)
        else:
            bottleneck_output = bn_function(*prev_features)
        
        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
        
        if self.drop_rate > 0:
            new_features = F.dropout(new_features,
                                    p=self.drop_rate,
                                    training=self.training
                                    )
        return new_features 
開發者ID:zhouyuangan,項目名稱:SE_DenseNet,代碼行數:23,代碼來源:se_efficient_densenet.py

示例5: forward

# 需要導入模塊: from torch.utils import checkpoint [as 別名]
# 或者: from torch.utils.checkpoint import checkpoint [as 別名]
def forward(self, x):
        """Forward function."""

        def _inner_forward(x):
            identity = x

            out = self.conv1(x)
            out = self.norm1(out)
            out = self.relu(out)

            out = self.conv2(out)
            out = self.norm2(out)

            if self.downsample is not None:
                identity = self.downsample(x)

            out += identity

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out 
開發者ID:open-mmlab,項目名稱:mmdetection,代碼行數:30,代碼來源:resnet.py

示例6: forward

# 需要導入模塊: from torch.utils import checkpoint [as 別名]
# 或者: from torch.utils.checkpoint import checkpoint [as 別名]
def forward(self, inputs):
        return checkpoint.checkpoint(self.transform, inputs) 
開發者ID:bayesiains,項目名稱:nsf,代碼行數:4,代碼來源:autils.py

示例7: forward_checkpoint

# 需要導入模塊: from torch.utils import checkpoint [as 別名]
# 或者: from torch.utils.checkpoint import checkpoint [as 別名]
def forward_checkpoint(self, x):
        with self.set_activation_inplace():
            return checkpoint(self.forward, x) 
開發者ID:yu45020,項目名稱:Text_Segmentation_Image_Inpainting,代碼行數:5,代碼來源:MobileNetV2.py

示例8: forward

# 需要導入模塊: from torch.utils import checkpoint [as 別名]
# 或者: from torch.utils.checkpoint import checkpoint [as 別名]
def forward(
            self, x, layer_past=None, 
            attention_mask=None, head_mask=None):

        return checkpoint(
            self.forward_wrapper, x, layer_past,
            attention_mask, head_mask) 
開發者ID:bme-chatbots,項目名稱:dialogue-generation,代碼行數:9,代碼來源:model.py

示例9: layer_function

# 需要導入模塊: from torch.utils import checkpoint [as 別名]
# 或者: from torch.utils.checkpoint import checkpoint [as 別名]
def layer_function(self, data, layer, previous_layer, i):
        if self.use_checkpointing:
            data = checkpoint(layer, data)
        else:
            data = layer(data)
        if self.residuals[i]:
            if previous_layer is not None:
                data += previous_layer
            previous_layer = data
        return data, previous_layer 
開發者ID:nussl,項目名稱:nussl,代碼行數:12,代碼來源:blocks.py

示例10: forward

# 需要導入模塊: from torch.utils import checkpoint [as 別名]
# 或者: from torch.utils.checkpoint import checkpoint [as 別名]
def forward(self, x):
        feat4, feat16 = self.entry_flow(x)
        #  feat_mid = ckpt.checkpoint(self.middle_flow, feat16)
        feat_mid = self.middle_flow(feat16)
        feat_exit = self.exit_flow(feat_mid)
        return feat4, feat_exit 
開發者ID:CoinCheung,項目名稱:DeepLab-v3-plus-cityscapes,代碼行數:8,代碼來源:xception.py

示例11: do_efficient_fwd

# 需要導入模塊: from torch.utils import checkpoint [as 別名]
# 或者: from torch.utils.checkpoint import checkpoint [as 別名]
def do_efficient_fwd(block, x, efficient):
    # return block(x)
    if efficient and x.requires_grad:
        return cp.checkpoint(block, x)
    else:
        return block(x) 
開發者ID:orsic,項目名稱:swiftnet,代碼行數:8,代碼來源:resnet_pyramid.py

示例12: do_efficient_fwd

# 需要導入模塊: from torch.utils import checkpoint [as 別名]
# 或者: from torch.utils.checkpoint import checkpoint [as 別名]
def do_efficient_fwd(block, x, efficient):
    if efficient and x.requires_grad:
        return cp.checkpoint(block, x)
    else:
        return block(x) 
開發者ID:orsic,項目名稱:swiftnet,代碼行數:7,代碼來源:resnet_single_scale.py

示例13: forward

# 需要導入模塊: from torch.utils import checkpoint [as 別名]
# 或者: from torch.utils.checkpoint import checkpoint [as 別名]
def forward(self, input):
        stack = []
        x = input['kspace'].clone() if isinstance(input, dict) else input

        nmodules = len(self._modules.values())

        for module_idx, module in enumerate(self._modules.values()):
            last_module_flag = module_idx == nmodules - 1

            if isinstance(module, Push):
                stack.append(x)
            elif isinstance(module, Pop):
                if module.method == 'concat':
                    x = torch.cat((x, stack.pop()), 1)
                elif module.method == 'add':
                    x = x + stack.pop()
                else:
                    assert False
            else:
                if isinstance(module, (DC, SensExpand, SensReduce, SoftDC, GRAPPA, MaskCenter)):
                    x = module(x, input)
                else:
                    if args.gradient_checkpointing and not last_module_flag:
                        x.requires_grad_()
                        x = checkpoint(module, x)
                    else:
                        x = module(x)
        return x 
開發者ID:facebookresearch,項目名稱:fastMRI,代碼行數:30,代碼來源:var_net.py

示例14: forward

# 需要導入模塊: from torch.utils import checkpoint [as 別名]
# 或者: from torch.utils.checkpoint import checkpoint [as 別名]
def forward(self, x):

        def _inner_forward(x):
            residual = x

            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)

            out = self.conv2(out)
            out = self.bn2(out)
            out = self.relu(out)

            out = self.conv3(out)
            out = self.bn3(out)

            if self.downsample is not None:
                residual = self.downsample(x)

            out += residual

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out 
開發者ID:open-mmlab,項目名稱:mmcv,代碼行數:33,代碼來源:resnet.py

示例15: forward

# 需要導入模塊: from torch.utils import checkpoint [as 別名]
# 或者: from torch.utils.checkpoint import checkpoint [as 別名]
def forward(self, x):

        def _inner_forward(x):
            identity = x

            out = self.conv1(x)
            out = self.norm1(out)
            out = self.relu(out)

            if not self.with_dcn:
                out = self.conv2(out)
            elif self.with_modulated_dcn:
                offset_mask = self.conv2_offset(out)
                offset = offset_mask[:, :18, :, :]
                mask = offset_mask[:, -9:, :, :].sigmoid()
                out = self.conv2(out, offset, mask)
            else:
                offset = self.conv2_offset(out)
                out = self.conv2(out, offset)
            out = self.norm2(out)
            out = self.relu(out)

            out = self.conv3(out)
            out = self.norm3(out)

            if self.downsample is not None:
                identity = self.downsample(x)

            out += identity

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out 
開發者ID:qixuxiang,項目名稱:mmdetection_with_SENet154,代碼行數:42,代碼來源:resnet.py


注:本文中的torch.utils.checkpoint.checkpoint方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。