Skip to content

Commit 5f3fd0f

Browse files
Merge pull request #11 from dev-diaries41/fix/cliptext
fix: prevent IllegalCapacity in embed
2 parents e032ccd + f9b7230 commit 5f3fd0f

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

ml/src/androidTest/kotlin/com/fpf/smartscansdk/ml/models/providers/embeddings/clip/ClipTextEmbedderTest.kt

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,4 +115,28 @@ class ClipTextEmbedderInstrumentedTest {
115115

116116
verify(exactly = 1) { (mockModel as AutoCloseable).close() }
117117
}
118+
119+
@Test
120+
fun `embed handles strings longer than 77 tokens`() = runBlocking {
121+
val embedder = ClipTextEmbedder(context, ResourceId(0))
122+
val mockModel = mockk<OnnxModel>(relaxed = true)
123+
every { mockModel.isLoaded() } returns true
124+
every { mockModel.getInputNames() } returns listOf("input")
125+
every { mockModel.getEnv() } returns mockk<OrtEnvironment>()
126+
127+
val raw = Array(1) { FloatArray(embedder.embeddingDim) { 1.0f } }
128+
every { mockModel.run(any<Map<String, TensorData>>()) } returns mapOf("out" to raw)
129+
130+
val field = embedder::class.java.getDeclaredField("model")
131+
field.isAccessible = true
132+
field.set(embedder, mockModel)
133+
134+
val longText = "a".repeat(2000)
135+
val embedding = embedder.embed(longText)
136+
137+
assertEquals(embedder.embeddingDim, embedding.size)
138+
val l2 = sqrt(embedding.map { it * it }.sum())
139+
assertTrue(abs(l2 - 1.0f) < 1e-3)
140+
}
141+
118142
}

ml/src/main/java/com/fpf/smartscansdk/ml/models/providers/embeddings/clip/ClipTextEmbedder.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ class ClipTextEmbedder(
5050
if (!isInitialized()) throw IllegalStateException("Model not initialized")
5151

5252
val clean = Regex("[^A-Za-z0-9 ]").replace(data, "").lowercase()
53-
var tokens = mutableListOf(tokenBOS) + tokenizer.encode(clean) + tokenEOS
54-
tokens = tokens.take(77) + List(77 - tokens.size) { 0 }
53+
var tokens = (mutableListOf(tokenBOS) + tokenizer.encode(clean) + tokenEOS).take(77).toMutableList()
54+
if (tokens.size < 77) tokens += List(77 - tokens.size) { 0 }
55+
5556

5657
val inputIds = LongBuffer.allocate(1 * 77).apply {
5758
tokens.forEach { put(it.toLong()) }

0 commit comments

Comments
 (0)