当前位置: 首页>>代码示例>>Python>>正文


Python pytorch_utils.SharedMLP方法代码示例

本文整理汇总了Python中pytorch_utils.SharedMLP方法的典型用法代码示例。如果您正苦于以下问题:Python pytorch_utils.SharedMLP方法的具体用法?Python pytorch_utils.SharedMLP怎么用?Python pytorch_utils.SharedMLP使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在pytorch_utils的用法示例。


在下文中一共展示了pytorch_utils.SharedMLP方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: import pytorch_utils [as 别名]
# 或者: from pytorch_utils import SharedMLP [as 别名]
def __init__(
            self,
            *,
            npoint: int,
            radii: List[float],
            nsamples: List[int],
            mlps: List[List[int]],
            bn: bool = True,
            use_xyz: bool = True, 
            sample_uniformly: bool = False
    ):
        super().__init__()

        assert len(radii) == len(nsamples) == len(mlps)

        self.npoint = npoint
        self.groupers = nn.ModuleList()
        self.mlps = nn.ModuleList()
        for i in range(len(radii)):
            radius = radii[i]
            nsample = nsamples[i]
            self.groupers.append(
                pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz, sample_uniformly=sample_uniformly)
                if npoint is not None else pointnet2_utils.GroupAll(use_xyz)
            )
            mlp_spec = mlps[i]
            if use_xyz:
                mlp_spec[0] += 3

            self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn)) 
开发者ID:zaiweizhang,项目名称:H3DNet,代码行数:32,代码来源:pointnet2_modules.py

示例2: __init__

# 需要导入模块: import pytorch_utils [as 别名]
# 或者: from pytorch_utils import SharedMLP [as 别名]
def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True,
                 use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
        """
        :param npoint: int
        :param radii: list of float, list of radii to group with
        :param nsamples: list of int, number of samples in each ball query
        :param mlps: list of list of int, spec of the pointnet before the global pooling for each scale
        :param bn: whether to use batchnorm
        :param use_xyz:
        :param pool_method: max_pool / avg_pool
        :param instance_norm: whether to use instance_norm
        """
        super().__init__()

        assert len(radii) == len(nsamples) == len(mlps)

        self.npoint = npoint
        self.groupers = nn.ModuleList()
        self.mlps = nn.ModuleList()
        for i in range(len(radii)):
            radius = radii[i]
            nsample = nsamples[i]
            self.groupers.append(
                pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
                if npoint is not None else pointnet2_utils.GroupAll(use_xyz)
            )
            mlp_spec = mlps[i]
            if use_xyz:
                mlp_spec[0] += 3

            self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm))
        self.pool_method = pool_method 
开发者ID:daveredrum,项目名称:Pointnet2.ScanNet,代码行数:34,代码来源:pointnet2_modules.py

示例3: __init__

# 需要导入模块: import pytorch_utils [as 别名]
# 或者: from pytorch_utils import SharedMLP [as 别名]
def __init__(self, *, mlp: List[int], bn: bool = True):
        super().__init__()
        self.mlp = pt_utils.SharedMLP(mlp, bn=bn) 
开发者ID:Yochengliu,项目名称:Relation-Shape-CNN,代码行数:5,代码来源:pointnet2_modules.py

示例4: __init__

# 需要导入模块: import pytorch_utils [as 别名]
# 或者: from pytorch_utils import SharedMLP [as 别名]
def __init__(
        self,
        *,
        npoint: int,
        radii: List[float],
        nsamples: List[int],
        mlps: List[List[int]],
        bn: bool = True,
        use_xyz: bool = True,
        sample_uniformly: bool = False
    ):
        super().__init__()

        assert len(radii) == len(nsamples) == len(mlps)

        self.npoint = npoint
        self.groupers = nn.ModuleList()
        self.mlps = nn.ModuleList()
        for i in range(len(radii)):
            radius = radii[i]
            nsample = nsamples[i]
            self.groupers.append(
                pointnet2_utils.QueryAndGroup(
                    radius, nsample, use_xyz=use_xyz, sample_uniformly=sample_uniformly
                )
                if npoint is not None
                else pointnet2_utils.GroupAll(use_xyz)
            )
            mlp_spec = mlps[i]
            if use_xyz:
                mlp_spec[0] += 3

            self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn)) 
开发者ID:poodarchu,项目名称:Det3D,代码行数:35,代码来源:pointnet2_modules.py


注:本文中的pytorch_utils.SharedMLP方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。