Skip to content

Commit

Permalink
[CodeCamp2023-516]Add new configuration files for Cylinder3D (#2681)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhaoCake authored Sep 13, 2023
1 parent 48fd72f commit 6e6edfa
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 0 deletions.
49 changes: 49 additions & 0 deletions mmdet3d/configs/_base_/models/cylinder3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet3d.models import Cylinder3D
from mmdet3d.models.backbones import Asymm3DSpconv
from mmdet3d.models.data_preprocessors import Det3DDataPreprocessor
from mmdet3d.models.decode_heads.cylinder3d_head import Cylinder3DHead
from mmdet3d.models.losses import LovaszLoss
from mmdet3d.models.voxel_encoders import SegVFE

grid_shape = [480, 360, 32]
model = dict(
type=Cylinder3D,
data_preprocessor=dict(
type=Det3DDataPreprocessor,
voxel=True,
voxel_type='cylindrical',
voxel_layer=dict(
grid_shape=grid_shape,
point_cloud_range=[0, -3.14159265359, -4, 50, 3.14159265359, 2],
max_num_points=-1,
max_voxels=-1,
),
),
voxel_encoder=dict(
type=SegVFE,
feat_channels=[64, 128, 256, 256],
in_channels=6,
with_voxel_center=True,
feat_compression=16,
return_point_feats=False),
backbone=dict(
type=Asymm3DSpconv,
grid_size=grid_shape,
input_channels=16,
base_channels=32,
norm_cfg=dict(type='BN1d', eps=1e-5, momentum=0.1)),
decode_head=dict(
type=Cylinder3DHead,
channels=128,
num_classes=20,
loss_ce=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=False,
class_weight=None,
loss_weight=1.0),
loss_lovasz=dict(type=LovaszLoss, loss_weight=1.0, reduction='none'),
),
train_cfg=None,
test_cfg=dict(mode='whole'),
)
1 change: 1 addition & 0 deletions mmdet3d/configs/cylinder3d/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.
43 changes: 43 additions & 0 deletions mmdet3d/configs/cylinder3d/cylinder3d_4xb4-3x_semantickitti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine import read_base

with read_base():
from .._base_.datasets.semantickitti import *
from .._base_.models.cylinder3d import *
from .._base_.default_runtime import *

from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper
from mmengine.optim.scheduler.lr_scheduler import LinearLR, MultiStepLR
from mmengine.runner.loops import EpochBasedTrainLoop, TestLoop, ValLoop
from torch.optim import AdamW

# optimizer
lr = 0.001
optim_wrapper = dict(
type=OptimWrapper, optimizer=dict(type=AdamW, lr=lr, weight_decay=0.01))

train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=36, val_interval=1)
val_cfg = dict(type=ValLoop)
test_cfg = dict(type=TestLoop)

# learning rate
param_scheduler = [
dict(type=LinearLR, start_factor=0.001, by_epoch=False, begin=0, end=1000),
dict(
type=MultiStepLR,
begin=0,
end=36,
by_epoch=True,
milestones=[30],
gamma=0.1)
]

train_dataloader.update(dict(batch_size=4, ))

# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (8 GPUs) x (4 samples per GPU).
# auto_scale_lr = dict(enable=False, base_batch_size=32)

default_hooks.update(dict(checkpoint=dict(type=CheckpointHook, interval=5)))
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine import read_base

with read_base():
from .._base_.datasets.semantickitti import *
from .._base_.default_runtime import *
from .._base_.models.cylinder3d import *
from .._base_.schedules.schedule_3x import *

from mmcv.transforms.wrappers import RandomChoice

from mmdet3d.datasets.transforms.transforms_3d import LaserMix, PolarMix

train_pipeline = [
dict(type=LoadPointsFromFile, coord_type='LIDAR', load_dim=4, use_dim=4),
dict(
type=LoadAnnotations3D,
with_bbox_3d=False,
with_label_3d=False,
with_seg_3d=True,
seg_3d_dtype='np.int32',
seg_offset=2**16,
dataset_type='semantickitti'),
dict(type=PointSegClassMapping),
dict(
type=RandomChoice,
transforms=[
[
dict(
type=LaserMix,
num_areas=[3, 4, 5, 6],
pitch_angles=[-25, 3],
pre_transform=[
dict(
type=LoadPointsFromFile,
coord_type='LIDAR',
load_dim=4,
use_dim=4),
dict(
type=LoadAnnotations3D,
with_bbox_3d=False,
with_label_3d=False,
with_seg_3d=True,
seg_3d_dtype='np.int32',
seg_offset=2**16,
dataset_type='semantickitti'),
dict(type=PointSegClassMapping)
],
prob=1)
],
[
dict(
type=PolarMix,
instance_classes=[0, 1, 2, 3, 4, 5, 6, 7],
swap_ratio=0.5,
rotate_paste_ratio=1.0,
pre_transform=[
dict(
type=LoadPointsFromFile,
coord_type='LIDAR',
load_dim=4,
use_dim=4),
dict(
type=LoadAnnotations3D,
with_bbox_3d=False,
with_label_3d=False,
with_seg_3d=True,
seg_3d_dtype='np.int32',
seg_offset=2**16,
dataset_type='semantickitti'),
dict(type=PointSegClassMapping)
],
prob=1)
],
],
prob=[0.5, 0.5]),
dict(
type=GlobalRotScaleTrans,
rot_range=[0., 6.28318531],
scale_ratio_range=[0.95, 1.05],
translation_std=[0, 0, 0],
),
dict(type=Pack3DDetInputs, keys=['points', 'pts_semantic_mask'])
]

train_dataloader.update(dict(dataset=dict(pipeline=train_pipeline)))

default_hooks.update(dict(checkpoint=dict(type=CheckpointHook, interval=1)))

0 comments on commit 6e6edfa

Please sign in to comment.