-
Notifications
You must be signed in to change notification settings - Fork 27
Add activation matching distillation support #672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add activation matching distillation support #672
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements knowledge distillation functionality that enables training a smaller student model using guidance from a larger teacher model. The implementation supports multiple distillation loss variants including forward KL divergence, reverse KL, symmetric KL, Jensen-Shannon divergence, logit MSE, and layer activation matching.
Key changes:
- Added
distillation_loss_variants.pymodule with 6 loss variants and a configurable builder function - Extended CLI arguments to support distillation configuration (teacher checkpoint, loss type, temperature, weight, etc.)
- Modified training loop to load teacher model, compute teacher logits, and combine distillation loss with student loss
- Added layer activation matching support with forward hooks for intermediate layer comparison
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| train_variations/distillation_loss_variants.py | Implements 6 distillation loss variants and builder function for configurable loss selection |
| train_args.py | Adds 8 new CLI arguments for distillation configuration including teacher checkpoint path, loss variant, temperature, and activation layers |
| train.py | Integrates distillation into training loop with teacher model loading, activation caching, and combined loss computation |
| explorations/distillation_sweep.yaml | Configuration file for hyperparameter sweep across distillation variants |
| demos/distillation_minipile_comparison.sh | Demo script comparing baseline vs distilled student training |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| def _register_activation_hooks(self, model, cache, *, detach_outputs: bool, context_attr: str): | ||
| handles = [] | ||
| module_list = model.transformer['h'] |
Copilot
AI
Oct 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accessing model.transformer using dictionary syntax ['h'] assumes transformer is a ModuleDict, but if the model structure changes or uses attribute access, this will raise a KeyError. Consider using getattr(model.transformer, 'h', None) with proper error handling or verify that model.transformer is always a ModuleDict.
| module_list = model.transformer['h'] | |
| # Robustly access the 'h' block list from the transformer, supporting both attribute and dict access | |
| if hasattr(model.transformer, 'h'): | |
| module_list = getattr(model.transformer, 'h') | |
| elif isinstance(model.transformer, dict) and 'h' in model.transformer: | |
| module_list = model.transformer['h'] | |
| else: | |
| raise AttributeError("model.transformer does not have an 'h' attribute or key.") |
| loss_fn.requires_layer_activations = True | ||
| loss_fn.activation_layers = activation_layers |
Copilot
AI
Oct 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting dynamic attributes on a function object is unconventional and may cause issues with type checkers and IDE support. Consider returning a dataclass or wrapper object that encapsulates both the function and its metadata, or use function attributes explicitly with __dict__ if this pattern is necessary.
| if hasattr(self.optimizer, "set_entropy") and not isinstance(logits, (list, tuple)): | ||
| with torch.no_grad(): | ||
| probs = torch.softmax(logits, dim=-1) | ||
| ent = -(probs * torch.log(probs + 1e-9)).sum(dim=-1).mean() | ||
| ent = ent / math.log(logits.size(-1)) | ||
| self.optimizer.set_entropy(float(ent)) |
Copilot
AI
Oct 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Entropy computation is performed before distillation loss calculation, meaning softmax is computed on student logits twice when distillation is active (once for entropy, once inside distillation loss). Consider computing entropy after distillation or reusing probability tensors if possible to avoid redundant softmax operations.
| ) | ||
| if self._distillation_requires_layer_activations and getattr(self.args, "compile", False): | ||
| raise ValueError( | ||
| "Layer-activation distillation is not compatible with --compile on the student model." |
Copilot
AI
Oct 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message states incompatibility but doesn't explain why. Consider adding a brief explanation such as 'Layer-activation distillation is not compatible with --compile on the student model because torch.compile interferes with forward hook registration.'
| "Layer-activation distillation is not compatible with --compile on the student model." | |
| "Layer-activation distillation is not compatible with --compile on the student model because torch.compile interferes with forward hook registration." |
| if targets is None: | ||
| return flat | ||
|
|
||
| mask = (targets != -1).reshape(-1) |
Copilot
AI
Oct 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The magic number -1 for ignore_index is hardcoded here and appears in multiple places throughout the codebase. Consider defining a module-level constant IGNORE_INDEX = -1 to make this value more maintainable and self-documenting.
| def hook(module, inputs, output, *, idx=layer_idx): | ||
| context_key = getattr(self, context_attr) | ||
| if context_key is None: | ||
| return | ||
| tensor = output[0] if isinstance(output, (tuple, list)) else output | ||
| if detach_outputs: | ||
| tensor = tensor.detach() | ||
| cache[context_key][idx] = tensor |
Copilot
AI
Oct 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hook function captures self, cache, and context_attr in its closure. This creates strong references that may prevent garbage collection of the trainer object. Consider using weak references or ensure hooks are properly removed when the trainer is destroyed to avoid memory leaks.
| ) | ||
| if self._distillation_requires_layer_activations: | ||
| raise RuntimeError( | ||
| "Layer-activation distillation does not support compiling the teacher model." |
Copilot
AI
Oct 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message states incompatibility but doesn't explain why. Consider adding a brief explanation such as 'Layer-activation distillation does not support compiling the teacher model because torch.compile interferes with forward hook registration needed to capture intermediate activations.'
| "Layer-activation distillation does not support compiling the teacher model." | |
| "Layer-activation distillation does not support compiling the teacher model because torch.compile interferes with forward hook registration needed to capture intermediate activations." |
| raise ValueError("No valid layer indices were supplied for activation-based distillation.") | ||
| activation_layers = tuple(sorted(set(parsed_layers))) | ||
|
|
||
| def loss_fn(student_logits: torch.Tensor, teacher_logits: torch.Tensor, targets: torch.Tensor, *, iter_num=None, **kwargs): |
Copilot
AI
Oct 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The nested loss_fn function is defined inside build_distillation_loss but is missing type hints for iter_num (should be int | None) and has an untyped **kwargs. Adding proper type hints would improve code clarity and enable better static analysis.
This pull request introduces comprehensive support for knowledge distillation in the training workflow, enabling student models to learn from teacher models using a variety of distillation loss functions. The implementation includes new command-line arguments, teacher model loading and management, support for layer activation matching, logging of distillation metrics, and example scripts/configurations for running distillation experiments.
The most important changes are:
Knowledge Distillation Infrastructure:
Added support for loading a teacher model checkpoint (
ckpt.pt) and applying a configurable distillation loss during student training, including options for temperature, weight, and numerical stability epsilon. Teacher model can be optionally compiled or cast to a specific dtype. [1] [2] [3]Implemented infrastructure for matching intermediate layer activations between student and teacher models, with activation hooks, caching, and compatibility checks for compiled models. [1] [2] [3] [4]
Training Loop & Metrics:
Integrated distillation loss computation into the main training loop, ensuring correct combination with the primary loss and proper handling of activation caches. [1] [2]
Enhanced logging to TensorBoard by recording the latest distillation loss alongside standard loss metrics for both training and validation, improving experiment traceability. [1] [2]
User Interface & Configuration:
Documentation & Examples:
Added an example shell script (
demos/distillation_minipile_comparison.sh) demonstrating end-to-end distillation workflow on the minipile dataset, including teacher/student training and statistics comparison.Added a YAML sweep configuration (
explorations/distillation_sweep.yaml) for running systematic distillation experiments with various loss functions and hyperparameters.Codebase Organization: