Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions examples/models/vlm/qwen3_vl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,24 @@ export WORKSPACE=/your/custom/path
```

Directory structure:

- `${WORKSPACE}/models/` - Converted checkpoints
- `${WORKSPACE}/results/` - Training outputs and experiment results

## Checkpoint Conversion

### Import HF → Megatron

To import the HF VL model to your desired Megatron path:

```bash
python examples/conversion/convert_checkpoints.py import \
--hf-model Qwen/Qwen3-VL-8B-Instruct \
--megatron-path ${WORKSPACE}/models/Qwen3-VL-8B-Instruct
```

### Export Megatron → HF

```bash
python examples/conversion/convert_checkpoints.py export \
--hf-model Qwen/Qwen3-VL-8B-Instruct \
Expand All @@ -48,15 +52,18 @@ python -m torch.distributed.run --nproc_per_node=4 examples/conversion/hf_to_meg
```

Note:

- `--megatron_model_path` is optional. If not specified, the script will convert the model and then run forward.
- You can also use image URLs: `--image_path="https://example.com/image.jpg"`

See the [inference.sh](inference.sh) script for commands to:

- Run inference with Hugging Face checkpoints
- Run inference with imported Megatron checkpoints
- Run inference with exported Hugging Face checkpoints

**Expected output:**

```
...
Generation step 46
Expand Down Expand Up @@ -88,8 +95,9 @@ Here is a breakdown of the key specifications:
- `qwen3_vl_8b_finetune_config`: Finetuning for 8B VL model with PEFT support
- `qwen3_vl_30b_a3b_finetune_config`: Finetuning for 30B-A3B VL model with PEFT support
- `qwen3_vl_235b_a22b_finetune_config`: Finetuning for 235B-A22B VL model with PEFT support

Before training, ensure the following environment variables are set:

1. `HF_TOKEN`: to download models from HF Hub (if required)
2. `HF_HOME`: (optional) to avoid re-downloading models and datasets
3. `WANDB_API_KEY`: (optional) to enable WandB logging
Expand Down Expand Up @@ -125,7 +133,7 @@ Follow the instructions [here](https://github.com/NVIDIA/Megatron-LM/tree/main/e
__module__: megatron.bridge.recipes.qwen_vl.data.energon.task_encoder
__class__: ChatMLWebdataset
field_map:
imgs: jpg
imgs: jpgs
conversation: json
```

Expand Down
19 changes: 16 additions & 3 deletions src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,17 +216,30 @@ def __init__(self, imagespec):
self.extensions_mapping = {"jpgs": "jpg", "mp4s": "jpg", "videos": "jpg"}
self.image_handler = imagehandler(imagespec)

def _convert_to_tensor(self, data):
"""Convert numpy array or bytes to tensor.

The wds conversion script stores images as numpy arrays (HWC, uint8),
so we need to handle both numpy arrays and raw bytes.
"""
if isinstance(data, np.ndarray):
# Data is already a numpy array (HWC, uint8) from pickle
# Convert to tensor (CHW, float32 in [0,1])
return torch.from_numpy(data).permute(2, 0, 1).float() / 255.0
else:
# Data is raw bytes, use imagehandler to decode
return self.image_handler("jpg", data)

def __call__(self, key, data):
"""Perform nested image decoding."""
extension = re.sub(r".*[.]", "", key)
if extension.lower() not in self.extensions:
return None
data = pickle.loads(data)
key = self.extensions_mapping[extension]
if extension.lower() == "jpgs":
data = [self.image_handler(key, d) for d in data]
data = [self._convert_to_tensor(d) for d in data]
else:
data = [[self.image_handler(key, d) for d in video] for video in data]
data = [[self._convert_to_tensor(d) for d in video] for video in data]
return data


Expand Down