-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathintent_classifier_embedding.py
134 lines (108 loc) · 3.9 KB
/
intent_classifier_embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from typing import List, Optional, TYPE_CHECKING
from numpy import mean
from mcp_agent.workflows.embedding.embedding_base import (
FloatArray,
EmbeddingModel,
compute_confidence,
compute_similarity_scores,
)
from mcp_agent.workflows.intent_classifier.intent_classifier_base import (
Intent,
IntentClassifier,
IntentClassificationResult,
)
if TYPE_CHECKING:
from mcp_agent.context import Context
class EmbeddingIntent(Intent):
"""An intent with embedding information"""
embedding: FloatArray | None = None
"""Pre-computed embedding for this intent"""
class EmbeddingIntentClassifier(IntentClassifier):
"""
An intent classifier that uses embedding similarity for classification.
Supports different embedding models through the EmbeddingModel interface.
Features:
- Semantic similarity based classification
- Support for example-based learning
- Flexible embedding model support
- Multiple similarity computation strategies
"""
def __init__(
self,
intents: List[Intent],
embedding_model: EmbeddingModel,
context: Optional["Context"] = None,
**kwargs,
):
super().__init__(intents=intents, context=context, **kwargs)
self.embedding_model = embedding_model
self.initialized = False
@classmethod
async def create(
cls,
intents: List[Intent],
embedding_model: EmbeddingModel,
) -> "EmbeddingIntentClassifier":
"""
Factory method to create and initialize a classifier.
Use this instead of constructor since we need async initialization.
"""
instance = cls(
intents=intents,
embedding_model=embedding_model,
)
await instance.initialize()
return instance
async def initialize(self):
"""
Precompute embeddings for all intents by combining their
descriptions and examples
"""
if self.initialized:
return
for intent in self.intents.values():
# Combine all text for a rich intent representation
intent_texts = [intent.name, intent.description] + intent.examples
# Get embeddings for all texts
embeddings = await self.embedding_model.embed(intent_texts)
# Use mean pooling to combine embeddings
embedding = mean(embeddings, axis=0)
# Create intents with embeddings
self.intents[intent.name] = EmbeddingIntent(
**intent,
embedding=embedding,
)
self.initialized = True
async def classify(
self, request: str, top_k: int = 1
) -> List[IntentClassificationResult]:
"""
Classify the input text into one or more intents
Args:
text: Input text to classify
top_k: Maximum number of top matches to return
Returns:
List of classification results, ordered by confidence
"""
if not self.initialized:
await self.initialize()
# Get embedding for input
embeddings = await self.embedding_model.embed([request])
request_embedding = embeddings[0] # Take first since we only embedded one text
results: List[IntentClassificationResult] = []
for intent_name, intent in self.intents.items():
if intent.embedding is None:
continue
similarity_scores = compute_similarity_scores(
request_embedding, intent.embedding
)
# Compute overall confidence score
confidence = compute_confidence(similarity_scores)
results.append(
IntentClassificationResult(
intent=intent_name,
p_score=confidence,
)
)
results.sort(key=lambda x: x.p_score, reverse=True)
return results[:top_k]