From ba2f778bbdad2f3cf6ba9d71fbd670b67ae1a628 Mon Sep 17 00:00:00 2001 From: Luke Van Seters Date: Sun, 6 Aug 2023 00:56:37 -0400 Subject: [PATCH 1/2] Fix CVModel creation in TopicMembershipModel --- adatest/_topic_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/adatest/_topic_model.py b/adatest/_topic_model.py index 752280d..6536466 100644 --- a/adatest/_topic_model.py +++ b/adatest/_topic_model.py @@ -157,8 +157,7 @@ def __init__(self, topic, test_tree): else: # we are in a highly overparametrized situation, so we use a linear SVC to get "max-margin" based generalization - self.model = CVModel() - self.model.fit(embeddings, labels) + self.model = CVModel(embeddings, labels) def __call__(self, input): embeddings = adatest.embed([input])[0] From 848bbabc0a27ec845e105e2be94648069300526e Mon Sep 17 00:00:00 2001 From: Luke Van Seters Date: Sun, 6 Aug 2023 13:17:48 -0400 Subject: [PATCH 2/2] Fix class_weights to support any labels --- adatest/_topic_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/adatest/_topic_model.py b/adatest/_topic_model.py index 6536466..e46f242 100644 --- a/adatest/_topic_model.py +++ b/adatest/_topic_model.py @@ -19,7 +19,8 @@ def predict_prob(self, embeddings): class CVModel(): def __init__(self, embeddings, labels): - self.inner_model = RidgeClassifierCV(class_weight={"pass": 1, "fail": 1}) + class_weight = {label: 1 for label in labels} + self.inner_model = RidgeClassifierCV(class_weight=class_weight) self.inner_model.fit(embeddings, labels) def predict_prob(self, embeddings):