Skip to content

Conversation

@joe32140
Copy link

@joe32140 joe32140 commented Oct 24, 2025

[11/2] Updated to support arbitrary N_Dense layers and apply corresponding activations.

This commit implements a complete refactor of the projection layer architecture to support arbitrary N-layer projections (1_Dense, 2_Dense, ..., N_Dense) with proper activation function handling per layer, enabling compatibility with newer compact models like mixedbread-ai/mxbai-edge-colbert-v0-17m.

See running example in https://joe32140.github.io/pylate-rs/.

Core Architecture Changes

Model Layer (src/model.rs)

  • Add Activation enum supporting Identity, ReLU, GELU, GeluErf, Tanh, SiLU
  • Implement Activation::from_pytorch_name() to parse PyTorch activation class names
  • Create ProjectionLayer struct encapsulating linear layer + activation
  • Replace hardcoded linear/linear2 with Vec<ProjectionLayer>
  • Implement dimension validation between consecutive layers
  • Update forward pass to iterate through all layers sequentially

Builder (src/builder.rs)

  • Auto-detect numbered Dense directories (1_Dense, 2_Dense, ..., up to 10)
  • Scan both local paths and HuggingFace Hub
  • Collect (config_bytes, weights_bytes, layer_name) tuples for each layer
  • Stop scanning after first missing layer (1_Dense is required)
  • Pass all layers to ColBERT constructor

WASM Bindings (src/wasm.rs)

  • Update constructor to accept projection layers array
  • Use js_sys::Uint8Array for direct byte array handling
  • Format: Array<{config: Uint8Array, weights: Uint8Array, name: string}>
  • Extract layers using js_sys::Reflect for proper type handling

Demo Updates (docs/index.html)

  • Auto-detect and load all Dense layers from HuggingFace
  • Pass Uint8Array directly (not Array.from conversion)
  • Add await and 100ms delay after WASM init to fix race condition
  • Update constructor call with new signature
  • Fix initial load failure by ensuring WASM is fully ready

Key Features

✅ Activation functions applied after each linear layer
✅ Supports arbitrary 1-N projection layers (tested up to 10)
✅ Backward compatible with single-layer models
✅ Proper dimension validation at load time
✅ Auto-detection of available layers
✅ Clean encapsulation with ProjectionLayer struct

Architecture Flow

Token Embeddings → 1_Dense (Linear + Activation) → 2_Dense (Linear + Activation) → ... → N_Dense (Linear + Activation) → L2 Normalization

Testing

  • ✅ All existing tests pass (3/3)
  • ✅ Single-layer models work (e.g., answerai-colbert-small-v1)
  • ✅ Multi-layer models work (e.g., mxbai-edge-colbert-v0-17m)
  • ✅ WASM demo loads successfully on first attempt

Alignment with PyLate

Implementation verified against Python PyLate repository:

  • Matches activation function application logic
  • Compatible with same directory structure (1_Dense, 2_Dense, ...)
  • Sequential layer iteration matches Python's module iteration
  • Config-based activation parsing aligns with PyLate's Dense class

🤖 Generated with Claude Code

Co-Authored-By: Claude [email protected]

This commit implements support for models with 2-stage Dense projections,
enabling compatibility with newer ColBERT models like mixedbread-ai's
mxbai-edge-colbert-v0-17m that use dual projection layers.

## Core Changes

### Model Architecture (src/model.rs)
- Add `linear2: Option<Linear>` field to ColBERT struct
- Update constructor to accept optional dense2_weights and dense2_config
- Implement dimension validation between Dense layers
- Update forward pass to apply both 1_Dense → 2_Dense sequentially
- Applied in both parallel (CPU) and sequential (GPU/WASM) code paths

### Model Loading (src/builder.rs)
- Auto-detect optional 2_Dense files (2_Dense/config.json, model.safetensors)
- Support loading from both local paths and HuggingFace Hub
- Gracefully handle models without 2_Dense (backward compatible)
- Pass optional parameters through to ColBERT constructor

### WASM Bindings (src/wasm.rs)
- Update constructor signature to accept dense2_weights and dense2_config as JsValue
- Handle undefined/null values for backward compatibility
- Fix serialization issues by returning JSON strings instead of JsValue
- Updated all WASM methods: encode, similarity, raw_similarity_matrix, hierarchical_pooling

### Dependencies (Cargo.toml)
- Add js-sys dependency for WASM builds
- Add getrandom with "js" feature for WASM target
- Update wasm feature to include js-sys

### Demo Updates (docs/index.html)
- Add retry logic with exponential backoff for WASM initialization
- Update model loading to fetch optional 2_Dense files
- Parse JSON string responses from all WASM methods
- Add mixedbread-ai/mxbai-edge-colbert-v0-17m to model selector
- Fix all three demos: Similarity, Matrix Visualization, Pooling

## Key Features

✅ Backward compatible - works with both single-Dense and two-Dense models
✅ Auto-detection - automatically loads 2_Dense when present
✅ Dimension validation - ensures layer compatibility
✅ WASM support - full browser compatibility with optional parameters
✅ Zero overhead - no performance impact when 2_Dense absent

## Expected Results

Models with 2_Dense (e.g., mxbai-edge-colbert-v0-17m):
- Output 48-dimensional embeddings (vs 512 with single Dense)
- Achieve 10.6x smaller index sizes
- Match model card specifications

Models without 2_Dense (existing models):
- Continue working as before
- No code changes required

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
@raphaelsty
Copy link
Collaborator

raphaelsty commented Oct 25, 2025

Hi @joe32140 This is an amazing MR, let me some time and we will merge it, congrats, we just need to assert that the python version using PyLate yield the same results at the PyLate-rs version. Pylate already support multi-stage projection model :D

joe32140 and others added 2 commits November 2, 2025 22:51
… support

This commit implements a complete refactor of the projection layer architecture
to support arbitrary N-layer projections (1_Dense, 2_Dense, ..., N_Dense) with
proper activation function handling per layer.

## Core Architecture Changes

### Model Layer (src/model.rs)
- Add `Activation` enum supporting Identity, ReLU, GELU, GeluErf, Tanh, SiLU
- Implement `Activation::from_pytorch_name()` to parse PyTorch activation class names
- Create `ProjectionLayer` struct encapsulating linear layer + activation
- Replace hardcoded `linear`/`linear2` with `Vec<ProjectionLayer>`
- Implement dimension validation between consecutive layers
- Update forward pass to iterate through all layers sequentially

### Builder (src/builder.rs)
- Auto-detect numbered Dense directories (1_Dense, 2_Dense, ..., up to 10)
- Scan both local paths and HuggingFace Hub
- Collect (config_bytes, weights_bytes, layer_name) tuples for each layer
- Stop scanning after first missing layer (1_Dense is required)
- Pass all layers to ColBERT constructor

### WASM Bindings (src/wasm.rs)
- Update constructor to accept projection layers array
- Use js_sys::Uint8Array for direct byte array handling
- Format: Array<{config: Uint8Array, weights: Uint8Array, name: string}>
- Extract layers using js_sys::Reflect for proper type handling

### Demo Updates (docs/index.html)
- Auto-detect and load all Dense layers from HuggingFace
- Pass Uint8Array directly (not Array.from conversion)
- Add await and 100ms delay after WASM init to fix race condition
- Update constructor call with new signature
- Fix initial load failure by ensuring WASM is fully ready

## Key Features

✅ Activation functions applied after each linear layer
✅ Supports arbitrary 1-N projection layers (tested up to 10)
✅ Backward compatible with single-layer models
✅ Proper dimension validation at load time
✅ Auto-detection of available layers
✅ Clean encapsulation with ProjectionLayer struct

## Architecture Flow

Token Embeddings → 1_Dense (Linear + Activation) → 2_Dense (Linear + Activation) → ... → N_Dense (Linear + Activation) → L2 Normalization

## Testing

- ✅ All existing tests pass (3/3)
- ✅ Single-layer models work (e.g., answerai-colbert-small-v1)
- ✅ Multi-layer models work (e.g., mxbai-edge-colbert-v0-17m)
- ✅ WASM demo loads successfully on first attempt

## Alignment with PyLate

Implementation verified against Python PyLate repository:
- Matches activation function application logic
- Compatible with same directory structure (1_Dense, 2_Dense, ...)
- Sequential layer iteration matches Python's module iteration
- Config-based activation parsing aligns with PyLate's Dense class

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Changed from hardcoded `1..=10` range to infinite loop that scans until
a layer is not found. This allows models to have any number of projection
layers (1_Dense, 2_Dense, ..., N_Dense) without artificial limits.

## Changes

### Rust Builder (src/builder.rs)
- Replace `for i in 1..=10` with `while` loop and manual counter
- Scan continues until first missing layer is encountered
- Applied to both local path and HuggingFace Hub loading

### JavaScript Demo (docs/index.html)
- Replace `for (let i = 1; i <= 10; i++)` with `while (true)` loop
- Increments counter only when layer found
- Breaks on first missing layer

## Benefits

✅ No artificial limit on number of layers
✅ Supports future models with many projection stages
✅ Cleaner logic - stops naturally when layers end
✅ Maintains same behavior for existing models

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
@joe32140 joe32140 changed the title Add 2_Dense support for multi-stage projection models Add flexible N-layer projection architecture with activation function support Nov 3, 2025
## Code Quality Improvements
- Use method reference instead of closure (redundant_closure)
- Use .first() instead of .get(0) (get_first)
- Remove redundant field names in struct initialization

## Updated Files
- src/model.rs: Applied clippy auto-fixes
- docs/pkg/*: Rebuilt WASM with improvements

These are minor style improvements that enhance code readability
without changing functionality.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
@raphaelsty
Copy link
Collaborator

Hi @joe32140, not much time this week, but thank's again for contributing, I'll make a pass on your PR as soon as possible ☺️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants