Skip to content

Commit ba72fb3

Browse files
author
Yehudit Kerido
committed
fix skipped tests
Signed-off-by: Yehudit Kerido <[email protected]>
1 parent 816dbec commit ba72fb3

File tree

5 files changed

+119
-88
lines changed

5 files changed

+119
-88
lines changed

candle-binding/semantic-router_test.go

Lines changed: 34 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,59 +1478,55 @@ func TestGetEmbeddingSmart(t *testing.T) {
14781478
// Initialize embedding models first
14791479
err := InitEmbeddingModels(Qwen3EmbeddingModelPath, GemmaEmbeddingModelPath, true)
14801480
if err != nil {
1481-
if isModelInitializationError(err) {
1482-
t.Skipf("Skipping GetEmbeddingSmart tests due to model initialization error: %v", err)
1483-
}
14841481
t.Fatalf("Failed to initialize embedding models: %v", err)
14851482
}
14861483

14871484
t.Run("ShortTextHighLatency", func(t *testing.T) {
1488-
// Short text with high latency priority should use Traditional BERT
1485+
// Short text with high latency priority - uses Qwen3 (1024) since Gemma is not available
14891486
text := "Hello world"
14901487
embedding, err := GetEmbeddingSmart(text, 0.3, 0.8)
14911488

14921489
if err != nil {
1493-
t.Logf("GetEmbeddingSmart returned error (expected for placeholder): %v", err)
1494-
// This is expected since we're using placeholder implementation
1495-
return
1490+
t.Fatalf("GetEmbeddingSmart failed: %v", err)
14961491
}
14971492

1498-
if len(embedding) != 768 {
1499-
t.Errorf("Expected 768-dim embedding, got %d", len(embedding))
1493+
// Expect Qwen3 (1024) dimension since Gemma is not available
1494+
if len(embedding) != 1024 {
1495+
t.Errorf("Expected 1024-dim embedding, got %d", len(embedding))
15001496
}
15011497

15021498
t.Logf("Short text embedding generated: dim=%d", len(embedding))
15031499
})
15041500

15051501
t.Run("MediumTextBalanced", func(t *testing.T) {
1506-
// Medium text with balanced priorities - may select Qwen3 (1024) or Gemma (768)
1502+
// Medium text with balanced priorities - uses Qwen3 (1024) since Gemma is not available
15071503
text := strings.Repeat("This is a medium length text with enough words to exceed 512 tokens. ", 10)
15081504
embedding, err := GetEmbeddingSmart(text, 0.5, 0.5)
15091505

15101506
if err != nil {
15111507
t.Fatalf("GetEmbeddingSmart failed: %v", err)
15121508
}
15131509

1514-
// Accept both Qwen3 (1024) and Gemma (768) dimensions
1515-
if len(embedding) != 768 && len(embedding) != 1024 {
1516-
t.Errorf("Expected 768 or 1024-dim embedding, got %d", len(embedding))
1510+
// Expect Qwen3 (1024) dimension since Gemma is not available
1511+
if len(embedding) != 1024 {
1512+
t.Errorf("Expected 1024-dim embedding, got %d", len(embedding))
15171513
}
15181514

15191515
t.Logf("Medium text embedding generated: dim=%d", len(embedding))
15201516
})
15211517

15221518
t.Run("LongTextHighQuality", func(t *testing.T) {
1523-
// Long text with high quality priority should use Qwen3
1519+
// Long text with high quality priority should use Qwen3 (1024)
15241520
text := strings.Repeat("This is a very long document that requires Qwen3's 32K context support. ", 50)
15251521
embedding, err := GetEmbeddingSmart(text, 0.9, 0.2)
15261522

15271523
if err != nil {
1528-
t.Logf("GetEmbeddingSmart returned error (expected for placeholder): %v", err)
1529-
return
1524+
t.Fatalf("GetEmbeddingSmart failed: %v", err)
15301525
}
15311526

1532-
if len(embedding) != 768 {
1533-
t.Errorf("Expected 768-dim embedding, got %d", len(embedding))
1527+
// Expect Qwen3 (1024) dimension
1528+
if len(embedding) != 1024 {
1529+
t.Errorf("Expected 1024-dim embedding, got %d", len(embedding))
15341530
}
15351531

15361532
t.Logf("Long text embedding generated: dim=%d", len(embedding))
@@ -1573,9 +1569,9 @@ func TestGetEmbeddingSmart(t *testing.T) {
15731569
return
15741570
}
15751571

1576-
// Smart routing may select Qwen3 (1024) or Gemma (768) based on priorities
1577-
if len(embedding) != 768 && len(embedding) != 1024 {
1578-
t.Errorf("Expected 768 or 1024-dim embedding, got %d", len(embedding))
1572+
// Expect Qwen3 (1024) since Gemma is not available
1573+
if len(embedding) != 1024 {
1574+
t.Errorf("Expected 1024-dim embedding, got %d", len(embedding))
15791575
}
15801576
t.Logf("Priority test %s: generated %d-dim embedding", tc.desc, len(embedding))
15811577
})
@@ -1598,9 +1594,9 @@ func TestGetEmbeddingSmart(t *testing.T) {
15981594
continue
15991595
}
16001596

1601-
// Smart routing may select Qwen3 (1024) or Gemma (768)
1602-
if len(embedding) != 768 && len(embedding) != 1024 {
1603-
t.Errorf("Iteration %d: Expected 768 or 1024-dim embedding, got %d", i, len(embedding))
1597+
// Expect Qwen3 (1024) since Gemma is not available
1598+
if len(embedding) != 1024 {
1599+
t.Errorf("Iteration %d: Expected 1024-dim embedding, got %d", i, len(embedding))
16041600
}
16051601

16061602
// Verify no nil pointers
@@ -1639,11 +1635,12 @@ func BenchmarkGetEmbeddingSmart(b *testing.B) {
16391635
}
16401636

16411637
// Test constants for embedding models (Phase 4.2)
1638+
// Note: Gemma model is gated and requires HF_TOKEN, so tests use Qwen3 only
16421639
const (
16431640
Qwen3EmbeddingModelPath = "../models/Qwen3-Embedding-0.6B"
1644-
GemmaEmbeddingModelPath = "../models/embeddinggemma-300m"
1641+
GemmaEmbeddingModelPath = "" // Gemma is gated, not used in CI tests
16451642
TestEmbeddingText = "This is a test sentence for embedding generation"
1646-
TestLongContextText = "This is a longer text that might benefit from long-context embedding models like Qwen3 or Gemma"
1643+
TestLongContextText = "This is a longer text that might benefit from long-context embedding models like Qwen3"
16471644
)
16481645

16491646
// Test constants for Qwen3 Multi-LoRA
@@ -1705,23 +1702,8 @@ func TestInitEmbeddingModels(t *testing.T) {
17051702
})
17061703

17071704
t.Run("InitGemmaOnly", func(t *testing.T) {
1708-
// Similar to InitBothModels, accept already-initialized state
1709-
err := InitEmbeddingModels("", GemmaEmbeddingModelPath, true)
1710-
if err != nil {
1711-
t.Logf("InitEmbeddingModels (Gemma only) returned error (may already be initialized): %v", err)
1712-
1713-
// Verify functionality
1714-
_, testErr := GetEmbeddingSmart("test", 0.5, 0.5)
1715-
if testErr == nil {
1716-
t.Log("✓ ModelFactory is functional (already initialized)")
1717-
} else {
1718-
if isModelInitializationError(testErr) {
1719-
t.Skipf("Skipping test due to model unavailability: %v", testErr)
1720-
}
1721-
}
1722-
} else {
1723-
t.Log("✓ Gemma model initialized successfully")
1724-
}
1705+
// Gemma is a gated model requiring HF_TOKEN, skip in CI
1706+
t.Skip("Skipping Gemma-only test: Gemma is a gated model requiring HF_TOKEN")
17251707
})
17261708

17271709
t.Run("InitWithInvalidPaths", func(t *testing.T) {
@@ -1739,9 +1721,6 @@ func TestGetEmbeddingWithDim(t *testing.T) {
17391721
// Initialize embedding models first
17401722
err := InitEmbeddingModels(Qwen3EmbeddingModelPath, GemmaEmbeddingModelPath, true)
17411723
if err != nil {
1742-
if isModelInitializationError(err) {
1743-
t.Skipf("Skipping GetEmbeddingWithDim tests due to model initialization error: %v", err)
1744-
}
17451724
t.Fatalf("Failed to initialize embedding models: %v", err)
17461725
}
17471726

@@ -1806,16 +1785,16 @@ func TestGetEmbeddingWithDim(t *testing.T) {
18061785

18071786
t.Run("OversizedDimension", func(t *testing.T) {
18081787
// Test graceful degradation when requested dimension exceeds model capacity
1809-
// Qwen3: 1024, Gemma: 768, so 2048 should fall back to full dimension
1788+
// Qwen3: 1024, so 2048 should fall back to full dimension
18101789
embedding, err := GetEmbeddingWithDim(TestEmbeddingText, 0.5, 0.5, 2048)
18111790
if err != nil {
18121791
t.Errorf("Should gracefully handle oversized dimension, got error: %v", err)
18131792
return
18141793
}
18151794

1816-
// Should return full dimension (1024 for Qwen3 or 768 for Gemma)
1817-
if len(embedding) != 1024 && len(embedding) != 768 {
1818-
t.Errorf("Expected full dimension (1024 or 768), got %d", len(embedding))
1795+
// Should return full dimension (1024 for Qwen3)
1796+
if len(embedding) != 1024 {
1797+
t.Errorf("Expected full dimension (1024), got %d", len(embedding))
18191798
} else {
18201799
t.Logf("✓ Oversized dimension gracefully degraded to full dimension: %d", len(embedding))
18211800
}
@@ -1841,9 +1820,6 @@ func TestGetEmbeddingWithDim(t *testing.T) {
18411820
func TestEmbeddingConsistency(t *testing.T) {
18421821
err := InitEmbeddingModels(Qwen3EmbeddingModelPath, GemmaEmbeddingModelPath, true)
18431822
if err != nil {
1844-
if isModelInitializationError(err) {
1845-
t.Skipf("Skipping consistency tests due to model initialization error: %v", err)
1846-
}
18471823
t.Fatalf("Failed to initialize embedding models: %v", err)
18481824
}
18491825

@@ -1911,12 +1887,11 @@ func TestEmbeddingConsistency(t *testing.T) {
19111887
func TestEmbeddingPriorityRouting(t *testing.T) {
19121888
err := InitEmbeddingModels(Qwen3EmbeddingModelPath, GemmaEmbeddingModelPath, true)
19131889
if err != nil {
1914-
if isModelInitializationError(err) {
1915-
t.Skipf("Skipping priority routing tests due to model initialization error: %v", err)
1916-
}
19171890
t.Fatalf("Failed to initialize embedding models: %v", err)
19181891
}
19191892

1893+
// Note: These tests use Matryoshka dimension truncation (768) with Qwen3 model
1894+
// The dimension is truncated from Qwen3's full 1024 dimensions
19201895
testCases := []struct {
19211896
name string
19221897
text string
@@ -1931,23 +1906,23 @@ func TestEmbeddingPriorityRouting(t *testing.T) {
19311906
qualityPriority: 0.2,
19321907
latencyPriority: 0.9,
19331908
expectedDim: 768,
1934-
description: "Should prefer faster embedding model (Gemma > Qwen3)",
1909+
description: "Uses Qwen3 with Matryoshka 768 truncation",
19351910
},
19361911
{
19371912
name: "HighQualityPriority",
19381913
text: strings.Repeat("Long context text ", 30),
19391914
qualityPriority: 0.9,
19401915
latencyPriority: 0.2,
19411916
expectedDim: 768,
1942-
description: "Should prefer quality model (Qwen3/Gemma)",
1917+
description: "Uses Qwen3 with Matryoshka 768 truncation",
19431918
},
19441919
{
19451920
name: "BalancedPriority",
19461921
text: "Medium length text for embedding",
19471922
qualityPriority: 0.5,
19481923
latencyPriority: 0.5,
19491924
expectedDim: 768,
1950-
description: "Should select based on text length",
1925+
description: "Uses Qwen3 with Matryoshka 768 truncation",
19511926
},
19521927
}
19531928

candle-binding/src/classifiers/unified.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use parking_lot::RwLock;
1111
use std::collections::HashMap;
1212
use std::time::Instant;
1313

14+
use crate::ffi::embedding::GLOBAL_MODEL_FACTORY;
1415
use crate::model_architectures::config::{DualPathConfig, LoRAConfig, TraditionalConfig};
1516
use crate::model_architectures::routing::{DualPathRouter, ProcessingRequirements};
1617
use crate::model_architectures::traits::*;
@@ -1024,6 +1025,45 @@ impl DualPathUnifiedClassifier {
10241025
model_type
10251026
};
10261027

1028+
// Validate model availability and fall back if necessary
1029+
let model_type = match model_type {
1030+
ModelType::GemmaEmbedding => {
1031+
// Check if Gemma is available
1032+
if let Some(factory) = GLOBAL_MODEL_FACTORY.get() {
1033+
if factory.get_gemma_model().is_none() {
1034+
// Gemma not available, fall back to Qwen3
1035+
eprintln!(
1036+
"WARNING: GemmaEmbedding selected but not available, falling back to Qwen3Embedding"
1037+
);
1038+
ModelType::Qwen3Embedding
1039+
} else {
1040+
ModelType::GemmaEmbedding
1041+
}
1042+
} else {
1043+
// No factory available, fall back to Qwen3
1044+
eprintln!(
1045+
"WARNING: ModelFactory not initialized, falling back to Qwen3Embedding"
1046+
);
1047+
ModelType::Qwen3Embedding
1048+
}
1049+
}
1050+
ModelType::Qwen3Embedding => {
1051+
// Qwen3 is the default, should always be available
1052+
// But verify just in case
1053+
if let Some(factory) = GLOBAL_MODEL_FACTORY.get() {
1054+
if factory.get_qwen3_model().is_none() {
1055+
return Err(UnifiedClassifierError::ProcessingError(
1056+
"Qwen3Embedding selected but not available and no fallback available"
1057+
.to_string(),
1058+
));
1059+
}
1060+
}
1061+
ModelType::Qwen3Embedding
1062+
}
1063+
// For non-embedding types, pass through
1064+
other => other,
1065+
};
1066+
10271067
// Log routing decision for monitoring
10281068
if self.config.embedding.enable_performance_tracking {
10291069
println!(

candle-binding/src/ffi/embedding.rs

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ enum PaddingSide {
2929
}
3030

3131
/// Global singleton for ModelFactory
32-
static GLOBAL_MODEL_FACTORY: OnceLock<ModelFactory> = OnceLock::new();
32+
pub(crate) static GLOBAL_MODEL_FACTORY: OnceLock<ModelFactory> = OnceLock::new();
3333

3434
/// Generic internal helper for single text embedding generation
3535
///
@@ -77,14 +77,18 @@ where
7777

7878
// Apply Matryoshka truncation if requested
7979
let result = if let Some(dim) = target_dim {
80-
if dim > embedding_vec.len() {
81-
return Err(format!(
82-
"Target dimension {} exceeds model dimension {}",
80+
// Gracefully degrade to model's max dimension if requested dimension is too large
81+
let actual_dim = if dim > embedding_vec.len() {
82+
eprintln!(
83+
"WARNING: Requested dimension {} exceeds model dimension {}, using full dimension",
8384
dim,
8485
embedding_vec.len()
85-
));
86-
}
87-
embedding_vec[..dim].to_vec()
86+
);
87+
embedding_vec.len()
88+
} else {
89+
dim
90+
};
91+
embedding_vec[..actual_dim].to_vec()
8892
} else {
8993
embedding_vec
9094
};
@@ -185,15 +189,19 @@ where
185189

186190
// Apply Matryoshka truncation if requested
187191
let result_embeddings = if let Some(dim) = target_dim {
188-
if dim > embedding_dim {
189-
return Err(format!(
190-
"Target dimension {} exceeds model dimension {}",
192+
// Gracefully degrade to model's max dimension if requested dimension is too large
193+
let actual_dim = if dim > embedding_dim {
194+
eprintln!(
195+
"WARNING: Requested dimension {} exceeds model dimension {}, using full dimension",
191196
dim, embedding_dim
192-
));
193-
}
197+
);
198+
embedding_dim
199+
} else {
200+
dim
201+
};
194202
embeddings_data
195203
.into_iter()
196-
.map(|emb| emb[..dim].to_vec())
204+
.map(|emb| emb[..actual_dim].to_vec())
197205
.collect()
198206
} else {
199207
embeddings_data
@@ -207,11 +215,11 @@ where
207215
/// # Safety
208216
/// - `qwen3_model_path` and `gemma_model_path` must be valid null-terminated C strings or null
209217
/// - Must be called before any embedding generation functions
210-
/// - Can only be called once (subsequent calls will be ignored)
218+
/// - Can only be called once (subsequent calls will return true as already initialized)
211219
///
212220
/// # Returns
213-
/// - `true` if initialization succeeded
214-
/// - `false` if initialization failed or already initialized
221+
/// - `true` if initialization succeeded or already initialized
222+
/// - `false` if initialization failed
215223
#[no_mangle]
216224
pub extern "C" fn init_embedding_models(
217225
qwen3_model_path: *const c_char,
@@ -220,6 +228,12 @@ pub extern "C" fn init_embedding_models(
220228
) -> bool {
221229
use candle_core::Device;
222230

231+
// Check if already initialized (OnceLock can only be set once)
232+
if GLOBAL_MODEL_FACTORY.get().is_some() {
233+
eprintln!("WARNING: ModelFactory already initialized");
234+
return true; // Already initialized, return success
235+
}
236+
223237
// Parse model paths
224238
let qwen3_path = if qwen3_model_path.is_null() {
225239
None

0 commit comments

Comments
 (0)