Skip to content

Commit

Permalink
[Fix] Update pre-commit-config-zh-cn.yaml and add typehints for Point…
Browse files Browse the repository at this point in the history
…Net2SAMSG (#2396)
  • Loading branch information
chriscarving authored Apr 12, 2023
1 parent 9f61eff commit 4ff1361
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 29 deletions.
14 changes: 4 additions & 10 deletions .pre-commit-config-zh-cn.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
exclude: ^tests/data/
repos:
- repo: https://gitee.com/openmmlab/mirrors-flake8
rev: 5.0.4
Expand All @@ -25,6 +24,10 @@ repos:
args: ["--remove"]
- id: mixed-line-ending
args: ["--fix=lf"]
- repo: https://gitee.com/openmmlab/mirrors-codespell
rev: v2.2.1
hooks:
- id: codespell
- repo: https://gitee.com/openmmlab/mirrors-mdformat
rev: 0.7.9
hooks:
Expand All @@ -34,20 +37,11 @@ repos:
- mdformat-openmmlab
- mdformat_frontmatter
- linkify-it-py
- repo: https://gitee.com/openmmlab/mirrors-codespell
rev: v2.2.1
hooks:
- id: codespell
- repo: https://gitee.com/openmmlab/mirrors-docformatter
rev: v1.3.1
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]
- repo: https://gitee.com/openmmlab/mirrors-pyupgrade
rev: v3.0.0
hooks:
- id: pyupgrade
args: ["--py36-plus"]
- repo: https://gitee.com/openmmlab/pre-commit-hooks
rev: v0.2.0
hooks:
Expand Down
54 changes: 36 additions & 18 deletions mmdet3d/models/backbones/pointnet2_sa_msg.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch
from mmcv.cnn import ConvModule
from torch import nn as nn

from mmdet3d.models.layers.pointnet_modules import build_sa_module
from mmdet3d.registry import MODELS
from mmdet3d.utils import OptConfigType
from .base_pointnet import BasePointNet

ThreeTupleIntType = Tuple[Tuple[Tuple[int, int, int]]]
TwoTupleIntType = Tuple[Tuple[int, int, int]]
TwoTupleStrType = Tuple[Tuple[str]]


@MODELS.register_module()
class PointNet2SAMSG(BasePointNet):
Expand All @@ -22,7 +29,7 @@ class PointNet2SAMSG(BasePointNet):
sa_channels (tuple[tuple[int]]): Out channels of each mlp in SA module.
aggregation_channels (tuple[int]): Out channels of aggregation
multi-scale grouping features.
fps_mods (tuple[int]): Mod of FPS for each SA module.
fps_mods Sequence[Tuple[str]]: Mod of FPS for each SA module.
fps_sample_range_lists (tuple[tuple[int]]): The number of sampling
points which each SA module samples.
dilated_group (tuple[bool]): Whether to use dilated ball query for
Expand All @@ -38,26 +45,37 @@ class PointNet2SAMSG(BasePointNet):
"""

def __init__(self,
in_channels,
num_points=(2048, 1024, 512, 256),
radii=((0.2, 0.4, 0.8), (0.4, 0.8, 1.6), (1.6, 3.2, 4.8)),
num_samples=((32, 32, 64), (32, 32, 64), (32, 32, 32)),
sa_channels=(((16, 16, 32), (16, 16, 32), (32, 32, 64)),
((64, 64, 128), (64, 64, 128), (64, 96, 128)),
((128, 128, 256), (128, 192, 256), (128, 256,
256))),
aggregation_channels=(64, 128, 256),
fps_mods=(('D-FPS'), ('FS'), ('F-FPS', 'D-FPS')),
fps_sample_range_lists=((-1), (-1), (512, -1)),
dilated_group=(True, True, True),
out_indices=(2, ),
norm_cfg=dict(type='BN2d'),
sa_cfg=dict(
in_channels: int,
num_points: Tuple[int] = (2048, 1024, 512, 256),
radii: Tuple[Tuple[float, float, float]] = (
(0.2, 0.4, 0.8),
(0.4, 0.8, 1.6),
(1.6, 3.2, 4.8),
),
num_samples: TwoTupleIntType = ((32, 32, 64), (32, 32, 64),
(32, 32, 32)),
sa_channels: ThreeTupleIntType = (((16, 16, 32), (16, 16, 32),
(32, 32, 64)),
((64, 64, 128),
(64, 64, 128), (64, 96,
128)),
((128, 128, 256),
(128, 192, 256), (128, 256,
256))),
aggregation_channels: Tuple[int] = (64, 128, 256),
fps_mods: TwoTupleStrType = (('D-FPS'), ('FS'), ('F-FPS',
'D-FPS')),
fps_sample_range_lists: TwoTupleIntType = ((-1), (-1), (512,
-1)),
dilated_group: Tuple[bool] = (True, True, True),
out_indices: Tuple[int] = (2, ),
norm_cfg: dict = dict(type='BN2d'),
sa_cfg: dict = dict(
type='PointSAModuleMSG',
pool_mod='max',
use_xyz=True,
normalize_xyz=False),
init_cfg=None):
init_cfg: OptConfigType = None):
super().__init__(init_cfg=init_cfg)
self.num_sa = len(sa_channels)
self.out_indices = out_indices
Expand Down Expand Up @@ -123,7 +141,7 @@ def __init__(self,
bias=True))
sa_in_channel = cur_aggregation_channel

def forward(self, points):
def forward(self, points: torch.Tensor):
"""Forward pass.
Args:
Expand Down
4 changes: 3 additions & 1 deletion mmdet3d/models/layers/pointnet_modules/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from mmengine.registry import Registry
from torch import nn as nn

SA_MODULES = Registry('point_sa_module')
SA_MODULES = Registry(
name='point_sa_module',
locations=['mmdet3d.models.layers.pointnet_modules'])


def build_sa_module(cfg: Union[dict, None], *args, **kwargs) -> nn.Module:
Expand Down

0 comments on commit 4ff1361

Please sign in to comment.