Skip to content

Conversation

@klei22
Copy link
Collaborator

@klei22 klei22 commented Oct 29, 2025

This pull request introduces comprehensive support for knowledge distillation in the training workflow, enabling student models to learn from teacher models using a variety of distillation loss functions. The implementation includes new command-line arguments, teacher model loading and management, support for layer activation matching, logging of distillation metrics, and example scripts/configurations for running distillation experiments.

The most important changes are:

Knowledge Distillation Infrastructure:

  • Added support for loading a teacher model checkpoint (ckpt.pt) and applying a configurable distillation loss during student training, including options for temperature, weight, and numerical stability epsilon. Teacher model can be optionally compiled or cast to a specific dtype. [1] [2] [3]

  • Implemented infrastructure for matching intermediate layer activations between student and teacher models, with activation hooks, caching, and compatibility checks for compiled models. [1] [2] [3] [4]

Training Loop & Metrics:

  • Integrated distillation loss computation into the main training loop, ensuring correct combination with the primary loss and proper handling of activation caches. [1] [2]

  • Enhanced logging to TensorBoard by recording the latest distillation loss alongside standard loss metrics for both training and validation, improving experiment traceability. [1] [2]

User Interface & Configuration:

  • Added new command-line arguments for all distillation-related options, including loss variant selection, teacher checkpoint path, temperature, weight, epsilon, teacher model dtype/compilation, and activation layer selection.

Documentation & Examples:

  • Added an example shell script (demos/distillation_minipile_comparison.sh) demonstrating end-to-end distillation workflow on the minipile dataset, including teacher/student training and statistics comparison.

  • Added a YAML sweep configuration (explorations/distillation_sweep.yaml) for running systematic distillation experiments with various loss functions and hyperparameters.

Codebase Organization:

  • Imported new distillation loss variant definitions and builder functions into the main training script and argument parser. [1] [2] [3]

@klei22 klei22 requested review from Copilot and gkielian October 29, 2025 19:29
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements knowledge distillation functionality that enables training a smaller student model using guidance from a larger teacher model. The implementation supports multiple distillation loss variants including forward KL divergence, reverse KL, symmetric KL, Jensen-Shannon divergence, logit MSE, and layer activation matching.

Key changes:

  • Added distillation_loss_variants.py module with 6 loss variants and a configurable builder function
  • Extended CLI arguments to support distillation configuration (teacher checkpoint, loss type, temperature, weight, etc.)
  • Modified training loop to load teacher model, compute teacher logits, and combine distillation loss with student loss
  • Added layer activation matching support with forward hooks for intermediate layer comparison

Reviewed Changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
train_variations/distillation_loss_variants.py Implements 6 distillation loss variants and builder function for configurable loss selection
train_args.py Adds 8 new CLI arguments for distillation configuration including teacher checkpoint path, loss variant, temperature, and activation layers
train.py Integrates distillation into training loop with teacher model loading, activation caching, and combined loss computation
explorations/distillation_sweep.yaml Configuration file for hyperparameter sweep across distillation variants
demos/distillation_minipile_comparison.sh Demo script comparing baseline vs distilled student training

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


def _register_activation_hooks(self, model, cache, *, detach_outputs: bool, context_attr: str):
handles = []
module_list = model.transformer['h']
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accessing model.transformer using dictionary syntax ['h'] assumes transformer is a ModuleDict, but if the model structure changes or uses attribute access, this will raise a KeyError. Consider using getattr(model.transformer, 'h', None) with proper error handling or verify that model.transformer is always a ModuleDict.

Suggested change
module_list = model.transformer['h']
# Robustly access the 'h' block list from the transformer, supporting both attribute and dict access
if hasattr(model.transformer, 'h'):
module_list = getattr(model.transformer, 'h')
elif isinstance(model.transformer, dict) and 'h' in model.transformer:
module_list = model.transformer['h']
else:
raise AttributeError("model.transformer does not have an 'h' attribute or key.")

Copilot uses AI. Check for mistakes.
Comment on lines +296 to +297
loss_fn.requires_layer_activations = True
loss_fn.activation_layers = activation_layers
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting dynamic attributes on a function object is unconventional and may cause issues with type checkers and IDE support. Consider returning a dataclass or wrapper object that encapsulates both the function and its metadata, or use function attributes explicitly with __dict__ if this pattern is necessary.

Copilot uses AI. Check for mistakes.
Comment on lines +1909 to +1914
if hasattr(self.optimizer, "set_entropy") and not isinstance(logits, (list, tuple)):
with torch.no_grad():
probs = torch.softmax(logits, dim=-1)
ent = -(probs * torch.log(probs + 1e-9)).sum(dim=-1).mean()
ent = ent / math.log(logits.size(-1))
self.optimizer.set_entropy(float(ent))
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Entropy computation is performed before distillation loss calculation, meaning softmax is computed on student logits twice when distillation is active (once for entropy, once inside distillation loss). Consider computing entropy after distillation or reusing probability tensors if possible to avoid redundant softmax operations.

Copilot uses AI. Check for mistakes.
)
if self._distillation_requires_layer_activations and getattr(self.args, "compile", False):
raise ValueError(
"Layer-activation distillation is not compatible with --compile on the student model."
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message states incompatibility but doesn't explain why. Consider adding a brief explanation such as 'Layer-activation distillation is not compatible with --compile on the student model because torch.compile interferes with forward hook registration.'

Suggested change
"Layer-activation distillation is not compatible with --compile on the student model."
"Layer-activation distillation is not compatible with --compile on the student model because torch.compile interferes with forward hook registration."

Copilot uses AI. Check for mistakes.
if targets is None:
return flat

mask = (targets != -1).reshape(-1)
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The magic number -1 for ignore_index is hardcoded here and appears in multiple places throughout the codebase. Consider defining a module-level constant IGNORE_INDEX = -1 to make this value more maintainable and self-documenting.

Copilot uses AI. Check for mistakes.
Comment on lines +576 to +583
def hook(module, inputs, output, *, idx=layer_idx):
context_key = getattr(self, context_attr)
if context_key is None:
return
tensor = output[0] if isinstance(output, (tuple, list)) else output
if detach_outputs:
tensor = tensor.detach()
cache[context_key][idx] = tensor
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hook function captures self, cache, and context_attr in its closure. This creates strong references that may prevent garbage collection of the trainer object. Consider using weak references or ensure hooks are properly removed when the trainer is destroyed to avoid memory leaks.

Copilot uses AI. Check for mistakes.
)
if self._distillation_requires_layer_activations:
raise RuntimeError(
"Layer-activation distillation does not support compiling the teacher model."
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message states incompatibility but doesn't explain why. Consider adding a brief explanation such as 'Layer-activation distillation does not support compiling the teacher model because torch.compile interferes with forward hook registration needed to capture intermediate activations.'

Suggested change
"Layer-activation distillation does not support compiling the teacher model."
"Layer-activation distillation does not support compiling the teacher model because torch.compile interferes with forward hook registration needed to capture intermediate activations."

Copilot uses AI. Check for mistakes.
raise ValueError("No valid layer indices were supplied for activation-based distillation.")
activation_layers = tuple(sorted(set(parsed_layers)))

def loss_fn(student_logits: torch.Tensor, teacher_logits: torch.Tensor, targets: torch.Tensor, *, iter_num=None, **kwargs):
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nested loss_fn function is defined inside build_distillation_loss but is missing type hints for iter_num (should be int | None) and has an untyped **kwargs. Adding proper type hints would improve code clarity and enable better static analysis.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant