Trim train_gpt_mlx_kl.py to ≤1500 lines; fix orphaned clip_grad_tree#8
Conversation
- Fix orphaned clip_grad_tree function body by adding proper def line - Remove verbose section separator comment blocks (17+ instances) - Compact 26-line module docstring to 2-line summary - Trim multi-line docstrings to single lines throughout - Remove redundant inline comments that restate the code - Remove unnecessary blank lines within function bodies - Compact Hyperparameters class by removing section comment headers All functionality, logic, algorithms, and class/function signatures preserved. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: kailean <49617037+kailean@users.noreply.github.com>
5c2050e
into
copilot/create-clean-submission-ready-pr
Reviewer's GuideFixes a latent NameError by properly defining clip_grad_tree, and reduces train_gpt_mlx_kl.py from ~1847 to 1493 lines via doc/comment pruning and minor docstring tightening without changing model/optimizer/eval behavior. Sequence diagram for training step with clip_grad_treesequenceDiagram
participant main
participant model
participant compiled_loss_and_grad
participant clip_grad_tree
participant optimizer
main->>compiled_loss_and_grad: loss_and_grad_chunked(args, train_loader)
compiled_loss_and_grad-->>main: loss, grads_tree
main->>clip_grad_tree: clip_grad_tree(grads_tree, args.grad_clip_norm)
alt max_norm <= 0
clip_grad_tree-->>main: grads_tree (unchanged)
else max_norm > 0
clip_grad_tree->>clip_grad_tree: flat = dict(tree_flatten(grads_tree))
clip_grad_tree->>clip_grad_tree: total_sq = sum((g**2).sum() for g in flat.values())
clip_grad_tree->>clip_grad_tree: scale = max_norm / sqrt(total_sq + 1e-12)
clip_grad_tree-->>main: tree_unflatten((k, g * scale))
end
main->>optimizer: opt.step(model, clipped_grads, step, lr_mul)
optimizer-->>model: update(model.parameters())
model-->>main: parameters_updated
Flow diagram for clip_grad_tree gradient clippingflowchart TD
A_start[Start clip_grad_tree] --> B_check_norm{max_norm <= 0}
B_check_norm -->|yes| C_return_orig[Return grads_tree]
B_check_norm -->|no| D_flatten["flat = dict(tree_flatten(grads_tree))"]
D_flatten --> E_total_sq["total_sq = sum((g * g).sum() for g in flat.values())"]
E_total_sq --> F_scale["scale = max_norm / (sqrt(total_sq) + 1e-12)"]
F_scale --> G_scale_grads["scaled_items = (k, g * scale) for k, g in flat.items()"]
G_scale_grads --> H_unflatten["clipped = tree_unflatten(scaled_items)"]
H_unflatten --> I_return_clipped[Return clipped]
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Hey - I've found 1 issue, and left some high level feedback:
- The new
clip_grad_treehelper is currently defined inline between evaluation functions; consider moving it closer to the optimizer/gradient logic (e.g., nearSplitOptimizersor training loop helpers) to keep related concerns grouped together. - In
clip_grad_tree, you repeatedly convert between tree and dict (tree_flatten→dict→tree_unflatten); if performance becomes an issue, you could operate directly on the flattened list or reuse the original structure to avoid extra allocations.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- The new `clip_grad_tree` helper is currently defined inline between evaluation functions; consider moving it closer to the optimizer/gradient logic (e.g., near `SplitOptimizers` or training loop helpers) to keep related concerns grouped together.
- In `clip_grad_tree`, you repeatedly convert between tree and dict (`tree_flatten` → `dict` → `tree_unflatten`); if performance becomes an issue, you could operate directly on the flattened list or reuse the original structure to avoid extra allocations.
## Individual Comments
### Comment 1
<location path="train_gpt_mlx_kl.py" line_range="1159-1163" />
<code_context>
-
-
+def clip_grad_tree(grads_tree, max_norm):
+ """Clip gradient tree by global norm."""
if max_norm <= 0:
</code_context>
<issue_to_address>
**issue (bug_risk):** clip_grad_tree increases small gradients up to max_norm instead of only shrinking large ones
This rescales the gradient tree to have norm `max_norm` even when the original norm is already smaller, which deviates from standard clipping and unintentionally increases gradients. A typical implementation only rescales when the norm exceeds `max_norm`, e.g.
```python
if total_sq <= max_norm * max_norm:
return grads_tree
scale = max_norm / (math.sqrt(total_sq) + 1e-12)
...
```
so gradients within the bound are left unchanged and only oversized gradients are reduced.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| def clip_grad_tree(grads_tree, max_norm): | ||
| """Clip gradient tree by global norm.""" | ||
| if max_norm <= 0: | ||
| return grads_tree | ||
| flat = dict(tree_flatten(grads_tree)) |
There was a problem hiding this comment.
issue (bug_risk): clip_grad_tree increases small gradients up to max_norm instead of only shrinking large ones
This rescales the gradient tree to have norm max_norm even when the original norm is already smaller, which deviates from standard clipping and unintentionally increases gradients. A typical implementation only rescales when the norm exceeds max_norm, e.g.
if total_sq <= max_norm * max_norm:
return grads_tree
scale = max_norm / (math.sqrt(total_sq) + 1e-12)
...so gradients within the bound are left unchanged and only oversized gradients are reduced.
There was a problem hiding this comment.
Pull request overview
This PR trims train_gpt_mlx_kl.py to meet the ≤1500-line target for the Parameter Golf artifact-size constraints and fixes a runtime bug where clip_grad_tree was referenced in training but had no function definition.
Changes:
- Reduced script length primarily by removing/condensing docstrings, separators, and redundant comments/whitespace.
- Added a proper
def clip_grad_tree(grads_tree, max_norm): ...implementation so gradient clipping works when enabled. - Kept existing model/optimizer/eval features intact while reorganizing/condensing surrounding text.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| bigram_hash_size: int = int(os.environ.get("BIGRAM_HASH_SIZE", 16384)) # Task 2: fill budget (was 10240) | ||
| qat_start_frac: float = float(os.environ.get("QAT_START_FRAC", 0.15)) # UNUSED — late_qat_threshold controls QAT | ||
| bigram_hash_size: int = int(os.environ.get("BIGRAM_HASH_SIZE", 16384)) | ||
| qat_start_frac: float = float(os.environ.get("QAT_START_FRAC", 0.15)) |
There was a problem hiding this comment.
qat_start_frac is defined but never referenced anywhere in the script (only appears in Hyperparameters). Since this repo is optimizing for minimal code bytes, consider removing it (and the QAT_START_FRAC env var) or wiring it into the QAT toggle logic so the config surface matches actual behavior.
| qat_start_frac: float = float(os.environ.get("QAT_START_FRAC", 0.15)) |
🤖 Augment PR SummarySummary: This PR trims Changes:
Technical Notes: The only intended behavioral change is the 🤖 Was this summary useful? React with 👍 or 👎 |
| bigram_hash_size: int = int(os.environ.get("BIGRAM_HASH_SIZE", 16384)) # Task 2: fill budget (was 10240) | ||
| qat_start_frac: float = float(os.environ.get("QAT_START_FRAC", 0.15)) # UNUSED — late_qat_threshold controls QAT | ||
| bigram_hash_size: int = int(os.environ.get("BIGRAM_HASH_SIZE", 16384)) | ||
| qat_start_frac: float = float(os.environ.get("QAT_START_FRAC", 0.15)) |
There was a problem hiding this comment.
train_gpt_mlx_kl.py:54 — qat_start_frac is read from QAT_START_FRAC but isn’t referenced anywhere else, so that env var currently has no effect on when QAT turns on. Since the explanatory comment was removed in this trim, consider either wiring this parameter into the QAT toggle logic or removing it to avoid misleading configuration.
Severity: medium
🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage.
| # TTT: gradient steps on LoRA params using context tokens (s=0..wlen-stride) | ||
| s = 0 if ws == 0 else max(wlen - stride, 0) | ||
| if s > 0: # Only train if there are context tokens before the eval window | ||
| if s > 0: |
There was a problem hiding this comment.
train_gpt_mlx_kl.py:1215 — Inside this TTT block, lora_A/lora_B never get updated (and ttt_lr is unused), so enabling TTT appears to be a no-op besides recomputing loss. If the intent is to actually adapt per-window weights, this likely needs an update step for the LoRA parameters.
Severity: medium
🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage.
Script was 1847 lines, over the 1500-line target for the sub-1.0 BPB build. Code bytes count toward the 16MB artifact limit per challenge rules. Also found a bug:
clip_grad_treebody existed but had nodefstatement — wouldNameErrorwhengrad_clip_norm > 0.Bug fix
Added missing function definition:
Previously this was an orphaned code block (indented body with no
def) followingeval_val_sliding_ngram, unreachable as written but called in the training loop at line 1747.Line reduction (1847 → 1493)
# ====...====blocks (~17)All 17 classes, 25 functions, and every feature (EngramLite, BackoffNgramMixer, ComplementaryTraining, SkipGramHash, SmearGate, XSA, LoRA TTT, GPTQ-lite, sliding-window eval) preserved. Verified via
ast.parseand AST name enumeration.Summary by Sourcery
Trim and clean up the GPT training script while preserving functionality and add a proper gradient clipping helper to fix a missing definition bug.
Bug Fixes:
Enhancements: