Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
1621dd8
support for squeeze backbone with fpn neck
Feb 15, 2025
dc4446d
Trial commit
Feb 18, 2025
632c058
Before coord_type change
Feb 18, 2025
2b69aa7
WIP
Feb 19, 2025
eebe279
mvxnet kitti with squeeze+fpn
Feb 20, 2025
061d627
working configs
Feb 22, 2025
9716b1a
working configs
Feb 22, 2025
bb4b0bc
working configs
Feb 23, 2025
a6ee97d
efficientnet with edge arch
Feb 26, 2025
8b0489e
more config and fix in squeeze
Mar 4, 2025
942aab5
Create mvxnet_sqeezefpn_rpfnet_kitti-3d-3class.py
aravindbisht Apr 30, 2025
41deecd
Create rpfnet.py
aravindbisht Apr 30, 2025
62301b4
Update __init__.py
aravindbisht Apr 30, 2025
df50970
Update rpfnet.py
aravindbisht Apr 30, 2025
34b9fc5
Create fire_rpfnet.py
aravindbisht Apr 30, 2025
5af8cd4
Update __init__.py
aravindbisht Apr 30, 2025
45f0d1c
Create mvxnet_sqeezefpn_fire_rpfnet_kitti-3d-3class.py
aravindbisht Apr 30, 2025
aa0fe77
Update mvxnet_sqeezefpn_rpfnet_kitti-3d-3class.py
aravindbisht Apr 30, 2025
975ba86
Create mvxnet_mobilenetv2_fpn_fire_rpfnet_kitti-3d-3class.py
aravindbisht May 1, 2025
8ccd9d5
Update mvxnet_mobilenetv2_fpn_second_fpn_kitti-3d-3class.py
aravindbisht May 1, 2025
182737d
Create mvxnet_efficiency_es_fpn_fire_rpfnet_kitti-3d-3class.py
aravindbisht May 1, 2025
f149e73
Update mvxnet_efficiency_es_fpn_fire_rpfnet_kitti-3d-3class.py
aravindbisht May 4, 2025
06a67b6
Update mvxnet_mobilenetv2_fpn_fire_rpfnet_kitti-3d-3class.py
aravindbisht May 4, 2025
1a6f571
Update mvxnet_sqeezefpn_fire_rpfnet_kitti-3d-3class.py
aravindbisht May 12, 2025
84ba1a3
Update mvxnet_sqeezefpn_rpfnet_kitti-3d-3class.py
aravindbisht May 12, 2025
e102a34
Update mvxnet_mobilenetv2_fpn_fire_rpfnet_kitti-3d-3class.py
aravindbisht May 12, 2025
013305c
Merge pull request #2 from sjsunadddy/pillarnet_exp_01
aravindbisht May 13, 2025
8cda1d1
Update multi_modality_det3d_inferencer.py
aravindbisht May 14, 2025
4f021b4
Merge pull request #3 from sjsunadddy/inference_update_kitti
aravindbisht May 14, 2025
a250a95
Merge pull request #4 from sjsunadddy/pillarnet_exp_01
aravindbisht May 14, 2025
549d915
feat: expensed to bev model abd v2
aravindbisht Nov 3, 2025
ee4d4be
feat: 2d version for 2d backbone
Nov 3, 2025
2edd856
Merge pull request #5 from aravindbisht/aravindbisht/extension_experi…
aravindbisht Nov 3, 2025
4320840
Create FireRPFNetExtension.md
aravindbisht Nov 4, 2025
b30b19e
Update bevfusion_lidar-cam_voxel0075_firerpfnet_8xb4-cyclic-20e_nus-3…
aravindbisht Nov 5, 2025
34cd101
feat : new config to use fireRPFNet with neck
aravindbisht Nov 11, 2025
d258054
Update bevfusion_lidar_voxel0075_firerpfnet_with_neck_8xb4-cyclic-20e…
aravindbisht Nov 11, 2025
c4d0123
Update bevfusion_lidar-cam_voxel0075_firerpfnet_with_neck_8xb4-cyclic…
aravindbisht Nov 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 209 additions & 0 deletions FireRPFNetExtension.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# FireRPFNet Models - Quick Start Guide

Custom 3D object detection models using FireRPFNet architecture with Fire Modules, Residual connections, and CBAM attention.

## 🔥 FireRPFNet Variants

- **FireRPFNetV2**: Enhanced 3D LiDAR backbone with improved attention
- **FireRPFNet2D**: 2D image backbone variant for camera features

**Plug-and-Play Design:**
- **FireRPFNetV2** can replace SECOND backbone in any model (BEVFusion is one example shown here)
- **FireRPFNet2D** can be used as an efficient image backbone in multi-modal architectures
- Simply update the backbone config to integrate into your existing models

---

## 📋 Available Models

| Model | Config | Image Backbone | LiDAR Backbone | Dataset | Modality |
|-------|--------|---------------|----------------|---------|----------|
| MVXNet-Squeeze | `configs/mvxnet/mvxnet_sqeezefpn_fire_rpfnet_kitti-3d-3class.py` | SQUEEZE | **FireRPFNetV2** | KITTI | Multi-modal |
| MVXNet-Fire2D | `configs/mvxnet/mvxnet_firerpfnet2dfpn_fire_rpfnet_kitti-3d-3class.py` | **FireRPFNet2D** | **FireRPFNetV2** | KITTI | Multi-modal |
| BEVFusion-Lidar | `projects/BEVFusion/configs/bevfusion_lidar_voxel0075_firerpfnet_8xb4-cyclic-20e_nus-3d.py` | - | **FireRPFNetV2** | nuScenes | LiDAR-only |
| BEVFusion-Cam | `projects/BEVFusion/configs/bevfusion_lidar-cam_voxel0075_firerpfnet_8xb4-cyclic-20e_nus-3d.py` | Swin-T | **FireRPFNetV2** | nuScenes | Multi-modal |

---

## 🚀 Installation

Follow the official MMDetection3D installation guide: https://mmdetection3d.readthedocs.io/en/latest/get_started.html

**Quick Setup:**
```bash
# Install dependencies
pip install -U openmim
mim install mmengine
mim install 'mmcv>=2.0.0rc4'
mim install 'mmdet>=3.0.0'

# Install mmdetection3d
cd mmdetection3d
pip install -v -e .
```

---

## 📦 Dataset Setup

### KITTI (MVXNet models)
```bash
# Download from http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d
# Organize: data/kitti/training/{image_2, velodyne, calib, label_2}

# Create data infos
python tools/create_data.py kitti --root-path ./data/kitti --out-dir ./data/kitti --extra-tag kitti
```

### nuScenes (BEVFusion models)
```bash
# Download from https://www.nuscenes.org/download
# Organize: data/nuscenes/{samples, sweeps, v1.0-trainval}

# Create data infos
python tools/create_data.py nuscenes --root-path ./data/nuscenes --out-dir ./data/nuscenes --extra-tag nuscenes
```

---

## 🏋️ Training Commands

### MVXNet Models (KITTI)

**Model 1: SqueezeFPN + FireRPFNetV2**
```bash
# Single GPU
python tools/train.py configs/mvxnet/mvxnet_sqeezefpn_fire_rpfnet_kitti-3d-3class.py

```
- Batch size: 2/GPU | Epochs: 20 | LR: 0.001 | Val: Every 5 epochs

**Model 2: FireRPFNet2D + FireRPFNetV2**
```bash
# Single GPU
python tools/train.py configs/mvxnet/mvxnet_firerpfnet2dfpn_fire_rpfnet_kitti-3d-3class.py

```
- Batch size: 4/GPU | Epochs: 16 | LR: 0.001 | Val: Every 2 epochs | Early stopping enabled

---

### BEVFusion Models (nuScenes)

**Model 3: BEVFusion LiDAR-only + FireRPFNetV2**
```bash
#
python tools/train.py projects/BEVFusion/configs/bevfusion_lidar_voxel0075_firerpfnet_8xb4-cyclic-20e_nus-3d.py
```
- Batch size: 4/GPU | Epochs: 20 | LR: 0.0002 | Cyclic scheduler

**Model 4: BEVFusion Multi-Modal + FireRPFNetV2**
```bash

# With mixed precision
python tools/train.py \
projects/BEVFusion/configs/bevfusion_lidar-cam_voxel0075_firerpfnet_8xb4-cyclic-20e_nus-3d.py \
--amp
```
- Batch size: 4/GPU (32 total) | Epochs: 6 | LR: 0.0002 | Val: Every epoch

---

## 🧪 Testing

### MVXNet Models
```bash
# Single GPU
python tools/test.py CONFIG CHECKPOINT

# Multi-GPU
bash tools/dist_test.sh CONFIG CHECKPOINT 4
```

**Examples:**
```bash
# Model 1
python tools/test.py \
configs/mvxnet/mvxnet_sqeezefpn_fire_rpfnet_kitti-3d-3class.py \
work_dirs/mvxnet_sqeezefpn_fire_rpfnet_kitti-3d-3class/best_checkpoint.pth

# Model 2
bash tools/dist_test.sh \
configs/mvxnet/mvxnet_firerpfnet2dfpn_fire_rpfnet_kitti-3d-3class.py \
work_dirs/mvxnet_firerpfnet2dfpn_fire_rpfnet_kitti-3d-3class/best_checkpoint.pth 4
```

### BEVFusion Models
```bash
# Model 3
bash tools/dist_test.sh \
projects/BEVFusion/configs/bevfusion_lidar_voxel0075_firerpfnet_8xb4-cyclic-20e_nus-3d.py \
work_dirs/bevfusion_lidar_voxel0075_firerpfnet_8xb4-cyclic-20e_nus-3d/best_checkpoint.pth 8

# Model 4
bash tools/dist_test.sh \
projects/BEVFusion/configs/bevfusion_lidar-cam_voxel0075_firerpfnet_8xb4-cyclic-20e_nus-3d.py \
work_dirs/bevfusion_lidar-cam_voxel0075_firerpfnet_8xb4-cyclic-20e_nus-3d/best_checkpoint.pth 8
```

---

## 💡 Tips

**Resume Training:**
```bash
python tools/train.py CONFIG --resume work_dirs/MODEL_NAME/epoch_X.pth
```

**Specify GPUs:**
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 bash tools/dist_train.sh CONFIG 4
```

**Debug Mode:**
```bash
python tools/train.py CONFIG \
--cfg-options data.train_dataloader.num_workers=0 \
data.train_dataloader.batch_size=1
```

**Monitor Training:**
```bash
tensorboard --logdir=work_dirs/
```

---

## 🐛 Common Issues

**CUDA OOM:** Reduce batch size in config or via `--cfg-options data.train_dataloader.batch_size=1`

**Dataset not found:** Verify paths and run `python tools/create_data.py`

**Import errors:** Reinstall with `pip install -v -e .`

---

## 📚 References

- [MMDetection3D Documentation](https://mmdetection3d.readthedocs.io)
- [KITTI Dataset](http://www.cvlibs.net/datasets/kitti/)
- [nuScenes Dataset](https://www.nuscenes.org/)

---

## 📝 Citation

```bibtex
@article{firerpfnet2024,
title={FireRPFNet: Efficient 3D Object Detection with Fire Modules and Attention},
author={Aravind Singh},
journal={arXiv preprint},
year={2024}
}
```

---

**Happy Training! 🚀**

Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
voxel_size = [0.2, 0.2, 8]
model = dict(
type='CenterPoint',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
voxel=True,
voxel_layer=dict(
max_num_points=20,
voxel_size=voxel_size,
max_voxels=(30000, 40000))),
pts_voxel_encoder=dict(
type='PillarFeatureNet',
in_channels=5,
feat_channels=[64],
with_distance=False,
voxel_size=(0.2, 0.2, 8),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
legacy=False),
pts_middle_encoder=dict(
type='PointPillarsScatter', in_channels=64, output_shape=(512, 512)),
pts_backbone=dict(
type='SQUEEZE',
in_channels=64,
out_channels=[64, 128, 256 , 512],
#layer_nums=[3, 5, 5],
#layer_strides=[2, 2, 2],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
conv_cfg=dict(type='Conv2d', bias=False)),
pts_neck=dict(
type='SQUEEZEFPN',
in_channels=[64, 128, 256 , 512],
out_channels=[512, 512, 512, 512],
#upsample_strides=[0.5, 1, 2],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
upsample_cfg=dict(type='deconv', bias=False),
#use_conv_for_no_stride=True
),
pts_bbox_head=dict(
type='CenterHead',
in_channels=sum([128, 128, 128,128]),
#in_channels=256,
tasks=[
dict(num_class=1, class_names=['car']),
dict(num_class=2, class_names=['truck', 'construction_vehicle']),
dict(num_class=2, class_names=['bus', 'trailer']),
dict(num_class=1, class_names=['barrier']),
dict(num_class=2, class_names=['motorcycle', 'bicycle']),
dict(num_class=2, class_names=['pedestrian', 'traffic_cone']),
],
common_heads=dict(
reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), vel=(2, 2)),
share_conv_channel=64,
bbox_coder=dict(
type='CenterPointBBoxCoder',
post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
max_num=500,
score_threshold=0.1,
out_size_factor=4,
voxel_size=voxel_size[:2],
code_size=9),
separate_head=dict(
type='SeparateHead', init_bias=-2.19, final_kernel=3),
loss_cls=dict(type='mmdet.GaussianFocalLoss', reduction='mean'),
loss_bbox=dict(
type='mmdet.L1Loss', reduction='mean', loss_weight=0.25),
norm_bbox=True),
# model training and testing settings
train_cfg=dict(
pts=dict(
grid_size=[512, 512, 1],
voxel_size=voxel_size,
out_size_factor=4,
dense_reg=1,
gaussian_overlap=0.1,
max_objs=500,
min_radius=2,
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2])),
test_cfg=dict(
pts=dict(
post_center_limit_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
max_per_img=500,
max_pool_nms=False,
min_radius=[4, 12, 10, 1, 0.85, 0.175],
score_threshold=0.1,
pc_range=[-51.2, -51.2],
out_size_factor=4,
voxel_size=voxel_size[:2],
nms_type='rotate',
pre_max_size=1000,
post_max_size=83,
nms_thr=0.2)))
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
model = dict(
type='VoxelNet',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
voxel=True,
voxel_layer=dict(
max_num_points=5,
point_cloud_range=[0, -40, -3, 70.4, 40, 1],
voxel_size=[0.05, 0.05, 0.1],
max_voxels=(16000, 40000))),
voxel_encoder=dict(type='HardSimpleVFE'),
middle_encoder=dict(
type='SparseEncoder',
in_channels=4,
sparse_shape=[41, 1600, 1408],
order=('conv', 'norm', 'act')),
backbone=dict(
type='SQUEEZE',
in_channels=3,
out_channels=[64, 128, 256, 512],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
conv_cfg=dict(type='Conv2d', bias=False)),
neck=dict(
type='SQUEEZEFPN',
in_channels=[64, 128, 256, 512],
out_channels=[256, 256, 256, 256],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
upsample_cfg=dict(type='deconv', bias=False),
conv_cfg=dict(type='Conv2d', bias=False)),
bbox_head=dict(
type='Anchor3DHead',
num_classes=3,
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='Anchor3DRangeGenerator',
ranges=[[0, -40, -1.8, 70.4, 40, -1.8]],
sizes=[[1.6, 3.9, 1.56]],
rotations=[0, 1.57]),
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
loss_cls=dict(type='FocalLoss', use_sigmoid=True, gamma=2.0, alpha=0.25),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0),
loss_dir=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.2)),
train_cfg=dict(assigner=dict(type='MaxIoUAssigner')),
test_cfg=dict(use_rotate_nms=True, nms_across_levels=False, nms_pre=1000, nms_thr=0.01, score_thr=0.1, min_bbox_size=0, max_num=500)
)
Loading