feat(mm-cpt): extract OCR image-tiling into the mm_tiling plugin#40
feat(mm-cpt): extract OCR image-tiling into the mm_tiling plugin#40thad0ctor wants to merge 40 commits into
Conversation
Adds a streaming-first multimodal CPT path: raw `(text, images[])` rows
are tokenized once with a placeholder-count guardrail, batched through a
hardened collator, and fed to a VLM with image-family tokens masked out
of labels. Gated by `type: multimodal_pretrain` (or `multimodal: true`)
on a `pretraining_dataset` entry; works end-to-end for train and eval,
including multi-entry eval and mixed image/text batches.
Features
--------
- Streaming MM CPT encoder (`encode_streaming_multimodal`): counts
placeholders by token id (not substring), enforces
`placeholders == len(images)` per row, and rejects rows that exceed
`sequence_len` instead of silently truncating mid-placeholder.
- MM CPT collator (`MultiModalPretrainDataCollator`): security-hardened
image loader (path traversal / NUL byte / remote URL / multi-frame
bomb / pixel cap rejection), per-row image cap, processor-call retry
that pinpoints the offending row, and label-side masking of every
image-family token id.
- Mixed image/text batches: text-only rows in a batch take a
tokenizer-only fallback (no `pixel_values`); rows with images go
through the processor as usual.
- Eval support: `test_datasets` accepts MM entries via a dedicated
`MultiModalEvalDataset` model so per-entry `text_column` /
`image_column` / `image_base_dir` / `image_token` survive validation.
Multi-entry MM eval streams are concatenated.
- `dispatch_batches: true` support: non-main ranks get a placeholder
dataset that mirrors the configured text + image columns.
- Config validation gates: `processor_type` required, `sample_packing:
false` enforced, `chat_template` rejected, single
`pretraining_dataset` entry required, MM eval entries must share
`image_base_dir` / `image_token`, mixed MM/non-MM eval rejected,
incompatible processor classes (Mllama, Pixtral, InternVL) rejected
at startup. `remove_unused_columns` is auto-set to `false` with an
INFO log.
- Docs: new section in `docs/multimodal.qmd` covering the YAML shape,
placeholder-token table, eval contract, and supported/rejected model
families.
YAML example
------------
base_model: HuggingFaceTB/SmolVLM-500M-Instruct
processor_type: AutoProcessor
pretraining_dataset:
- path: /path/to/shards/*.jsonl
ds_type: json
type: multimodal_pretrain
text_column: text
image_column: images
image_base_dir: /path/to/images
streaming: true
sequence_len: 2048
sample_packing: false
Tests
-----
59 tests across four suites covering the encoder, collator (including
mixed/all-text batches and security gates), prompt strategy, schema
preservation, multi-entry eval merge, eval homogeneity validation,
eval-aware collator, dispatch-batches placeholder shape, and the
auto-set log record.
Addresses CodeRabbit review on PR axolotl-ai-cloud#3629. No behavior change for the happy path; expands schema, hardens fallbacks, tightens validation. Bug fixes --------- - Gemma-3 autodetect: prefer `processor.boi_token` over `image_token` when they differ (Gemma-3's `image_token` is the post-expansion soft token, not the user-facing placeholder). Without this, MM CPT crashed on the first batch with "Prompt contained 0 image tokens". - `dispatch_batches: true` placeholder dataset now mirrors the configured `image_column` so worker ranks don't KeyError on empty rows. - `tokenize_prompt` rejects falsy non-None image cells (`""`, `0`, `False`) instead of coercing to `[]` — keeps malformed rows from silently turning into text-only samples. Schema completeness ------------------- - Add `ds_type` to `PretrainingDataset` and `MultiModalEvalDataset` (the documented `ds_type: json` shape now actually reaches `load_dataset`; previously dropped at validation). - Preserve `trust_remote_code` through `_pretraining_config_from_entry` and pass it to `load_dataset` (was silently dropped). - Honor `cfg.eval_sequence_len` in MM CPT eval streams (encoder + collator) with documented fallback to `cfg.sequence_len` when unset. Validation tightening (config-load time) ---------------------------------------- - Reject mixed multimodal/text entries in `test_datasets`. - Reject MM `test_datasets` paired with non-MM training. - Reject non-MM `test_datasets` paired with MM training. - The redundant runtime check in `sft.py` is removed; schema is the single source of truth. Hardening / observability ------------------------- - Mixed/all-text batch handling: collator routes all-text batches to the tokenizer (no `pixel_values`); mixed batches go through the processor as-is. Documented per-VLM compatibility (verified on SmolVLM/SmolVLM2, Gemma-3, Gemma-4, Qwen2.5-VL, Qwen3-VL). - Reject cloud/object-store URIs (`s3://`, `gs://`, `gcs://`, `az://`, `azure://`, `hf://`) in image paths so users see "Non-local scheme" instead of a confusing FileNotFoundError. - Warn when `MultiModalPretrainDataCollator.tokenizer is not processor.tokenizer` (all-text vs image batches could otherwise tokenize the same text differently). - Warn at retry kickoff when a processor call fails on a batch, so users see why processing stalls during per-row diagnosis. - INFO log when `remove_unused_columns` is auto-set to `false` for MM CPT. - DEBUG log when `tokenizer.get_added_vocab()` fails (was silent pass). - Clarify "exceeds sequence_len" error to note image-patch expansion may push the final length higher. Tests ----- +8 regression tests across the four MM CPT suites covering: Gemma-3 boi_token autodetection, eval_sequence_len (encoder + collator, including the fallback case), `trust_remote_code` and `ds_type` preservation through validation, three modality-mismatch validation cases, tokenizer-mismatch warning, cloud-URI rejection. 68 tests pass across `tests/test_multimodal_streaming.py`, `tests/prompt_strategies/test_multimodal_pretrain.py`, `tests/utils/schemas/validation/test_multimodal_cpt.py`, `tests/utils/data/test_mm_cpt_eval.py`. Lint clean against ruff v0.15.8 (upstream pre-commit pin).
Addresses CodeRabbit review on PR axolotl-ai-cloud#3629. No behavior change for the happy path; expands schema, hardens fallbacks, tightens validation. Bug fixes --------- - Gemma-3 autodetect: prefer `processor.boi_token` over `image_token` when they differ. Without this, MM CPT crashed on the first batch with "Prompt contained 0 image tokens". - `dispatch_batches: true` placeholder dataset mirrors the configured `image_column` so worker ranks don't KeyError on empty rows. - `tokenize_prompt` rejects falsy non-None image cells (`""`, `0`, `False`) instead of coercing to `[]`. - `_tokenize` now honors `add_eos_token` / `strip_bos_token` instead of silently ignoring them. Schema ------ - Add `ds_type` to `PretrainingDataset` and `MultiModalEvalDataset` (the documented `ds_type: json` shape now reaches `load_dataset`). - Preserve `trust_remote_code` through `_pretraining_config_from_entry` and pass to `load_dataset`. - Honor `cfg.eval_sequence_len` in MM CPT eval streams (encoder + collator) with documented fallback to `cfg.sequence_len`. Validation (config-load time) ----------------------------- - Reject mixed multimodal/text entries in `test_datasets`. - Reject MM `test_datasets` paired with non-MM training. - Reject non-MM `test_datasets` paired with MM training. - Removed the redundant runtime check in `sft.py`; the schema is now the single source of truth. Hardening / observability ------------------------- - Mixed/all-text batch handling: collator routes all-text batches to the tokenizer (no `pixel_values`); mixed batches go through the processor as-is. Documented per-VLM compatibility (verified on SmolVLM/SmolVLM2, Gemma-3, Gemma-4, Qwen2.5-VL, Qwen3-VL). - Reject cloud / object-store URIs (`s3://`, `gs://`, `gcs://`, `az://`, `azure://`, `hf://`) in image paths so users see the explicit "Non-local scheme" error instead of a confusing FileNotFoundError. - Warn at construction when `MultiModalPretrainDataCollator.tokenizer` is not `processor.tokenizer` (all-text vs image batches could otherwise tokenize the same text differently). - Warn at retry kickoff when a processor call fails on a batch, so users see why processing stalls during per-row diagnosis. - INFO log when `remove_unused_columns` is auto-set to `false` for MM CPT. - DEBUG log when `tokenizer.get_added_vocab()` fails (was silent pass). - Clarify "exceeds sequence_len" error in both encoder paths to note image-patch expansion may push the final length higher. Code quality ------------ - Lift `image_token_spec` into `MultimodalPretrainTokenizationStrategy. __init__` instead of post-construction monkey-patch + `type: ignore`. - Hoist `import importlib` out of the per-class loop. - Drop dead `n_chunks` multiplication; replace with explicit invariant assertion. - Replace ambiguous `×` (U+00D7) with ASCII `x` in code/comments and the user-facing pixel-cap error. Tests ----- +15 regression tests across the four MM CPT suites covering: Gemma-3 boi_token autodetect (with id-mapping assertion), `eval_sequence_len` on encoder + collator (set + unset-fallback), `trust_remote_code` and `ds_type` preservation, three modality-mismatch validation cases, tokenizer-mismatch warning, `remove_unused_columns` auto-set log, cloud-URI rejection. 68 tests pass across `test_multimodal_streaming`, `test_multimodal_pretrain`, `test_multimodal_cpt`, and `test_mm_cpt_eval`. Lint clean against ruff v0.15.8 (upstream pre-commit pin).
- mm_pretrain.py: return BatchEncoding (not dict) from all-text branch so
it matches the imaged path.
- test_multimodal_cpt.py, test_multimodal_streaming.py: monkeypatch
axolotl logger propagate=True so caplog can capture records (axolotl's
logging config sets propagate=False, blocking root capture in CI).
multimodal_pretrain.py: scope the boi_token swap in build_image_token_spec
to processors whose `image_token` name contains "soft_token" (the Gemma-3
convention). Without this, Gemma-4 (`image_token=<|image|>`,
`boi_token=<|image>`) gets the wrong placeholder autodetected and every
row fails validation with a 0-vs-N placeholder/image mismatch.
test_multimodal_streaming.py: 6 new tests
- Two for the new autodetection behavior (Gemma-4 keeps image_token,
Gemma-3 still swaps to boi_token), using stub processors.
- Three branch-coverage tests for build_image_token_spec failure modes:
override not registered as special token, override resolves to unk,
nothing autodetectable.
- Three collator-path tests: skip_bad_images drops a row and continues,
all-rows-dropped surfaces a RuntimeError, multi-frame GIF triggers
the animation-bomb guard via _open_image_hardened.
fix(test): patch parent `axolotl` logger so negative caplog assertion has teeth
The previous monkeypatch targeted `axolotl.utils.schemas.validation`, which
is already propagate=True by inheritance — the actual block sits one level
up at the `axolotl` logger (propagate=False from logging_config.py). The
result: caplog never received any records, and `assert not any("Auto-set"
... in caplog.records)` would have passed even if the regression fired.
Mirror the positive test by flipping propagate on `logging.getLogger("axolotl")`
and add a comment explaining why the leaf isn't the right target.
…ry loop Some HF processors reject `images=[[]]`, which made the per-row retry flag innocent text-only rows as the offender. Mirror the all-text bypass — diagnostic-only path, mainline unchanged.
Fix text following Gemma 4 regression fix
MultiModalPretrainDataCollator.torch_call calls processor(text=...) which re-tokenizes _mm_text from scratch, discarding the EOS that encode_streaming_multimodal appended to input_ids. Without this, labels never contain EOS at end-of-document and the model never learns to emit a stop token — symptoms: non-terminating / repetitive generation. Match the text CPT contract (encode_streaming keeps EOS in both input_ids and labels) by appending EOS to _mm_text, idempotently, gated on a new add_eos_token field (default True).
Aligns mm_pretrain.py with mm_chat.py's image-loading posture. Drops NUL/URL/path-traversal/pixel-cap/multi-frame/per-row-count guards that defended against threat models that don't apply to a CLI trainer loading its own dataset. Routes image loading through transformers.image_utils.load_image, matching the chat path. Keeps image_base_dir join, skip_bad_images, label masking, processor compatibility check, and the tokenizer/processor mismatch warning.
Adds two opt-in knobs that make `auto_resume_from_checkpoints: true` cheap for streaming `pretraining_dataset` + `type: multimodal_pretrain`: - `cfg.ignore_data_skip` forwards to `TrainingArguments.ignore_data_skip` via the existing kwargs passthrough; lets HF Trainer skip the dataloader fast-forward on resume. - When `dataset_prepared_path` is set (and `skip_prepare_dataset` is false, not eval, not `dispatch_batches`), the streaming MM-CPT loader builds a map-style Arrow cache once and reuses it on subsequent runs. Cache hits return a map-style `Dataset` so HF Trainer seeks by index instead of iterating; multi-rank coordinated by the existing `FileLockLoader`. Cache key (`generate_pretraining_dataset_hash`) includes a real processor fingerprint (class + image_token + image_processor size/patch_size/etc.) so cached arrows don't collide when only the processor variant changes under the same class name. `eval_sequence_len` and `image_base_dir` are deliberately excluded — eval-only / collator-runtime, neither binds the cached arrows. Map-style return path also computes `total_num_steps` via `calculate_total_num_steps` so warmup/scheduler math stays correct. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
CodeRabbit nit: use module + __qualname__ in _processor_fingerprint so same-named classes in different modules don't collide in the cache key. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The non-streaming `datasets:` MM CPT route was never wired through `build_collator`, which only routes MM batches under the pretraining branch — `datasets:` entries would emit `images`/`_mm_text` rows into a text-only collator. Strip the strategy class + `load()` and their unit tests; keep `ImageTokenSpec`, `build_image_token_spec`, and `check_processor_compatibility` since the streaming collator imports them. Add a docs callout that only the streaming `pretraining_dataset` route is currently wired. Fold `MultiModalEvalDataset` into `PretrainingDataset` via inheritance; the only intentional divergence is the `type` default and the `_require_mm_markers` validator. Drops ~60 lines of duplicated `Field` declarations the reviewer flagged. Tighten the collator `KeyError` message to mention only `encode_streaming_multimodal` now that the strategy class is gone.
…datasets: Previously failed with a raw AttributeError at strategy load time. Now raises a small ValueError pointing users to the supported entry point.
… cache
Add multimodal image tiling for OCR continual-pretraining and SFT:
- Config-driven tiling (image_tiling_*) with shape-bucketed policies; the
ocr_pages preset maps landscape->3x2, portrait->2x3, tall->2x4.
- Persistent SSD tile cache keyed by source path/stat (optional sha256) +
resolved tiling policy + cache-format/Pillow version; fsync-durable writes
with manifest-key validation.
- DocOwl/InternVL-style textual position labels injected per tile
(<global_img> for the overview, <row{r}_col{c}> for each tile, columns
numbered left-to-right regardless of RTL traversal). Toggle via
image_tiling_tile_labels (default on).
- OCR eval harness (scripts/mm_ocr_eval.py) reporting weighted loss and CER.
- Tiling applied consistently in the length-metadata path and the runtime
collator so packed-sample lengths stay aligned with tiled processor inputs.
- Tiling policy folded into the dataset/packing cache hashes so a policy
change invalidates stale prepared datasets.
- EXIF-orientation handling, min_area canvas pinning, and config validation
(overlap range, shape-bucket presets, unused-field warnings).
- mm_packing: md5 usedforsecurity=False (Bandit B324); validate positive packing capacities before the sort divides by them (ZeroDivisionError); narrow and log tokenizer token-id conversion failures (Bandit B112). - multimodal_pretrain: honor add_eos_token when building cached row ids so add_eos_token=False no longer drifts from the processor-length path. - mm_ocr_eval: guard empty dict in first_data_file; use cfg.get for sequence_len. - benchmark/probe scripts: portable tempfile.gettempdir() cache dirs (Bandit S108); drop hardcoded model paths in favor of AXOLOTL_MODEL_ROOTS.
Reformat to satisfy the ruff-format pre-commit hook (no logic changes).
ruff/ruff-format/bandit already passed after the prior commits; the remaining
pre-commit failure was mypy. The benchmark scripts' `from scripts.X import`
made mypy bail early ("source file found twice"), masking the type checks.
- add scripts/__init__.py so mypy resolves scripts as a package
- mm_image: guard None and avoid param reassignment in resize_image_for_processor
- multimodal_pretrain: annotate the encode output dict
- mm_packing / probe scripts: build image_grid_thw as explicit (t, h, w) tuples
- streaming: annotate the packed accumulator
- mm_ocr_eval: resolve sequence_len to a guarded int
- benchmarks: index pixel_values after the membership check
Pass each tile at its native crop resolution (ignoring image_tiling_tile_size) so the reconstructed page approximates the original — detail-preserving for OCR at the cost of higher token count and variable tile shapes; capped by the processor's max-pixels. Position labels and the rest of the tiling path are unchanged. Verified: native tiles reconstruct to ~100% of original area vs 41% (fixed 512px) / 166% padded (fixed 1024px) on a 1629x2329 page.
Stage 1 of extracting tiling into a plugin. Adds a tiling-agnostic image-transform extension point: MMImageTransform protocol + resolve_mm_image_transform(cfg) + BasePlugin.get_mm_image_transform. resolve_ asks plugins first, then falls back to the in-tree tiling adapter (TilingImageTransform) so behavior is unchanged. No core call sites rewired yet.
Replace the threaded ImageTilingConfig with the resolved MMImageTransform at every parity-critical site so core no longer references tiling specifics: - runtime collator (mm_pretrain), both length estimators (multimodal_pretrain, mm_packing), and the SFT processing strategy now call transform.per_source / expand_placeholders / prepare / new_cache. - dataset + metadata cache hashes fold in transform.policy_payload(). - entry points (causal builder, streaming) resolve via resolve_mm_image_transform(cfg). Behavior is unchanged (the resolver still falls back to the in-tree tiling adapter); 145 tests + mypy pass. Tiling impl + config still live in core util/ schema; Stage 3 moves them into the plugin.
New axolotl.integrations.mm_tiling.MMTilingPlugin provides the tiling image transform through the get_mm_image_transform hook. With the plugin enabled the transform resolves from the plugin (the core fallback remains for back-compat). Config fields + tiling impl still live in core; moving them fully into the plugin (and dropping the fallback) is the remaining cleanup.
…age 3b) BREAKING: image tiling now requires `plugins: [axolotl.integrations.mm_tiling.MMTilingPlugin]`. - Move the tiling impl (mm_tiling.py -> integrations/mm_tiling/tiling.py) and the ~16 image_tiling_* config fields + validators into the plugin (args.py, get_input_args). Core schema/training_args/builder plumbing dropped. - Remove the transitional fallback in resolve_mm_image_transform: the transform now comes solely from the plugin hook. Core is tiling-free (only a string cache-key label and a docstring remain). - SFT strategy resolves the transform from cfg (not training_args). - Migrate config-path tests to enable the plugin (shared mm_tiling_plugin fixture); add native MMTilingArgs tests. 148 tests + mypy + ruff pass.
…format) mm_ocr_eval builds the TilingImageTransform directly (standalone, no plugin registered) and passes image_transform= to encode/collator. ruff-format applied. pre-commit --all-files (ruff/ruff-format/mypy/bandit) + 148 tests pass.
|
@coderabbitai full review |
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughAdds multimodal CPT: docs/examples; prompt strategy and collators; streaming, 2D packing, and caching; image tiling/resize integration and plugin; trainer/dataloader args; schema validation; CLI tooling; and extensive tests. ChangesMultimodal CPT Enablement
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes ✨ Finishing Touches🧪 Generate unit tests (beta)
|
✅ Actions performedFull review triggered. |
|
📖 Documentation Preview: Deployed on Netlify from commit d0fa87f |
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (5)
src/axolotl/prompt_strategies/multimodal_pretrain.py (1)
500-513: 💤 Low valueConsider adding debug logging to the fallback exception handler.
The broad
except Exception:at line 500 catches processor batching failures and falls back to per-row processing, which is a reasonable degradation strategy. However, adding a debug-level log statement would help diagnose why batch processing failed without changing behavior.📝 Optional logging enhancement
except Exception: + LOG.debug("Processor batch call failed; falling back to per-row length computation") lengths: list[int] = []🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/axolotl/prompt_strategies/multimodal_pretrain.py` around lines 500 - 513, In the except Exception: fallback inside multimodal_pretrain (where processor_texts and loaded_images are zipped and batches are created using processor or tokenizer and lengths are computed via _batch_lengths_from_processor_output), add a debug-level log that records the caught exception and a brief context (e.g., number of items in processor_texts/loaded_images or a sample text) before falling back to per-row processing; use the existing logger (or introduce one named logger or process_logger) and include exception information (exc_info or str(exception)) so the fallback behavior remains unchanged but the root cause is recorded for debugging.src/axolotl/utils/data/mm_image_transform.py (1)
76-81: ⚡ Quick winBroad
except Exceptionsilently hides real plugin/config errors.This intentionally tolerates an uninitialized manager, but it also swallows genuine failures (e.g. a misconfigured tiling config that raises during
get_mm_image_transform), which surfaces downstream as silently untiled training — the exact footgun called out as a known limitation. Consider a debug log so the cause is recoverable without changing the fallback behavior.♻️ Log the swallowed exception
try: from axolotl.integrations.base import PluginManager return PluginManager.get_instance().get_mm_image_transform(cfg) - except Exception: # noqa: BLE001 - manager may be uninitialized in tooling + except Exception: # noqa: BLE001 - manager may be uninitialized in tooling + from axolotl.utils.logging import get_logger + + get_logger(__name__).debug( + "resolve_mm_image_transform: no transform resolved", exc_info=True + ) return None🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/axolotl/utils/data/mm_image_transform.py` around lines 76 - 81, The broad except in the function returning PluginManager.get_instance().get_mm_image_transform(cfg) hides real errors; keep the fallback behavior (return None) but log the caught exception so misconfigurations are discoverable. Update the except block to capture the exception as e and emit a debug/exception-level log (e.g. via logging.getLogger(__name__) or the module's logger) with context like "get_mm_image_transform failed" and include the exception traceback or message, then return None; do not change the current fallback logic.tests/utils/data/test_mm_packing.py (1)
197-209: ⚡ Quick winReplace mutable class attribute with a property.
The
sizedict as a class attribute could lead to test pollution if any test modifies it. Use a property or define it in__init__to ensure each instance gets its own dict.♻️ Proposed fix
class _ImageProcessor: - size = {"height": 224, "width": 224} + `@property` + def size(self): + return {"height": 224, "width": 224}🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/utils/data/test_mm_packing.py` around lines 197 - 209, The _ImageProcessor class defines a mutable class attribute size = {"height": 224, "width": 224} which can be mutated across tests; change it to an instance attribute or a read-only property to avoid test pollution: update _ImageProcessor (referenced by _Processor.image_processor) so that size is created per-instance (e.g., set self.size = {"height": 224, "width": 224} in __init__) or expose a `@property` that returns a new dict each call, and keep the external API (attribute name size) the same so existing tests using _ImageProcessor or _Processor.image_processor continue to work.scripts/mm_packing_probe.py (2)
193-195: ⚡ Quick winAvoid ternary expressions for side effects.
Using a ternary/conditional expression (
... if ... else ...) for method calls with side effects (likeextendandappend) is considered poor style and reduces readability.♻️ Recommended refactor for clarity
- paths.extend(str(path) for path in value) if isinstance( - value, list - ) else paths.append(str(value)) + if isinstance(value, list): + paths.extend(str(path) for path in value) + else: + paths.append(str(value))🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@scripts/mm_packing_probe.py` around lines 193 - 195, Replace the ternary used for side effects with an explicit if/else to improve readability: check isinstance(value, list) and call paths.extend(str(path) for path in value) in the if branch, otherwise call paths.append(str(value)) in the else branch; keep the same variables and behavior for paths, value, extend, append and isinstance.
202-204: ⚡ Quick winAvoid ternary expressions for side effects.
Using a ternary/conditional expression for method calls with side effects reduces readability.
♻️ Recommended refactor for clarity
- expanded.extend( - Path(match) for match in matches - ) if matches else expanded.append(Path(path)) + if matches: + expanded.extend(Path(match) for match in matches) + else: + expanded.append(Path(path))🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@scripts/mm_packing_probe.py` around lines 202 - 204, The ternary-style expression performing side effects reduces readability; replace the conditional expression with an explicit if/else around the side-effecting calls: check the variable matches and if truthy call expanded.extend(Path(match) for match in matches), otherwise call expanded.append(Path(path)); this targets the code that currently uses expanded.extend(Path(match) for match in matches) if matches else expanded.append(Path(path)) and preserves the same behavior while improving clarity.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/axolotl/core/builders/causal.py`:
- Around line 548-549: The code silently allows image_tiling=True with no
resolved transform (image_transform == None) which converts a tiled run into an
untiled one; update the builder to validate after calling
resolve_mm_image_transform(self.cfg) (the symbol image_transform returned by
resolve_mm_image_transform) and if self.cfg.image_tiling is truthy but
image_transform is None, raise a clear exception (ValueError/RuntimeError)
describing the missing plugin/transform so the run fails fast; apply the same
check and error at the other call site that also uses resolve_mm_image_transform
(the block referenced around lines 681-682) so both locations enforce tiling
configuration consistency.
- Around line 654-660: The per-dataset aggregation builds field_messages by
appending entire list/tuple objects, causing nested lists; change the loop that
calls _ds_get over ds_entries to normalize each field_message: if it's a str,
treat as single-item list; if it's a list/tuple, iterate its elements and append
each string individually to field_messages only if not already present (preserve
order and dedupe). Ensure you update the block that currently defines
field_messages and iterates ds_entries so ProcessingStrategy._get_messages_field
receives flat string names rather than lists.
In `@src/axolotl/core/trainers/base.py`:
- Around line 322-324: The eval-dataloader retention logic must use the
effective persistent_workers after applying multimodal overrides: after calling
_apply_multimodal_dataloader_overrides(dataloader_params, self.args,
data_collator) (and/or any code that mutates dataloader_params), compute the
actual flag e.g. effective_persistent =
dataloader_params.get("persistent_workers",
self.args.dataloader_persistent_workers) and use effective_persistent (not
self.args.dataloader_persistent_workers) in the _eval_dataloaders retention
branch so loaders created with persistent_workers=True are kept for
Accelerator.free_memory() to work correctly.
In `@src/axolotl/utils/collators/mm_pretrain.py`:
- Around line 206-226: The fallback length computation is using raw "_mm_text"
but must use the same prepared text and image expansion used in the packed path;
update _fallback_row_lengths to call _prepare_row_text_and_images(row) for each
row, take the returned prepared text (after any placeholder expansion) and pass
prepare_text_for_packed_boundary(prepared_text) (not raw _mm_text) to
compute_multimodal_processor_lengths, and use the image list returned by
_prepare_row_text_and_images for image_sources; apply the same change at the
other fallback call site around the block flagged (the similar call at lines
~275-278) so both fallback paths compute lengths from the prepared/packed text
and expanded images rather than raw _mm_text.
In `@src/axolotl/utils/data/sft.py`:
- Around line 458-468: The placeholder CSV path leaks because a
NamedTemporaryFile is created with delete=False and never removed; change the
logic around NamedTemporaryFile (in the branch where image_column is None) to
ensure the temporary file (f.name) is unlinked after use: create/write/flush the
temp file, call load_dataset("csv", data_files=f.name, split="train",
streaming=True) while the file exists, then immediately remove it with
os.unlink(f.name) in a finally block (or use TemporaryDirectory/ContextManager)
so the temp file is always cleaned up; reference the NamedTemporaryFile, f.name,
load_dataset, image_column and text_column symbols to locate where to add the
unlink/cleanup.
In `@src/axolotl/utils/schemas/validation.py`:
- Around line 1449-1467: The validation currently allows tiling-related keys
without the mm_tiling plugin; update the multimodal CPT validation in
validation.py (the block that inspects data and pd_is_mm) to reject tiling
configs when the mm_tiling integration is not enabled: detect presence of tiling
knobs (e.g. "image_tiling", "mm_tiling", "tile_size", or
"multimodal_sample_packing_cache_path") in data and check the configured plugins
list (data.get("plugins") or equivalent) for the mm_tiling integration (e.g.
"axolotl.integrations.mm_tiling.MMTilingPlugin" or "mm_tiling"); if tiling knobs
are set but the mm_tiling plugin is not present, raise a ValueError with a clear
message instructing the user to add the plugin to the plugins: list.
---
Nitpick comments:
In `@scripts/mm_packing_probe.py`:
- Around line 193-195: Replace the ternary used for side effects with an
explicit if/else to improve readability: check isinstance(value, list) and call
paths.extend(str(path) for path in value) in the if branch, otherwise call
paths.append(str(value)) in the else branch; keep the same variables and
behavior for paths, value, extend, append and isinstance.
- Around line 202-204: The ternary-style expression performing side effects
reduces readability; replace the conditional expression with an explicit if/else
around the side-effecting calls: check the variable matches and if truthy call
expanded.extend(Path(match) for match in matches), otherwise call
expanded.append(Path(path)); this targets the code that currently uses
expanded.extend(Path(match) for match in matches) if matches else
expanded.append(Path(path)) and preserves the same behavior while improving
clarity.
In `@src/axolotl/prompt_strategies/multimodal_pretrain.py`:
- Around line 500-513: In the except Exception: fallback inside
multimodal_pretrain (where processor_texts and loaded_images are zipped and
batches are created using processor or tokenizer and lengths are computed via
_batch_lengths_from_processor_output), add a debug-level log that records the
caught exception and a brief context (e.g., number of items in
processor_texts/loaded_images or a sample text) before falling back to per-row
processing; use the existing logger (or introduce one named logger or
process_logger) and include exception information (exc_info or str(exception))
so the fallback behavior remains unchanged but the root cause is recorded for
debugging.
In `@src/axolotl/utils/data/mm_image_transform.py`:
- Around line 76-81: The broad except in the function returning
PluginManager.get_instance().get_mm_image_transform(cfg) hides real errors; keep
the fallback behavior (return None) but log the caught exception so
misconfigurations are discoverable. Update the except block to capture the
exception as e and emit a debug/exception-level log (e.g. via
logging.getLogger(__name__) or the module's logger) with context like
"get_mm_image_transform failed" and include the exception traceback or message,
then return None; do not change the current fallback logic.
In `@tests/utils/data/test_mm_packing.py`:
- Around line 197-209: The _ImageProcessor class defines a mutable class
attribute size = {"height": 224, "width": 224} which can be mutated across
tests; change it to an instance attribute or a read-only property to avoid test
pollution: update _ImageProcessor (referenced by _Processor.image_processor) so
that size is created per-instance (e.g., set self.size = {"height": 224,
"width": 224} in __init__) or expose a `@property` that returns a new dict each
call, and keep the external API (attribute name size) the same so existing tests
using _ImageProcessor or _Processor.image_processor continue to work.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: dd62d29e-c66a-4f83-8b5f-241f158183c5
📒 Files selected for processing (51)
docs/multimodal.qmdexamples/qwen2_5-vl/mm-cpt-nonstreaming-qlora.yamlexamples/qwen2_5-vl/mm-cpt-streaming-qlora.yamlscripts/__init__.pyscripts/mm_ocr_eval.pyscripts/mm_packing_benchmark.pyscripts/mm_packing_probe.pyscripts/mm_training_benchmark.pyscripts/mm_visual_shape_probe.pysrc/axolotl/core/builders/base.pysrc/axolotl/core/builders/causal.pysrc/axolotl/core/trainers/base.pysrc/axolotl/core/training_args_base.pysrc/axolotl/integrations/base.pysrc/axolotl/integrations/mm_tiling/README.mdsrc/axolotl/integrations/mm_tiling/__init__.pysrc/axolotl/integrations/mm_tiling/args.pysrc/axolotl/integrations/mm_tiling/plugin.pysrc/axolotl/integrations/mm_tiling/tiling.pysrc/axolotl/processing_strategies.pysrc/axolotl/prompt_strategies/multimodal_pretrain.pysrc/axolotl/utils/collators/mm_chat.pysrc/axolotl/utils/collators/mm_pretrain.pysrc/axolotl/utils/config/__init__.pysrc/axolotl/utils/data/mm_image.pysrc/axolotl/utils/data/mm_image_transform.pysrc/axolotl/utils/data/mm_packing.pysrc/axolotl/utils/data/sft.pysrc/axolotl/utils/data/shared.pysrc/axolotl/utils/data/streaming.pysrc/axolotl/utils/schemas/config.pysrc/axolotl/utils/schemas/datasets.pysrc/axolotl/utils/schemas/multimodal.pysrc/axolotl/utils/schemas/validation.pysrc/axolotl/utils/trainer.pytests/conftest.pytests/integrations/mm_tiling/__init__.pytests/integrations/mm_tiling/test_plugin.pytests/patched/test_validation.pytests/prompt_strategies/test_multimodal_pretrain.pytests/scripts/test_mm_packing_benchmark.pytests/test_mm_chat_collator.pytests/test_multimodal_streaming.pytests/utils/data/test_hash.pytests/utils/data/test_mm_cpt_eval.pytests/utils/data/test_mm_image_transform.pytests/utils/data/test_mm_packing.pytests/utils/data/test_mm_pretrain_cache.pytests/utils/data/test_mm_pretrain_cache_integration.pytests/utils/data/test_mm_tiling_regressions.pytests/utils/schemas/validation/test_multimodal_cpt.py
| def _fallback_row_lengths(self, rows: list[dict]) -> list[int]: | ||
| texts = [row["_mm_text"] for row in rows] | ||
| image_sources = [ | ||
| self._raw_image_sources(row.get("images"), row_index=i) | ||
| for i, row in enumerate(rows) | ||
| ] | ||
| return compute_multimodal_processor_lengths( | ||
| texts, | ||
| image_sources, | ||
| tokenizer=self.tokenizer, | ||
| processor=self.processor, | ||
| image_base_dir=self.image_base_dir, | ||
| add_eos_token=self.add_eos_token, | ||
| image_size=self.image_size, | ||
| image_resize_algorithm=self.image_resize_algorithm, | ||
| image_resize_buckets=self.image_resize_buckets, | ||
| image_resize_no_upscale=self.image_resize_no_upscale, | ||
| image_resize_pad_color=self.image_resize_pad_color, | ||
| image_transform=self.image_transform, | ||
| image_token=self.image_token_spec.image_token, | ||
| ) |
There was a problem hiding this comment.
Fallback packed lengths are computed from the wrong text form.
Line 207 recomputes lengths from raw _mm_text, but the packed path actually tokenizes prepare_text_for_packed_boundary(...) plus any placeholder expansion from _prepare_row_text_and_images(...). For packed non-streaming CPT rows that do not already carry length, sample_lengths can drift from the processor output and then _add_packing_masks will either raise the length-mismatch error or assign segment IDs to the wrong span.
Also applies to: 275-278
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@src/axolotl/utils/collators/mm_pretrain.py` around lines 206 - 226, The
fallback length computation is using raw "_mm_text" but must use the same
prepared text and image expansion used in the packed path; update
_fallback_row_lengths to call _prepare_row_text_and_images(row) for each row,
take the returned prepared text (after any placeholder expansion) and pass
prepare_text_for_packed_boundary(prepared_text) (not raw _mm_text) to
compute_multimodal_processor_lengths, and use the image list returned by
_prepare_row_text_and_images for image_sources; apply the same change at the
other fallback call site around the block flagged (the similar call at lines
~275-278) so both fallback paths compute lengths from the prepared/packed text
and expanded images rather than raw _mm_text.
- validation: reject image_tiling_* knobs when the mm_tiling plugin is not loaded, instead of silently training untiled (+ test) - causal builder: fail fast if image_tiling is enabled but no transform resolves (e.g. plugin missing or tiling config invalid) - causal builder: flatten per-dataset field_messages so list/tuple entries don't nest and break ProcessingStrategy._get_messages_field - trainers/base: retain eval dataloaders using the effective persistent_workers value after MM overrides, not the raw arg - sft: drop the leaked NamedTemporaryFile placeholder path; build the placeholder via a generator (also removes unused tempfile import) Skipped the mm_pretrain fallback-length finding: the fallback uses the same standalone _mm_text form as the stored estimator length, which is exactly the invariant _add_packing_masks relies on (BOS is compensated by prepare_text_for_packed_boundary).
|
Addressed in 4e9049c:
Skipped — fallback packed lengths ( |
There was a problem hiding this comment.
Actionable comments posted: 12
♻️ Duplicate comments (1)
src/axolotl/utils/collators/mm_pretrain.py (1)
206-226:⚠️ Potential issue | 🟠 Major | ⚡ Quick winCompute fallback packed lengths from the prepared text/images, not raw
_mm_text.
_pack_rows()packsprepare_text_for_packed_boundary(...)plus any placeholder expansion from_prepare_row_text_and_images(...), but_fallback_row_lengths()still measures raw_mm_textand rawimages. Rows missinglengthcan therefore get a different span than the one the packed processor actually sees, which later breaks_add_packing_masks()or assigns segment ids to the wrong slice.This is the same issue that was already raised earlier and still appears unresolved in the current revision.
Also applies to: 275-278
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/axolotl/utils/collators/mm_pretrain.py` around lines 206 - 226, The fallback length computation is using raw _mm_text and raw images, causing mismatches with what _pack_rows() actually packs; change _fallback_row_lengths to derive each row's text and images by calling _prepare_row_text_and_images(row, i) and then run prepare_text_for_packed_boundary(...) on that prepared text (matching the same transformations _pack_rows uses), and pass the prepared image sources (from _prepare_row_text_and_images rather than row.get("images") or _raw_image_sources) into compute_multimodal_processor_lengths; apply the same replacement to the other similar call site that mirrors this logic so both places measure lengths from the prepared/expanded text+images used by _pack_rows, not the raw fields.
🧹 Nitpick comments (3)
tests/patched/test_validation.py (1)
1746-1753: ⚡ Quick winPrefer a behavioral passthrough assertion here.
Line 1752 is guarding this contract via source-text inspection, so it can false-pass on comments/strings and false-fail on harmless refactors. Assert the actual kwargs forwarded into
TrainingArgumentsinstead.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/patched/test_validation.py` around lines 1746 - 1753, Replace the fragile source-text inspection in test_listed_in_base_training_args_passthrough with a behavioral assertion: call TrainerBuilderBase._set_base_training_args (or invoke the builder path that uses it) with a sentinel value for ignore_data_skip and capture the kwargs passed into TrainingArguments (patch/mocking the TrainingArguments constructor or inspect the created instance) to assert that ignore_data_skip is forwarded; ensure the test uses TrainerBuilderBase._set_base_training_args (or the builder that calls it), mocks TrainingArguments to record received keyword arguments, and asserts that "ignore_data_skip" appears with the expected value in those kwargs.scripts/mm_packing_benchmark.py (1)
564-567: 💤 Low valueRedundant branch in packed vs unpacked collation.
Both branches call the same method:
if variant.packed: batch = collator.torch_call([feature]) else: batch = collator.torch_call([feature])This appears to be dead code or a copy-paste artifact. Consider removing the conditional:
Proposed simplification
- if variant.packed: - batch = collator.torch_call([feature]) - else: - batch = collator.torch_call([feature]) + batch = collator.torch_call([feature])🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@scripts/mm_packing_benchmark.py` around lines 564 - 567, The conditional checking variant.packed is redundant because both branches call the same collator.torch_call([feature]); remove the if/else and replace with a single call setting batch = collator.torch_call([feature]) (or implement the intended packed vs unpacked behavior if different) so references to variant.packed, collator.torch_call, batch, and feature are preserved and the dead code is eliminated.tests/integrations/mm_tiling/test_plugin.py (1)
10-23: ⚡ Quick winReuse the shared
mm_tiling_pluginfixture here.This duplicates the
PluginManagersave/clear/restore logic now centralized intests/conftest.py, so the two setups can drift the next time plugin-state handling changes.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/integrations/mm_tiling/test_plugin.py` around lines 10 - 23, The test defines a duplicate fixture _registered_plugin that manually manipulates PluginManager plugins instead of reusing the centralized mm_tiling_plugin fixture in tests/conftest.py; replace the custom _registered_plugin fixture with a dependency on the existing mm_tiling_plugin fixture (and remove the manual PluginManager save/clear/restore logic and the direct import/instantiation of MMTilingPlugin) so the test reuses the shared setup provided by mm_tiling_plugin and avoids duplicated plugin-state handling.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@docs/multimodal.qmd`:
- Around line 441-442: Update the CPT documentation text that currently states
"MM CPT requires sample_packing: false" to reflect the new validator behavior:
state that sample_packing can be either true or false (both streaming and
non-streaming MM CPT are allowed), and that the validator will accept
sample_packing: true; also keep the note that remove_unused_columns is auto-set
by the validator. Replace any wording that claims sample_packing:true is
rejected with language indicating it is supported for MM CPT (streaming and
non-streaming) so all occurrences (the sections referencing sample_packing and
remove_unused_columns) are consistent with the updated validation tests.
In `@scripts/mm_ocr_eval.py`:
- Around line 205-211: The evaluate() data-file selection currently only looks
at pretraining_dataset[0]["data_files"]; update the logic around dataset_cfg,
first_data_file and args.data_file to handle both config routes and both key
names: check for cfg.get("pretraining_dataset") or cfg.get("datasets") and pick
the first dataset entry, then when extracting the file accept either a
"data_files" list or a "path" string (and handle non-streaming formats
accordingly), falling back to args.data_file if provided; keep using
first_data_file to normalize list entries but add a branch to return
dataset_entry.get("path") when present so both streaming (data_files) and
path-based examples are supported.
In `@scripts/mm_packing_benchmark.py`:
- Around line 336-343: The loop that builds feature dicts uses
range(len(encoded["_mm_text"])) but then indexes rows[idx], causing index
mismatches when outputs are packed; update the block that builds features (the
loop that references encoded, features, rows, and keys "_mm_text",
"_bench_row_index", "_tile_count") to stop assuming a 1:1 mapping: either remove
the conditional assignments that read rows[idx] ("_bench_row_index" and
"_tile_count") for packed outputs, or, if correct source mapping exists in the
packing stage, pull the original source row indices from that packing metadata
and use those to annotate each feature instead of using idx. Ensure the code
only attempts rows[...] lookups when you can guarantee the index corresponds to
a single original row.
In `@scripts/mm_packing_probe.py`:
- Around line 307-310: The loop keeps file descriptors open by calling
Image.open(path).convert("RGB") directly; change it to open images inside a
context manager and copy them into memory so the underlying file is closed
immediately — e.g., use "with Image.open(path) as img:" then "image =
img.convert('RGB').copy()" before appending to loaded and sizes; update the loop
that references resolved, Image.open, loaded, and sizes accordingly.
In `@src/axolotl/integrations/mm_tiling/args.py`:
- Around line 219-241: The validator warn_unused_image_tiling_fields only checks
a hard-coded subset of tiling fields so some image_tiling_* settings are
silently ignored; update it to dynamically detect all attributes that start with
"image_tiling_" (excluding the boolean "image_tiling" itself) and are not None
(e.g., iterate over vars(self).keys() or self.__dict__.keys(), filter names with
name.startswith("image_tiling_") and name != "image_tiling" and getattr(self,
name) is not None), then build set_fields from that list and log them as before;
keep the method name warn_unused_image_tiling_fields and return self.
In `@src/axolotl/utils/data/mm_image_transform.py`:
- Around line 76-81: The current blanket except in the PluginManager resolution
swallows all plugin errors (in the try block around
PluginManager.get_instance().get_mm_image_transform(cfg)), hiding real
misconfigurations; change it so only import/manager-unavailable cases are
swallowed (e.g., catch ImportError or a specific "manager uninitialized"
indicator) and let other exceptions propagate (or log and re-raise) so plugin
resolution failures surface as errors rather than silently returning None;
update the try/except around PluginManager.get_instance().get_mm_image_transform
to narrow the exception types and re-raise or surface unexpected exceptions.
In `@src/axolotl/utils/data/mm_image.py`:
- Around line 75-88: The pad color default is an RGB tuple from
_normalize_pad_color(None) which may be incompatible with non-RGB modes when
_pad_to_canvas calls Image.new(fitted.mode, ..., color); update _pad_to_canvas
to accept the fitted.mode (or call image.mode inside it) and convert or derive
the pad color to a mode-appropriate value before calling Image.new — e.g., if
color is None or a tuple and fitted.mode is 'L' convert to a single luminance
value or use ImageCms/convert logic, or alternatively change
_normalize_pad_color to accept a mode parameter and return a value compatible
with fitted.mode; ensure the conversion is applied wherever
_pad_to_canvas(image, target, ...) is invoked so bucketed/padded resizing won't
raise mode errors.
In `@src/axolotl/utils/data/sft.py`:
- Around line 214-231: The helper _pretraining_config_from_entry currently drops
dataset revision info; update the DictDefault returned by
_pretraining_config_from_entry to include the "revision" field (e.g., set
"revision": entry.get("revision")) so dataset loading and cache keys preserve HF
revision pinning; ensure it uses entry.get("revision") (defaulting to None if
absent) alongside the existing keys like "path", "split", and
"trust_remote_code".
In `@src/axolotl/utils/data/shared.py`:
- Around line 518-528: The constructed cache key in variable component omits
loader-identifying fields (name, data_files, revision) for non-multimodal
datasets; update the assembly so the dataset loader fields are always included
in the component string. Use the existing _get accessor to append _get('name'),
_get('revision'), and _get('data_files') (and keep the current multimodal
fields) — e.g., ensure component includes _get('name') and _get('revision') for
the normal branch and also include _get('name')/_get('revision') in the
multimodal branch if not already present so different loader configs cannot
collide; modify the logic around the _get("type") / _get("multimodal") check and
the component concatenations accordingly.
- Around line 569-574: The fingerprint construction in shared.py can produce
"None+overrides:..." when cfg.tokenizer_config is unset, causing different
tokenizers to collide in the mm-pretraining cache; update the
tokenizer_fingerprint logic to use a resolved tokenizer identity (fallback to
tokenizer_name or the actual resolved tokenizer identifier) rather than the raw
cfg.tokenizer_config when building the string in the block that handles
cfg.get("added_tokens_overrides"); specifically change the expression that
builds tokenizer_fingerprint (the branch setting tokenizer_fingerprint =
f"{cfg.tokenizer_config}+overrides:...") to use cfg.tokenizer_config if truthy
else tokenizer_name (or the resolved tokenizer variable) so the cache key
reflects the real tokenizer identity.
In `@src/axolotl/utils/data/streaming.py`:
- Around line 318-329: The collate and encoding for the non-multimodal packed
path still use cfg.sequence_len; change both uses to the computed
effective_seq_len coming from wrap_streaming_dataset so packed eval streams
respect eval_sequence_len. Specifically update the
PretrainingBatchSamplerDataCollatorForSeq2Seq instantiation (pad_to_multiple_of)
and the functools.partial for encode_packed_streaming (max_seq_length) to use
effective_seq_len instead of cfg.sequence_len; keep multipack_attn, collate_fn,
ds_wrapper_fn and the rest of the encode setup unchanged.
In `@src/axolotl/utils/schemas/validation.py`:
- Around line 1389-1411: The multimodal validation only examines plain dicts by
building test_dicts = [t for t in test_datasets if isinstance(t, dict)], which
lets dataset objects bypass _entry_is_mm and the later image_base_dir /
image_token checks; fix by operating on the full test_datasets list (e.g.,
rename test_entries = data.get("test_datasets") or [] and pass each entry to
_entry_is_mm so it can accept both dicts and dataset objects) and remove the
dict-only filter used in the multimodal checks (affecting the block that
references test_datasets, _entry_is_mm, and train_is_mm and the later block
handling image_base_dir/image_token) so all programmatic dataset objects are
validated the same as dict configs.
---
Duplicate comments:
In `@src/axolotl/utils/collators/mm_pretrain.py`:
- Around line 206-226: The fallback length computation is using raw _mm_text and
raw images, causing mismatches with what _pack_rows() actually packs; change
_fallback_row_lengths to derive each row's text and images by calling
_prepare_row_text_and_images(row, i) and then run
prepare_text_for_packed_boundary(...) on that prepared text (matching the same
transformations _pack_rows uses), and pass the prepared image sources (from
_prepare_row_text_and_images rather than row.get("images") or
_raw_image_sources) into compute_multimodal_processor_lengths; apply the same
replacement to the other similar call site that mirrors this logic so both
places measure lengths from the prepared/expanded text+images used by
_pack_rows, not the raw fields.
---
Nitpick comments:
In `@scripts/mm_packing_benchmark.py`:
- Around line 564-567: The conditional checking variant.packed is redundant
because both branches call the same collator.torch_call([feature]); remove the
if/else and replace with a single call setting batch =
collator.torch_call([feature]) (or implement the intended packed vs unpacked
behavior if different) so references to variant.packed, collator.torch_call,
batch, and feature are preserved and the dead code is eliminated.
In `@tests/integrations/mm_tiling/test_plugin.py`:
- Around line 10-23: The test defines a duplicate fixture _registered_plugin
that manually manipulates PluginManager plugins instead of reusing the
centralized mm_tiling_plugin fixture in tests/conftest.py; replace the custom
_registered_plugin fixture with a dependency on the existing mm_tiling_plugin
fixture (and remove the manual PluginManager save/clear/restore logic and the
direct import/instantiation of MMTilingPlugin) so the test reuses the shared
setup provided by mm_tiling_plugin and avoids duplicated plugin-state handling.
In `@tests/patched/test_validation.py`:
- Around line 1746-1753: Replace the fragile source-text inspection in
test_listed_in_base_training_args_passthrough with a behavioral assertion: call
TrainerBuilderBase._set_base_training_args (or invoke the builder path that uses
it) with a sentinel value for ignore_data_skip and capture the kwargs passed
into TrainingArguments (patch/mocking the TrainingArguments constructor or
inspect the created instance) to assert that ignore_data_skip is forwarded;
ensure the test uses TrainerBuilderBase._set_base_training_args (or the builder
that calls it), mocks TrainingArguments to record received keyword arguments,
and asserts that "ignore_data_skip" appears with the expected value in those
kwargs.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 6a1bafcb-7aa5-4395-a7d7-2e9ab0d81bb2
📒 Files selected for processing (51)
docs/multimodal.qmdexamples/qwen2_5-vl/mm-cpt-nonstreaming-qlora.yamlexamples/qwen2_5-vl/mm-cpt-streaming-qlora.yamlscripts/__init__.pyscripts/mm_ocr_eval.pyscripts/mm_packing_benchmark.pyscripts/mm_packing_probe.pyscripts/mm_training_benchmark.pyscripts/mm_visual_shape_probe.pysrc/axolotl/core/builders/base.pysrc/axolotl/core/builders/causal.pysrc/axolotl/core/trainers/base.pysrc/axolotl/core/training_args_base.pysrc/axolotl/integrations/base.pysrc/axolotl/integrations/mm_tiling/README.mdsrc/axolotl/integrations/mm_tiling/__init__.pysrc/axolotl/integrations/mm_tiling/args.pysrc/axolotl/integrations/mm_tiling/plugin.pysrc/axolotl/integrations/mm_tiling/tiling.pysrc/axolotl/processing_strategies.pysrc/axolotl/prompt_strategies/multimodal_pretrain.pysrc/axolotl/utils/collators/mm_chat.pysrc/axolotl/utils/collators/mm_pretrain.pysrc/axolotl/utils/config/__init__.pysrc/axolotl/utils/data/mm_image.pysrc/axolotl/utils/data/mm_image_transform.pysrc/axolotl/utils/data/mm_packing.pysrc/axolotl/utils/data/sft.pysrc/axolotl/utils/data/shared.pysrc/axolotl/utils/data/streaming.pysrc/axolotl/utils/schemas/config.pysrc/axolotl/utils/schemas/datasets.pysrc/axolotl/utils/schemas/multimodal.pysrc/axolotl/utils/schemas/validation.pysrc/axolotl/utils/trainer.pytests/conftest.pytests/integrations/mm_tiling/__init__.pytests/integrations/mm_tiling/test_plugin.pytests/patched/test_validation.pytests/prompt_strategies/test_multimodal_pretrain.pytests/scripts/test_mm_packing_benchmark.pytests/test_mm_chat_collator.pytests/test_multimodal_streaming.pytests/utils/data/test_hash.pytests/utils/data/test_mm_cpt_eval.pytests/utils/data/test_mm_image_transform.pytests/utils/data/test_mm_packing.pytests/utils/data/test_mm_pretrain_cache.pytests/utils/data/test_mm_pretrain_cache_integration.pytests/utils/data/test_mm_tiling_regressions.pytests/utils/schemas/validation/test_multimodal_cpt.py
Core: - mm_image_transform: stop swallowing plugin resolution failures; only the PluginManager import is guarded, so a transform bug surfaces instead of silently downgrading to the non-tiling path - shared: include name/revision/data_files in the dataset cache-key identity, and fall back to tokenizer_name when tokenizer_config is unset with added_tokens_overrides (both SFT and pretraining hashes); add revision to the pretraining hash keys - streaming: use effective_seq_len (not cfg.sequence_len) in the non-MM packed path so eval honors eval_sequence_len - mm_image: derive a mode-compatible canvas fill so single-band (L) images don't crash bucketed/padded resize - sft: thread dataset through the extracted pretraining config and into load_dataset so revision-pinned configs load/cache correctly - validation: validate object-shaped test_datasets entries (not just dicts) - mm_tiling args: warn on every explicitly-set tiling field when tiling is off, via model_fields_set (handles bool flags with non-None defaults) Scripts/docs: - mm_ocr_eval: accept both pretraining_dataset/datasets routes and path/data_files - mm_packing_benchmark: don't map per-row metadata onto packed (aggregated) features - mm_packing_probe: close image file handles in the probe loop - docs/multimodal: sample_packing is supported for MM CPT; only the streaming pretraining_dataset + dataset_prepared_path combo is rejected
|
Second batch addressed in bb02b2c — all 12 actionable findings: Core
Scripts / docs
149 MM + 23 dataset-hash tests pass; |
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/axolotl/utils/data/mm_image.py (1)
41-56:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winThe
no_upscalebranch is identical to the fallthrough — likely dead code or missing logic.Lines 41-48 and 50-56 return the same
max(buckets, key=...)with identical keys, sono_upscalehas no effect on the non-containing path. Either this branch is redundant and should be removed, or it was meant to encode a different (upscale-capping) selection that got lost. Worth confirming the intended behavior.♻️ If the branch is genuinely redundant
- if no_upscale: - return max( - buckets, - key=lambda bucket: ( - min(bucket[0] / width, bucket[1] / height), - -(bucket[0] * bucket[1]), - ), - ) - return max( buckets, key=lambda bucket: ( min(bucket[0] / width, bucket[1] / height), -(bucket[0] * bucket[1]), ), )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/axolotl/utils/data/mm_image.py` around lines 41 - 56, The no_upscale branch currently duplicates the fallthrough selection (both return max(buckets, key=...)), so either remove the redundant branch or implement the intended "no upscaling" behavior: when no_upscale is True filter buckets to those that do not upscale (e.g., keep only buckets where min(bucket[0]/width, bucket[1]/height) <= 1) and then pick the best with the existing key; if that filter yields no candidates, fall back to the original max(buckets, key=...) or pick the largest bucket that is <= target by area. Update the branch around no_upscale to reflect one of these fixes, referencing buckets, width, height, and the existing key lambda.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@src/axolotl/utils/data/mm_image.py`:
- Around line 41-56: The no_upscale branch currently duplicates the fallthrough
selection (both return max(buckets, key=...)), so either remove the redundant
branch or implement the intended "no upscaling" behavior: when no_upscale is
True filter buckets to those that do not upscale (e.g., keep only buckets where
min(bucket[0]/width, bucket[1]/height) <= 1) and then pick the best with the
existing key; if that filter yields no candidates, fall back to the original
max(buckets, key=...) or pick the largest bucket that is <= target by area.
Update the branch around no_upscale to reflect one of these fixes, referencing
buckets, width, height, and the existing key lambda.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: d85887b0-d657-43c8-88ac-f3364673954e
📒 Files selected for processing (14)
docs/multimodal.qmdscripts/mm_ocr_eval.pyscripts/mm_packing_benchmark.pyscripts/mm_packing_probe.pysrc/axolotl/core/builders/causal.pysrc/axolotl/core/trainers/base.pysrc/axolotl/integrations/mm_tiling/args.pysrc/axolotl/utils/data/mm_image.pysrc/axolotl/utils/data/mm_image_transform.pysrc/axolotl/utils/data/sft.pysrc/axolotl/utils/data/shared.pysrc/axolotl/utils/data/streaming.pysrc/axolotl/utils/schemas/validation.pytests/utils/schemas/validation/test_multimodal_cpt.py
🚧 Files skipped from review as they are similar to previous changes (12)
- src/axolotl/core/trainers/base.py
- src/axolotl/integrations/mm_tiling/args.py
- src/axolotl/utils/data/shared.py
- docs/multimodal.qmd
- src/axolotl/utils/schemas/validation.py
- src/axolotl/utils/data/sft.py
- src/axolotl/utils/data/streaming.py
- scripts/mm_ocr_eval.py
- src/axolotl/core/builders/causal.py
- tests/utils/schemas/validation/test_multimodal_cpt.py
- scripts/mm_packing_probe.py
- scripts/mm_packing_benchmark.py
Summary
Extracts the multimodal OCR image-tiling feature out of core and into a first-class Axolotl plugin (
axolotl.integrations.mm_tiling). This is the plugin-architecture rework of the tiling feature that #35 prototyped inline — same behavior, but core no longer carries any tiling-specific code.Opt in via config:
Why a plugin
#35 baked
image_tiling_*fields into the core schema, a 16-field copy loop intotraining_args, and tiling calls into 7 core call sites. That couples a niche OCR feature to every training path. This PR introduces a generic core seam so core knows nothing about tiling:MMImageTransformProtocol +resolve_mm_image_transform(cfg)inutils/data/mm_image_transform.pyBasePlugin.get_mm_image_transform(cfg)/PluginManager.get_mm_image_transformimage_transforminstead of a tiling configThe plugin (
integrations/mm_tiling/) owns everything tiling:tiling.py(impl +TilingImageTransformadapter),args.py(MMTilingArgsschema),plugin.py(MMTilingPlugin).What's preserved from #35
ocr_pages: landscape 3×2 / portrait 2×3 / tall 2×4), SSD tile cache<rowX_colY>+<global_img>)image_tiling_native_resolutionmode (tile near original resolution)_add_packing_masksassertion holds)Verification
"image_tiling_config", kept for hash stability) and a genericimage_tiling_min_area >= 0validation checkpre-commit run --all-filesgreen (ruff, ruff-format, mypy, bandit)Relationship to other PRs
Known limitation
Setting
image_tiling: truewithout listing the plugin silently trains untiled (standard Axolotl plugin-arg behavior — args only attach when the plugin is loaded). The README shows the requiredplugins:line.Summary by CodeRabbit
New Features
New Tools & Scripts
Documentation
Tests