Skip to content

Commit

Permalink
[Feature] Support DSVT training (#2738)
Browse files Browse the repository at this point in the history
Co-authored-by: JingweiZhang12 <[email protected]>
Co-authored-by: sjh <sunjiahao1999>
  • Loading branch information
sunjiahao1999 and JingweiZhang12 authored Dec 28, 2023
1 parent 5b88c7b commit 762e3b5
Show file tree
Hide file tree
Showing 14 changed files with 875 additions and 86 deletions.
4 changes: 2 additions & 2 deletions mmdet3d/models/dense_heads/centerpoint_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def forward(self, x):
Returns:
dict[str: torch.Tensor]: contains the following keys:
-reg torch.Tensor): 2D regression value with the
-reg (torch.Tensor): 2D regression value with the
shape of [B, 2, H, W].
-height (torch.Tensor): Height value with the
shape of [B, 1, H, W].
Expand Down Expand Up @@ -217,7 +217,7 @@ def forward(self, x):
Returns:
dict[str: torch.Tensor]: contains the following keys:
-reg torch.Tensor): 2D regression value with the
-reg (torch.Tensor): 2D regression value with the
shape of [B, 2, H, W].
-height (torch.Tensor): Height value with the
shape of [B, 1, H, W].
Expand Down
18 changes: 11 additions & 7 deletions mmdet3d/models/necks/second_fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ class SECONDFPN(BaseModule):
upsample_cfg (dict): Config dict of upsample layers.
conv_cfg (dict): Config dict of conv layers.
use_conv_for_no_stride (bool): Whether to use conv when stride is 1.
init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`],
optional): Initialization config dict. Defaults to
[dict(type='Kaiming', layer='ConvTranspose2d'),
dict(type='Constant', layer='NaiveSyncBatchNorm2d', val=1.0)].
"""

def __init__(self,
Expand All @@ -31,7 +35,13 @@ def __init__(self,
upsample_cfg=dict(type='deconv', bias=False),
conv_cfg=dict(type='Conv2d', bias=False),
use_conv_for_no_stride=False,
init_cfg=None):
init_cfg=[
dict(type='Kaiming', layer='ConvTranspose2d'),
dict(
type='Constant',
layer='NaiveSyncBatchNorm2d',
val=1.0)
]):
# if for GroupNorm,
# cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True)
super(SECONDFPN, self).__init__(init_cfg=init_cfg)
Expand Down Expand Up @@ -64,12 +74,6 @@ def __init__(self,
deblocks.append(deblock)
self.deblocks = nn.ModuleList(deblocks)

if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='ConvTranspose2d'),
dict(type='Constant', layer='NaiveSyncBatchNorm2d', val=1.0)
]

def forward(self, x):
"""Forward function.
Expand Down
13 changes: 7 additions & 6 deletions mmdet3d/structures/bbox_3d/base_box3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,13 @@ def in_range_3d(
Tensor: A binary vector indicating whether each point is inside the
reference range.
"""
in_range_flags = ((self.tensor[:, 0] > box_range[0])
& (self.tensor[:, 1] > box_range[1])
& (self.tensor[:, 2] > box_range[2])
& (self.tensor[:, 0] < box_range[3])
& (self.tensor[:, 1] < box_range[4])
& (self.tensor[:, 2] < box_range[5]))
gravity_center = self.gravity_center
in_range_flags = ((gravity_center[:, 0] > box_range[0])
& (gravity_center[:, 1] > box_range[1])
& (gravity_center[:, 2] > box_range[2])
& (gravity_center[:, 0] < box_range[3])
& (gravity_center[:, 1] < box_range[4])
& (gravity_center[:, 2] < box_range[5]))
return in_range_flags

@abstractmethod
Expand Down
18 changes: 13 additions & 5 deletions projects/DSVT/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,25 @@ python tools/test.py projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-

### Training commands

The support of training DSVT is on the way.
In MMDetection3D's root directory, run the following command to test the model:

```bash
tools/dist_train.sh projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py 8 --sync_bn torch
```

## Results and models

### Waymo

| Middle Encoder | Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | Mem (GB) | Inf time (fps) | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download |
| :------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :------: | :------------: | :----: | :-----: | :----: | :---------: | :------: |
| [DSVT](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | [ResSECOND](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | 5 | voxel (0.32) || × | | | 75.2 | 72.2 | 68.9 | 66.1 | |
| Middle Encoder | Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download |
| :------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :----: | :-----: | :----: | :---------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| [DSVT](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | [ResSECOND](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | 5 | voxel (0.32) || × | 75.5 | 72.4 | 69.2 | 66.3 | \[log\](\<https://download.openmmlab.com/mmdetection3d/v1.1.0_models/dsvt/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class_20230917_102130.log) |

**Note**:

- `ResSECOND` denotes the base block in SECOND has residual layers.

**Note** that `ResSECOND` denotes the base block in SECOND has residual layers.
- Regrettably, we are unable to provide the pre-trained model weights due to [Waymo Dataset License Agreement](https://waymo.com/open/terms/), so we only provide the training logs as shown above.

## Citation

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,25 +88,28 @@
loss_cls=dict(
type='mmdet.GaussianFocalLoss', reduction='mean', loss_weight=1.0),
loss_bbox=dict(type='mmdet.L1Loss', reduction='mean', loss_weight=2.0),
loss_iou=dict(type='mmdet.L1Loss', reduction='sum', loss_weight=1.0),
loss_reg_iou=dict(
type='mmdet3d.DIoU3DLoss', reduction='mean', loss_weight=2.0),
norm_bbox=True),
# model training and testing settings
train_cfg=dict(
pts=dict(
grid_size=grid_size,
voxel_size=voxel_size,
out_size_factor=4,
dense_reg=1,
gaussian_overlap=0.1,
max_objs=500,
min_radius=2,
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])),
grid_size=grid_size,
voxel_size=voxel_size,
point_cloud_range=point_cloud_range,
out_size_factor=1,
dense_reg=1,
gaussian_overlap=0.1,
max_objs=500,
min_radius=2,
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]),
test_cfg=dict(
max_per_img=500,
max_pool_nms=False,
min_radius=[4, 12, 10, 1, 0.85, 0.175],
iou_rectifier=[[0.68, 0.71, 0.65]],
pc_range=[-80, -80],
out_size_factor=4,
out_size_factor=1,
voxel_size=voxel_size[:2],
nms_type='rotate',
multi_class_nms=True,
Expand All @@ -128,6 +131,8 @@
coord_type='LIDAR',
load_dim=6,
use_dim=[0, 1, 2, 3, 4],
norm_intensity=True,
norm_elongation=True,
backend_args=backend_args),
backend_args=backend_args)

Expand All @@ -138,25 +143,22 @@
load_dim=6,
use_dim=5,
norm_intensity=True,
norm_elongation=True,
backend_args=backend_args),
# Add this if using `MultiFrameDeformableDecoderRPN`
# dict(
# type='LoadPointsFromMultiSweeps',
# sweeps_num=9,
# load_dim=6,
# use_dim=[0, 1, 2, 3, 4],
# pad_empty_sweeps=True,
# remove_close=True),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.95, 1.05],
translation_std=[0.5, 0.5, 0]),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectNameFilter', classes=class_names),
translation_std=[0.5, 0.5, 0.5]),
dict(type='PointsRangeFilter3D', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter3D', point_cloud_range=point_cloud_range),
dict(type='PointShuffle'),
dict(
type='Pack3DDetInputs',
Expand All @@ -172,25 +174,34 @@
norm_intensity=True,
norm_elongation=True,
backend_args=backend_args),
dict(type='PointsRangeFilter3D', point_cloud_range=point_cloud_range),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range)
]),
dict(type='Pack3DDetInputs', keys=['points'])
type='Pack3DDetInputs',
keys=['points'],
meta_keys=['box_type_3d', 'sample_idx', 'context_name', 'timestamp'])
]

dataset_type = 'WaymoDataset'
train_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='waymo_infos_train.pkl',
data_prefix=dict(pts='training/velodyne', sweeps='training/velodyne'),
pipeline=train_pipeline,
modality=input_modality,
test_mode=False,
metainfo=metainfo,
# we use box_type_3d='LiDAR' in kitti and nuscenes dataset
# and box_type_3d='Depth' in sunrgbd and scannet dataset.
box_type_3d='LiDAR',
# load one frame every five frames
load_interval=5,
backend_args=backend_args))
val_dataloader = dict(
batch_size=4,
num_workers=4,
Expand All @@ -212,18 +223,59 @@

val_evaluator = dict(
type='WaymoMetric',
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
waymo_bin_file='./data/waymo/waymo_format/gt.bin',
data_root='./data/waymo/waymo_format',
backend_args=backend_args,
convert_kitti_format=False,
idx2metainfo='./data/waymo/waymo_format/idx2metainfo.pkl')
result_prefix='./dsvt_pred')
test_evaluator = val_evaluator

vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')

# schedules
lr = 1e-5
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.05, betas=(0.9, 0.99)),
clip_grad=dict(max_norm=10, norm_type=2))
param_scheduler = [
dict(
type='CosineAnnealingLR',
T_max=1.2,
eta_min=lr * 100,
begin=0,
end=1.2,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=10.8,
eta_min=lr * 1e-4,
begin=1.2,
end=12,
by_epoch=True,
convert_to_iter_based=True),
# momentum scheduler
dict(
type='CosineAnnealingMomentum',
T_max=1.2,
eta_min=0.85,
begin=0,
end=1.2,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingMomentum',
T_max=10.8,
eta_min=0.95,
begin=1.2,
end=12,
by_epoch=True,
convert_to_iter_based=True)
]

# runtime settings
train_cfg = dict(by_epoch=True, max_epochs=12, val_interval=1)

# runtime settings
val_cfg = dict()
test_cfg = dict()
Expand All @@ -236,4 +288,12 @@

default_hooks = dict(
logger=dict(type='LoggerHook', interval=50),
checkpoint=dict(type='CheckpointHook', interval=5))
checkpoint=dict(type='CheckpointHook', interval=1))
custom_hooks = [
dict(
type='DisableAugHook',
disable_after_epoch=11,
disable_aug_list=[
'GlobalRotScaleTrans', 'RandomFlip3D', 'ObjectSample'
])
]
5 changes: 4 additions & 1 deletion projects/DSVT/dsvt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from .disable_aug_hook import DisableAugHook
from .dsvt import DSVT
from .dsvt_head import DSVTCenterHead
from .dsvt_transformer import DSVTMiddleEncoder
from .dynamic_pillar_vfe import DynamicPillarVFE3D
from .map2bev import PointPillarsScatter3D
from .res_second import ResSECOND
from .transforms_3d import ObjectRangeFilter3D, PointsRangeFilter3D
from .utils import DSVTBBoxCoder

__all__ = [
'DSVTCenterHead', 'DSVT', 'DSVTMiddleEncoder', 'DynamicPillarVFE3D',
'PointPillarsScatter3D', 'ResSECOND', 'DSVTBBoxCoder'
'PointPillarsScatter3D', 'ResSECOND', 'DSVTBBoxCoder',
'ObjectRangeFilter3D', 'PointsRangeFilter3D', 'DisableAugHook'
]
Loading

0 comments on commit 762e3b5

Please sign in to comment.