diff --git a/scripts/e2e_eval/testsets/models_with_acc.json b/scripts/e2e_eval/testsets/models_with_acc.json index 99744e381..2d3067b22 100644 --- a/scripts/e2e_eval/testsets/models_with_acc.json +++ b/scripts/e2e_eval/testsets/models_with_acc.json @@ -1271,6 +1271,23 @@ } } }, + { + "hf_id": "cross-encoder/nli-deberta-v3-small", + "task": "zero-shot-classification", + "model_type": "deberta-v2", + "group": "Top200", + "priority": "P1", + "dataset_config": { + "path": "nyu-mll/multi_nli", + "split": "validation_matched", + "metric": "accuracy", + "columns_mapping": { + "input_column": "premise", + "label_column": "genre", + "candidate_labels": "fiction,government,slate,telephone,travel" + } + } + }, { "hf_id": "openai/clip-vit-base-patch32", "task": "zero-shot-image-classification", @@ -1289,6 +1306,23 @@ } } }, + { + "hf_id": "joeddav/xlm-roberta-large-xnli", + "task": "zero-shot-classification", + "model_type": "xlm-roberta", + "group": "Top200", + "priority": "P1", + "dataset_config": { + "path": "nyu-mll/multi_nli", + "split": "validation_matched", + "metric": "accuracy", + "columns_mapping": { + "input_column": "premise", + "label_column": "genre", + "candidate_labels": "fiction,government,slate,telephone,travel" + } + } + }, { "hf_id": "openai/clip-vit-large-patch14", "task": "zero-shot-image-classification", @@ -1307,6 +1341,23 @@ } } }, + { + "hf_id": "lxyuan/distilbert-base-multilingual-cased-sentiments-student", + "task": "zero-shot-classification", + "model_type": "distilbert", + "group": "Top200", + "priority": "P1", + "dataset_config": { + "path": "nyu-mll/multi_nli", + "split": "validation_matched", + "metric": "accuracy", + "columns_mapping": { + "input_column": "premise", + "label_column": "genre", + "candidate_labels": "fiction,government,slate,telephone,travel" + } + } + }, { "hf_id": "openai/clip-vit-large-patch14-336", "task": "zero-shot-image-classification", @@ -1325,6 +1376,23 @@ } } }, + { + "hf_id": "MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli", + "task": "zero-shot-classification", + "model_type": "deberta-v2", + "group": "Top200", + "priority": "P1", + "dataset_config": { + "path": "nyu-mll/multi_nli", + "split": "validation_matched", + "metric": "accuracy", + "columns_mapping": { + "input_column": "premise", + "label_column": "genre", + "candidate_labels": "fiction,government,slate,telephone,travel" + } + } + }, { "hf_id": "openai/clip-vit-base-patch16", "task": "zero-shot-image-classification", @@ -1343,6 +1411,23 @@ } } }, + { + "hf_id": "MoritzLaurer/deberta-v3-large-zeroshot-v2.0", + "task": "zero-shot-classification", + "model_type": "deberta-v2", + "group": "Top200", + "priority": "P1", + "dataset_config": { + "path": "nyu-mll/multi_nli", + "split": "validation_matched", + "metric": "accuracy", + "columns_mapping": { + "input_column": "premise", + "label_column": "genre", + "candidate_labels": "fiction,government,slate,telephone,travel" + } + } + }, { "hf_id": "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", "task": "zero-shot-image-classification", @@ -1361,6 +1446,23 @@ } } }, + { + "hf_id": "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli", + "task": "zero-shot-classification", + "model_type": "deberta-v2", + "group": "Top200", + "priority": "P1", + "dataset_config": { + "path": "nyu-mll/multi_nli", + "split": "validation_matched", + "metric": "accuracy", + "columns_mapping": { + "input_column": "premise", + "label_column": "genre", + "candidate_labels": "fiction,government,slate,telephone,travel" + } + } + }, { "hf_id": "patrickjohncyh/fashion-clip", "task": "zero-shot-image-classification", @@ -1378,6 +1480,23 @@ } } }, + { + "hf_id": "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7", + "task": "zero-shot-classification", + "model_type": "deberta-v2", + "group": "Top200", + "priority": "P1", + "dataset_config": { + "path": "nyu-mll/multi_nli", + "split": "validation_matched", + "metric": "accuracy", + "columns_mapping": { + "input_column": "premise", + "label_column": "genre", + "candidate_labels": "fiction,government,slate,telephone,travel" + } + } + }, { "hf_id": "google/siglip-so400m-patch14-384", "task": "zero-shot-image-classification", diff --git a/src/winml/modelkit/datasets/__init__.py b/src/winml/modelkit/datasets/__init__.py index 490ea1c57..532db1628 100644 --- a/src/winml/modelkit/datasets/__init__.py +++ b/src/winml/modelkit/datasets/__init__.py @@ -45,6 +45,7 @@ "sentence-similarity": TextDataset, "next-sentence-prediction": TextDataset, "fill-mask": TextDataset, + "zero-shot-classification": TextDataset, "image-segmentation": ImageSegmentationDataset, "random": RandomDataset, # Add more task types as needed diff --git a/src/winml/modelkit/eval/__init__.py b/src/winml/modelkit/eval/__init__.py index 06966de91..48f99d6f9 100644 --- a/src/winml/modelkit/eval/__init__.py +++ b/src/winml/modelkit/eval/__init__.py @@ -15,6 +15,7 @@ from .fill_mask_evaluator import WinMLFillMaskEvaluator from .image_feature_extraction_evaluator import WinMLImageFeatureExtractionEvaluator from .image_segmentation_evaluator import WinMLImageSegmentationEvaluator +from .metrics.classification import ClassificationMetric from .metrics.knn_accuracy import KNNAccuracyMetric from .metrics.mean_average_precision import MAPMetric from .metrics.mean_iou import IGNORE_INDEX, MeanIoUMetric @@ -25,11 +26,13 @@ from .question_answering_evaluator import WinMLQuestionAnsweringEvaluator from .text_classification_evaluator import WinMLTextClassificationEvaluator from .token_classification_evaluator import WinMLTokenClassificationEvaluator +from .zero_shot_classification_evaluator import WinMLZeroShotClassificationEvaluator from .zero_shot_image_classification_evaluator import WinMLZeroShotImageClassificationEvaluator __all__ = [ "IGNORE_INDEX", + "ClassificationMetric", "EvalResult", "KNNAccuracyMetric", "MAPMetric", @@ -47,6 +50,7 @@ "WinMLQuestionAnsweringEvaluator", "WinMLTextClassificationEvaluator", "WinMLTokenClassificationEvaluator", + "WinMLZeroShotClassificationEvaluator", "WinMLZeroShotImageClassificationEvaluator", "evaluate", ] diff --git a/src/winml/modelkit/eval/base_evaluator.py b/src/winml/modelkit/eval/base_evaluator.py index 11cb64c52..5c8e5c0f4 100644 --- a/src/winml/modelkit/eval/base_evaluator.py +++ b/src/winml/modelkit/eval/base_evaluator.py @@ -143,6 +143,38 @@ def prepare_pipeline(self) -> Pipeline: device="cpu", ) + def _fixed_seq_length(self) -> int | None: + """Return the model's fixed sequence length, or ``None`` if dynamic. + + Reads ``io_config["input_shapes"]`` and treats an integer second + dimension as a static sequence length. Subclasses use this to decide + whether tokenized inputs need to be padded/truncated to a fixed size. + """ + io_config = getattr(self.model, "io_config", None) or {} + shapes = io_config.get("input_shapes") or [[]] + if len(shapes[0]) > 1 and isinstance(shapes[0][1], int): + return shapes[0][1] + return None + + def _pad_or_truncate(self, encoding: Any, tokenizer: Any) -> Any: + """Resize tokenized inputs to the model's fixed sequence length. + + No-op for dynamic-shape models. Otherwise truncates over-length + tensors and delegates padding to the tokenizer. + """ + seq_len = self._fixed_seq_length() + if seq_len is None: + return encoding + for key, tensor in list(encoding.items()): + if hasattr(tensor, "shape") and tensor.dim() >= 2 and tensor.shape[1] > seq_len: + encoding[key] = tensor[:, :seq_len] + return tokenizer.pad( + encoding, + padding="max_length", + max_length=seq_len, + return_tensors="pt", + ) + def align_labels(self, dataset: Dataset, ds_config: DatasetConfig) -> Dataset: """Align dataset labels and filter unsupported IDs. diff --git a/src/winml/modelkit/eval/evaluate.py b/src/winml/modelkit/eval/evaluate.py index b48c334a9..1fa949cee 100644 --- a/src/winml/modelkit/eval/evaluate.py +++ b/src/winml/modelkit/eval/evaluate.py @@ -23,6 +23,7 @@ from .question_answering_evaluator import WinMLQuestionAnsweringEvaluator from .text_classification_evaluator import WinMLTextClassificationEvaluator from .token_classification_evaluator import WinMLTokenClassificationEvaluator +from .zero_shot_classification_evaluator import WinMLZeroShotClassificationEvaluator from .zero_shot_image_classification_evaluator import WinMLZeroShotImageClassificationEvaluator @@ -43,6 +44,7 @@ "sentence-similarity": WinMLFeatureExtractionEvaluator, "image-feature-extraction": WinMLImageFeatureExtractionEvaluator, "fill-mask": WinMLFillMaskEvaluator, + "zero-shot-classification": WinMLZeroShotClassificationEvaluator, "zero-shot-image-classification": WinMLZeroShotImageClassificationEvaluator, } @@ -127,6 +129,16 @@ streaming=True, columns_mapping={"input_column": "text"}, ), + "zero-shot-classification": DatasetConfig( + path="fancyzhx/ag_news", + split="test", + samples=100, + shuffle=True, + columns_mapping={ + "input_column": "text", + "label_column": "label", + }, + ), "zero-shot-image-classification": DatasetConfig( path="uoft-cs/cifar100", split="test", diff --git a/src/winml/modelkit/eval/metrics/__init__.py b/src/winml/modelkit/eval/metrics/__init__.py index 5f1c51010..5eaa18413 100644 --- a/src/winml/modelkit/eval/metrics/__init__.py +++ b/src/winml/modelkit/eval/metrics/__init__.py @@ -5,6 +5,7 @@ """Evaluation metrics.""" +from .classification import ClassificationMetric from .knn_accuracy import KNNAccuracyMetric from .mean_average_precision import MAPMetric from .mean_iou import IGNORE_INDEX, MeanIoUMetric @@ -15,6 +16,7 @@ __all__ = [ "IGNORE_INDEX", + "ClassificationMetric", "KNNAccuracyMetric", "MAPMetric", "MeanIoUMetric", diff --git a/src/winml/modelkit/eval/metrics/classification.py b/src/winml/modelkit/eval/metrics/classification.py new file mode 100644 index 000000000..dfaffd79e --- /dev/null +++ b/src/winml/modelkit/eval/metrics/classification.py @@ -0,0 +1,56 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Classification metrics. + +Accuracy and macro-F1 over string labels, for classification evaluators +that do not have an HF evaluate wrapper (e.g. zero-shot-classification). +""" + +from __future__ import annotations + +from typing import Any + + +class ClassificationMetric: + """Accuracy and macro-F1 over string labels.""" + + def compute( + self, + predictions: list[str], + references: list[str], + labels: list[str], + ) -> dict[str, Any]: + """Compute accuracy and macro-F1. + + Args: + predictions: Predicted label strings, one per sample. + references: Ground-truth label strings, one per sample. + labels: Full set of class labels for macro-F1 averaging. + + Returns: + Dict with ``accuracy`` and ``f1`` (both floats in [0, 1]). + """ + from sklearn.metrics import accuracy_score, f1_score + + if len(predictions) != len(references): + raise ValueError( + f"predictions and references must have the same length, " + f"got {len(predictions)} vs {len(references)}.", + ) + if not references: + raise ValueError("references must not be empty.") + if not labels: + raise ValueError("labels must not be empty.") + + accuracy = accuracy_score(references, predictions) + macro_f1 = f1_score( + references, + predictions, + labels=labels, + average="macro", + zero_division=0, + ) + return {"accuracy": float(accuracy), "f1": float(macro_f1)} diff --git a/src/winml/modelkit/eval/zero_shot_classification_evaluator.py b/src/winml/modelkit/eval/zero_shot_classification_evaluator.py new file mode 100644 index 000000000..542d6eb49 --- /dev/null +++ b/src/winml/modelkit/eval/zero_shot_classification_evaluator.py @@ -0,0 +1,196 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Zero-shot classification evaluator for NLI checkpoints. + +Computes accuracy and macro-F1. The HF evaluate library has no +zero-shot-classification evaluator, so this class runs the metric loop +manually: each text is scored against every candidate label as a +(premise, hypothesis) NLI pair, and the label with the top entailment +score wins. + +Candidate labels come from ``columns_mapping["candidate_labels"]`` if set +(comma-separated), otherwise from ``dataset.features[label_column].names`` +when the column is a ``ClassLabel``. An override on a ``ClassLabel`` column +must list one label per class, in order, and replaces the class names +positionally for both predictions and references. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from tqdm import tqdm +from transformers.pipelines.zero_shot_classification import ZeroShotClassificationPipeline + +from .base_evaluator import WinMLEvaluator + + +if TYPE_CHECKING: + from datasets import Dataset + from transformers.pipelines.base import Pipeline + + from ..datasets.config import DatasetConfig + from ..models.winml.base import WinMLPreTrainedModel + from .config import WinMLEvaluationConfig + + +class _FixedShapeZeroShotPipeline(ZeroShotClassificationPipeline): + """Resize tokenized pairs for fixed-shape ONNX exports. + + Delegates padding/truncation to the owning ``WinMLEvaluator`` (set as + ``_winml_evaluator`` after construction). + """ + + _winml_evaluator: WinMLEvaluator | None = None + + def _parse_and_tokenize(self, sequence_pairs: Any, **kwargs: Any) -> Any: + kwargs.setdefault("padding", True) + kwargs.setdefault("truncation", True) + encoding = super()._parse_and_tokenize(sequence_pairs, **kwargs) + if self._winml_evaluator is None or self.tokenizer is None: + return encoding + return self._winml_evaluator._pad_or_truncate(encoding, self.tokenizer) + + +class WinMLZeroShotClassificationEvaluator(WinMLEvaluator): + """Evaluator for zero-shot text classification using NLI models.""" + + @classmethod + def schema_info(cls) -> list: + """Return expected dataset schema for zero-shot classification.""" + from .config import SchemaColumn + + return [ + SchemaColumn("text", "Value(string)", "input_column", description="input text"), + SchemaColumn( + "label", + "ClassLabel", + "label_column", + description="gold label (ClassLabel or string)", + ), + SchemaColumn( + "", + "comma-separated str", + "candidate_labels", + required=False, + description="override candidate labels (required for non-ClassLabel columns)", + ), + SchemaColumn( + "", + "Value(string)", + "hypothesis_template", + required=False, + description='NLI hypothesis template (default: "This example is {}.")', + ), + ] + + def __init__( + self, + config: WinMLEvaluationConfig, + model: WinMLPreTrainedModel, + ) -> None: + mapping = config.dataset.columns_mapping + self._input_col = mapping.get("input_column", "text") + self._label_col = mapping.get("label_column", "label") + self._candidate_labels_override = mapping.get("candidate_labels") + self._hypothesis_template = mapping.get("hypothesis_template") + super().__init__(config, model) + + def prepare_pipeline(self) -> Pipeline: + """Create pipeline with fixed-length tokenization for ONNX.""" + from transformers import pipeline + + max_length = self._fixed_seq_length() + + pipe = pipeline( + "zero-shot-classification", + model=self.model, + framework="pt", + tokenizer=self.config.model_id, + device="cpu", + pipeline_class=_FixedShapeZeroShotPipeline, + ) + pipe._winml_evaluator = self + + if pipe.tokenizer is not None and max_length is not None: + pipe.tokenizer.model_max_length = max_length + + # Drop tokenizer keys the ONNX graph does not accept + # (some NLI exports omit token_type_ids). + io_config = getattr(self.model, "io_config", None) or {} + input_names = io_config.get("input_names", []) + if input_names: + filtered = [n for n in pipe.tokenizer.model_input_names if n in input_names] + if filtered: + pipe.tokenizer.model_input_names = filtered + + return pipe + + def align_labels( + self, + dataset: Dataset, + ds_config: DatasetConfig, + ) -> Dataset: + """Validate input and label columns. + + Base-class label alignment is bypassed: NLI ``label2id`` identifies + entailment/neutral/contradiction classes, which are unrelated to the + ground-truth labels used for accuracy. + """ + col_names = set(dataset.column_names) + for col in (self._input_col, self._label_col): + if col not in col_names: + raise ValueError(f"Column '{col}' not found in dataset: {sorted(col_names)}") + return dataset + + def _resolve_candidate_labels(self, dataset: Dataset) -> list[str]: + """Return candidate labels from user override or dataset ``ClassLabel``.""" + if self._candidate_labels_override: + labels = [s.strip() for s in self._candidate_labels_override.split(",") if s.strip()] + if not labels: + raise ValueError("candidate_labels override must not be empty.") + return labels + + names = getattr(dataset.features.get(self._label_col), "names", None) + if names: + return list(names) + + raise ValueError( + f"Column '{self._label_col}' is not a ClassLabel; pass " + f'--column "candidate_labels=a,b,...".', + ) + + def compute(self) -> dict[str, Any]: + """Compute accuracy and macro-F1 over all samples.""" + from .metrics import ClassificationMetric + + candidate_labels = self._resolve_candidate_labels(self.data) + class_names = getattr(self.data.features.get(self._label_col), "names", None) + + # An override replaces ClassLabel.names positionally, so references use + # the same vocabulary as predictions. + if self._candidate_labels_override and class_names is not None: + if len(candidate_labels) != len(class_names): + raise ValueError( + f"candidate_labels override has {len(candidate_labels)} entries " + f"but dataset ClassLabel has {len(class_names)}; provide one " + f"override label per class, in order.", + ) + class_names = candidate_labels + + pipe_kwargs: dict[str, Any] = {"candidate_labels": candidate_labels} + if self._hypothesis_template is not None: + pipe_kwargs["hypothesis_template"] = self._hypothesis_template + + predictions: list[str] = [] + references: list[str] = [] + for sample in tqdm(self.data, desc="Evaluating zero-shot (accuracy)"): + result = self.pipe(sample[self._input_col], **pipe_kwargs) + predictions.append(result["labels"][0]) + raw = sample[self._label_col] + references.append(class_names[int(raw)] if class_names else str(raw)) + + return ClassificationMetric().compute(predictions, references, candidate_labels) diff --git a/tests/integration/eval/__init__.py b/tests/integration/eval/__init__.py new file mode 100644 index 000000000..862c45ce3 --- /dev/null +++ b/tests/integration/eval/__init__.py @@ -0,0 +1,4 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- diff --git a/tests/integration/eval/test_zero_shot_classification.py b/tests/integration/eval/test_zero_shot_classification.py new file mode 100644 index 000000000..ea0b4d16b --- /dev/null +++ b/tests/integration/eval/test_zero_shot_classification.py @@ -0,0 +1,53 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""End-to-end integration test for zero-shot classification evaluation. + +Downloads a small NLI checkpoint and runs the evaluator against a handful +of samples from AG News. Skipped by default via `pytest -m "not slow"`. +""" + +from __future__ import annotations + +import pytest + +from winml.modelkit.datasets.config import DatasetConfig +from winml.modelkit.eval import WinMLZeroShotClassificationEvaluator +from winml.modelkit.eval.config import WinMLEvaluationConfig + + +# Representative NLI checkpoints across the three families listed in issue #325. +_MODEL_IDS = [ + "typeform/distilbert-base-uncased-mnli", + "cross-encoder/nli-roberta-base", + "MoritzLaurer/deberta-v3-base-zeroshot-v2.0", +] + + +@pytest.mark.slow +@pytest.mark.network +@pytest.mark.integration +@pytest.mark.parametrize("model_id", _MODEL_IDS) +def test_zero_shot_classification_end_to_end(model_id: str) -> None: + from transformers import AutoModelForSequenceClassification + + model = AutoModelForSequenceClassification.from_pretrained(model_id) + + config = WinMLEvaluationConfig( + model_id=model_id, + task="zero-shot-classification", + dataset=DatasetConfig( + path="fancyzhx/ag_news", + split="test", + samples=5, + shuffle=False, + ), + ) + + results = WinMLZeroShotClassificationEvaluator(config, model).compute() + + assert "accuracy" in results + assert "f1" in results + assert 0.0 <= results["accuracy"] <= 1.0 + assert 0.0 <= results["f1"] <= 1.0 diff --git a/tests/unit/eval/test_classification_metric.py b/tests/unit/eval/test_classification_metric.py new file mode 100644 index 000000000..eae486335 --- /dev/null +++ b/tests/unit/eval/test_classification_metric.py @@ -0,0 +1,114 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Tests for ClassificationMetric class.""" + +from __future__ import annotations + +import pytest + +from winml.modelkit.eval import ClassificationMetric + + +class TestClassificationMetricBasic: + """Accuracy and macro-F1 over string labels.""" + + def test_perfect_predictions(self) -> None: + metric = ClassificationMetric() + result = metric.compute( + predictions=["a", "b", "c"], + references=["a", "b", "c"], + labels=["a", "b", "c"], + ) + assert result["accuracy"] == pytest.approx(1.0) + assert result["f1"] == pytest.approx(1.0) + + def test_all_wrong(self) -> None: + metric = ClassificationMetric() + result = metric.compute( + predictions=["a", "a"], + references=["b", "b"], + labels=["a", "b"], + ) + assert result["accuracy"] == pytest.approx(0.0) + assert result["f1"] == pytest.approx(0.0) + + def test_half_correct(self) -> None: + metric = ClassificationMetric() + result = metric.compute( + predictions=["a", "b", "a", "b"], + references=["a", "a", "b", "b"], + labels=["a", "b"], + ) + assert result["accuracy"] == pytest.approx(0.5) + + +class TestClassificationMetricLabels: + """Full class set preserved via ``labels`` argument.""" + + def test_unseen_class_in_macro_f1(self) -> None: + """Classes with no predictions should contribute 0 to macro-F1.""" + metric = ClassificationMetric() + result = metric.compute( + predictions=["a", "a"], + references=["a", "a"], + labels=["a", "b", "c"], + ) + # Class 'a' is perfect (F1=1.0); 'b' and 'c' unseen (F1=0 each). + # Macro-F1 = (1 + 0 + 0) / 3. + assert result["accuracy"] == pytest.approx(1.0) + assert result["f1"] == pytest.approx(1.0 / 3) + + def test_labels_order_does_not_affect_result(self) -> None: + metric = ClassificationMetric() + r1 = metric.compute( + predictions=["a", "b"], + references=["a", "b"], + labels=["a", "b"], + ) + r2 = metric.compute( + predictions=["a", "b"], + references=["a", "b"], + labels=["b", "a"], + ) + assert r1 == r2 + + +class TestClassificationMetricValidation: + def test_length_mismatch_raises(self) -> None: + metric = ClassificationMetric() + with pytest.raises(ValueError, match="same length"): + metric.compute( + predictions=["a", "b"], + references=["a"], + labels=["a", "b"], + ) + + def test_empty_references_raises(self) -> None: + metric = ClassificationMetric() + with pytest.raises(ValueError, match="references"): + metric.compute(predictions=[], references=[], labels=["a"]) + + def test_empty_labels_raises(self) -> None: + metric = ClassificationMetric() + with pytest.raises(ValueError, match="labels"): + metric.compute(predictions=["a"], references=["a"], labels=[]) + + +class TestClassificationMetricZeroDivision: + """``zero_division=0`` prevents sklearn warnings / crashes.""" + + def test_single_class_prediction(self) -> None: + """All predictions collapse to one class — other classes have no preds.""" + metric = ClassificationMetric() + result = metric.compute( + predictions=["a", "a", "a"], + references=["a", "b", "c"], + labels=["a", "b", "c"], + ) + # accuracy = 1/3; classes b and c have precision/recall issues, + # but zero_division=0 guarantees a finite F1. + assert result["accuracy"] == pytest.approx(1 / 3) + assert 0.0 <= result["f1"] <= 1.0 diff --git a/tests/unit/eval/test_zero_shot_classification_evaluator.py b/tests/unit/eval/test_zero_shot_classification_evaluator.py new file mode 100644 index 000000000..a6e02e0b0 --- /dev/null +++ b/tests/unit/eval/test_zero_shot_classification_evaluator.py @@ -0,0 +1,576 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Unit tests for WinMLZeroShotClassificationEvaluator.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +from datasets import ClassLabel, Dataset, Features, Value + +from winml.modelkit.eval import ( + WinMLEvaluationConfig, + WinMLZeroShotClassificationEvaluator, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +CANDIDATE_LABELS = ["World", "Sports", "Business", "Sci/Tech"] + + +def _make_classlabel_dataset( + texts: list[str], + labels: list[int], + class_names: list[str] | None = None, + input_col: str = "text", + label_col: str = "label", +) -> Dataset: + names = class_names or CANDIDATE_LABELS + features = Features( + { + input_col: Value("string"), + label_col: ClassLabel(names=names), + } + ) + return Dataset.from_dict({input_col: texts, label_col: labels}, features=features) + + +def _make_string_dataset( + texts: list[str], + labels: list[str], + input_col: str = "text", + label_col: str = "label", +) -> Dataset: + features = Features( + { + input_col: Value("string"), + label_col: Value("string"), + } + ) + return Dataset.from_dict({input_col: texts, label_col: labels}, features=features) + + +def make_evaluator( + dataset: Dataset, + columns_mapping: dict[str, str] | None = None, + pipe: MagicMock | None = None, +) -> WinMLZeroShotClassificationEvaluator: + """Construct an evaluator without going through HF loading.""" + from winml.modelkit.datasets import DatasetConfig + + mapping = columns_mapping or {"input_column": "text", "label_column": "label"} + + if pipe is None: + pipe = MagicMock() + pipe.tokenizer = None + + model = MagicMock() + model.config.label2id = {"entailment": 0, "neutral": 1, "contradiction": 2} + model.io_config = {} + + config = WinMLEvaluationConfig( + model_id="test/model", + task="zero-shot-classification", + dataset=DatasetConfig( + path="dummy", + columns_mapping=mapping, + samples=len(dataset), + shuffle=False, + ), + ) + + with ( + patch.object( + WinMLZeroShotClassificationEvaluator, + "prepare_data", + return_value=dataset, + ), + patch.object( + WinMLZeroShotClassificationEvaluator, + "prepare_pipeline", + return_value=pipe, + ), + ): + return WinMLZeroShotClassificationEvaluator(config, model) + + +# --------------------------------------------------------------------------- +# Registry wiring +# --------------------------------------------------------------------------- + + +class TestRegistry: + def test_evaluator_registered(self) -> None: + from winml.modelkit.eval.evaluate import _EVALUATOR_REGISTRY + + assert "zero-shot-classification" in _EVALUATOR_REGISTRY + assert ( + _EVALUATOR_REGISTRY["zero-shot-classification"] is WinMLZeroShotClassificationEvaluator + ) + + def test_default_dataset_registered(self) -> None: + from winml.modelkit.eval.evaluate import _DEFAULT_DATASETS + + cfg = _DEFAULT_DATASETS["zero-shot-classification"] + assert cfg.path is not None + assert cfg.columns_mapping.get("input_column") is not None + assert cfg.columns_mapping.get("label_column") is not None + + def test_exported_from_package(self) -> None: + from winml.modelkit import eval as eval_pkg + + assert "WinMLZeroShotClassificationEvaluator" in eval_pkg.__all__ + + +# --------------------------------------------------------------------------- +# _FixedShapeZeroShotPipeline +# --------------------------------------------------------------------------- + + +class TestFixedShapePipeline: + """The subclass delegates resizing to the evaluator's pad/truncate helper.""" + + def test_calls_evaluator_pad_or_truncate(self) -> None: + from transformers.pipelines.zero_shot_classification import ( + ZeroShotClassificationPipeline, + ) + + from winml.modelkit.eval.zero_shot_classification_evaluator import ( + _FixedShapeZeroShotPipeline, + ) + + captured: dict = {} + sentinel_encoding = {"input_ids": MagicMock()} + padded = {"input_ids": MagicMock()} + + def _fake_super_parse(self, sequence_pairs, **kwargs): + captured.update(kwargs) + return sentinel_encoding + + evaluator = MagicMock() + evaluator._pad_or_truncate.return_value = padded + + instance = _FixedShapeZeroShotPipeline.__new__(_FixedShapeZeroShotPipeline) + instance._winml_evaluator = evaluator + instance.tokenizer = MagicMock() + with patch.object(ZeroShotClassificationPipeline, "_parse_and_tokenize", _fake_super_parse): + result = instance._parse_and_tokenize([("a", "b")]) + + assert captured["truncation"] is True + assert captured["padding"] is True + evaluator._pad_or_truncate.assert_called_once_with(sentinel_encoding, instance.tokenizer) + assert result is padded + + def test_passthrough_when_evaluator_unset(self) -> None: + from transformers.pipelines.zero_shot_classification import ( + ZeroShotClassificationPipeline, + ) + + from winml.modelkit.eval.zero_shot_classification_evaluator import ( + _FixedShapeZeroShotPipeline, + ) + + sentinel_encoding = {"input_ids": MagicMock()} + + def _fake_super_parse(self, sequence_pairs, **kwargs): + return sentinel_encoding + + instance = _FixedShapeZeroShotPipeline.__new__(_FixedShapeZeroShotPipeline) + instance._winml_evaluator = None + instance.tokenizer = MagicMock() + with patch.object(ZeroShotClassificationPipeline, "_parse_and_tokenize", _fake_super_parse): + result = instance._parse_and_tokenize([("a", "b")]) + + assert result is sentinel_encoding + + +# --------------------------------------------------------------------------- +# prepare_pipeline +# --------------------------------------------------------------------------- + + +def _make_mock_pipe_with_tokenizer(model_input_names: list[str] | None = None) -> MagicMock: + pipe = MagicMock() + pipe.tokenizer = MagicMock() + pipe.tokenizer.model_max_length = 0 + pipe.tokenizer.model_input_names = model_input_names or [ + "input_ids", + "attention_mask", + "token_type_ids", + ] + return pipe + + +class TestPreparePipeline: + @patch("transformers.pipeline") + @patch("datasets.load_dataset") + def test_sets_model_max_length_from_io_config(self, mock_load_ds, mock_pipeline) -> None: + from winml.modelkit.datasets import DatasetConfig + + mock_ds = MagicMock() + mock_ds.__len__ = lambda self: 2 + mock_ds.shuffle.return_value = mock_ds + mock_ds.select.return_value = mock_ds + mock_ds.column_names = ["text", "label"] + mock_load_ds.return_value = mock_ds + + mock_pipe = _make_mock_pipe_with_tokenizer() + mock_pipeline.return_value = mock_pipe + + model = MagicMock() + model.config.label2id = None + model.io_config = {"input_shapes": [[1, 256]]} + + config = WinMLEvaluationConfig( + model_id="test/model", + task="zero-shot-classification", + dataset=DatasetConfig(path="dummy"), + ) + + with patch.object( + WinMLZeroShotClassificationEvaluator, + "align_labels", + side_effect=lambda dataset, ds_config: dataset, + ): + WinMLZeroShotClassificationEvaluator(config, model) + + assert mock_pipe.tokenizer.model_max_length == 256 + + @patch("transformers.pipeline") + @patch("datasets.load_dataset") + def test_filters_tokenizer_input_names(self, mock_load_ds, mock_pipeline) -> None: + from winml.modelkit.datasets import DatasetConfig + + mock_ds = MagicMock() + mock_ds.__len__ = lambda self: 2 + mock_ds.shuffle.return_value = mock_ds + mock_ds.select.return_value = mock_ds + mock_ds.column_names = ["text", "label"] + mock_load_ds.return_value = mock_ds + + mock_pipe = _make_mock_pipe_with_tokenizer( + model_input_names=["input_ids", "attention_mask", "token_type_ids"], + ) + mock_pipeline.return_value = mock_pipe + + model = MagicMock() + model.config.label2id = None + model.io_config = { + "input_shapes": [[1, 128]], + "input_names": ["input_ids", "attention_mask"], + } + + config = WinMLEvaluationConfig( + model_id="test/model", + task="zero-shot-classification", + dataset=DatasetConfig(path="dummy"), + ) + + with patch.object( + WinMLZeroShotClassificationEvaluator, + "align_labels", + side_effect=lambda dataset, ds_config: dataset, + ): + WinMLZeroShotClassificationEvaluator(config, model) + + assert mock_pipe.tokenizer.model_input_names == ["input_ids", "attention_mask"] + + @patch("transformers.pipeline") + @patch("datasets.load_dataset") + def test_no_tokenizer_change_without_io_config(self, mock_load_ds, mock_pipeline) -> None: + from winml.modelkit.datasets import DatasetConfig + + mock_ds = MagicMock() + mock_ds.__len__ = lambda self: 2 + mock_ds.shuffle.return_value = mock_ds + mock_ds.select.return_value = mock_ds + mock_ds.column_names = ["text", "label"] + mock_load_ds.return_value = mock_ds + + mock_pipe = _make_mock_pipe_with_tokenizer() + original_names = list(mock_pipe.tokenizer.model_input_names) + mock_pipeline.return_value = mock_pipe + + model = MagicMock() + model.config.label2id = None + model.io_config = {} + + config = WinMLEvaluationConfig( + model_id="test/model", + task="zero-shot-classification", + dataset=DatasetConfig(path="dummy"), + ) + + with patch.object( + WinMLZeroShotClassificationEvaluator, + "align_labels", + side_effect=lambda dataset, ds_config: dataset, + ): + WinMLZeroShotClassificationEvaluator(config, model) + + # model_max_length untouched and input_names unchanged. + assert mock_pipe.tokenizer.model_max_length == 0 + assert mock_pipe.tokenizer.model_input_names == original_names + + +# --------------------------------------------------------------------------- +# schema_info +# --------------------------------------------------------------------------- + + +class TestSchemaInfo: + def test_schema_has_input_and_label(self) -> None: + cols = WinMLZeroShotClassificationEvaluator.schema_info() + overrides = {c.override for c in cols if c.override} + assert "input_column" in overrides + assert "label_column" in overrides + + def test_schema_has_optional_overrides(self) -> None: + cols = WinMLZeroShotClassificationEvaluator.schema_info() + override_to_required = {c.override: c.required for c in cols if c.override} + assert override_to_required.get("candidate_labels") is False + assert override_to_required.get("hypothesis_template") is False + + +# --------------------------------------------------------------------------- +# align_labels / schema validation +# --------------------------------------------------------------------------- + + +class TestAlignLabels: + def test_valid_dataset_passes(self) -> None: + ds = _make_classlabel_dataset(["a", "b"], [0, 1]) + ev = make_evaluator(ds) + out = ev.align_labels(ds, ev.config.dataset) + assert out is ds + + def test_missing_input_column_raises(self) -> None: + ds = _make_classlabel_dataset(["a"], [0]) + ev = make_evaluator( + ds, + columns_mapping={"input_column": "nope", "label_column": "label"}, + ) + with pytest.raises(ValueError, match="Column 'nope'"): + ev.align_labels(ds, ev.config.dataset) + + def test_missing_label_column_raises(self) -> None: + ds = _make_classlabel_dataset(["a"], [0]) + ev = make_evaluator( + ds, + columns_mapping={"input_column": "text", "label_column": "missing"}, + ) + with pytest.raises(ValueError, match="Column 'missing'"): + ev.align_labels(ds, ev.config.dataset) + + def test_no_alignment_against_nli_label2id(self) -> None: + """Regression: base-class alignment must not be applied.""" + ds = _make_classlabel_dataset(["a", "b"], [0, 1]) + ev = make_evaluator(ds) + out = ev.align_labels(ds, ev.config.dataset) + # Labels unchanged — still 0 and 1 (not remapped to NLI ids) + assert out[0]["label"] == 0 + assert out[1]["label"] == 1 + + +# --------------------------------------------------------------------------- +# _resolve_candidate_labels +# --------------------------------------------------------------------------- + + +class TestResolveCandidateLabels: + def test_user_override_comma_separated(self) -> None: + ds = _make_classlabel_dataset(["a"], [0]) + ev = make_evaluator( + ds, + columns_mapping={ + "input_column": "text", + "label_column": "label", + "candidate_labels": "politics, sports ,tech", + }, + ) + labels = ev._resolve_candidate_labels(ds) + assert labels == ["politics", "sports", "tech"] + + def test_auto_from_classlabel(self) -> None: + ds = _make_classlabel_dataset(["a"], [0]) + ev = make_evaluator(ds) + labels = ev._resolve_candidate_labels(ds) + assert labels == CANDIDATE_LABELS + + def test_string_label_without_override_raises(self) -> None: + ds = _make_string_dataset(["a"], ["World"]) + ev = make_evaluator(ds) + with pytest.raises(ValueError, match="not a ClassLabel"): + ev._resolve_candidate_labels(ds) + + def test_empty_override_raises(self) -> None: + ds = _make_classlabel_dataset(["a"], [0]) + ev = make_evaluator( + ds, + columns_mapping={ + "input_column": "text", + "label_column": "label", + "candidate_labels": ", , ", + }, + ) + with pytest.raises(ValueError, match="empty"): + ev._resolve_candidate_labels(ds) + + +# --------------------------------------------------------------------------- +# compute() +# --------------------------------------------------------------------------- + + +def _pipe_returning(predictions: list[str]) -> MagicMock: + """Build a MagicMock pipeline that emits predictions in order.""" + pipe = MagicMock() + pipe.tokenizer = None + state = {"i": 0} + + def _call(text: str, candidate_labels: list[str], **kwargs: object): + idx = state["i"] + state["i"] += 1 + top = predictions[idx] + ordered = [top] + [c for c in candidate_labels if c != top] + scores = [0.9] + [0.1 / max(1, len(ordered) - 1)] * (len(ordered) - 1) + return {"sequence": text, "labels": ordered, "scores": scores} + + pipe.side_effect = _call + return pipe + + +class TestCompute: + def test_perfect_accuracy_and_f1(self) -> None: + ds = _make_classlabel_dataset( + ["a", "b", "c", "d"], + [0, 1, 2, 3], + ) + # Predictions exactly match gold labels. + pipe = _pipe_returning(["World", "Sports", "Business", "Sci/Tech"]) + ev = make_evaluator(ds, pipe=pipe) + metrics = ev.compute() + assert metrics["accuracy"] == pytest.approx(1.0) + assert metrics["f1"] == pytest.approx(1.0) + + def test_half_accuracy(self) -> None: + ds = _make_classlabel_dataset( + ["a", "b", "c", "d"], + [0, 1, 2, 3], + ) + # 2 out of 4 correct. + pipe = _pipe_returning(["World", "Business", "Business", "World"]) + ev = make_evaluator(ds, pipe=pipe) + metrics = ev.compute() + assert metrics["accuracy"] == pytest.approx(0.5) + + def test_custom_hypothesis_template_passed(self) -> None: + ds = _make_classlabel_dataset(["a"], [0]) + pipe = _pipe_returning(["World"]) + ev = make_evaluator( + ds, + columns_mapping={ + "input_column": "text", + "label_column": "label", + "hypothesis_template": "The topic is {}.", + }, + pipe=pipe, + ) + ev.compute() + _, call_kwargs = pipe.call_args + assert call_kwargs["hypothesis_template"] == "The topic is {}." + + def test_default_template_not_passed_when_unset(self) -> None: + ds = _make_classlabel_dataset(["a"], [0]) + pipe = _pipe_returning(["World"]) + ev = make_evaluator(ds, pipe=pipe) + ev.compute() + _, call_kwargs = pipe.call_args + # Template is omitted so pipeline uses its own default. + assert "hypothesis_template" not in call_kwargs + + def test_candidate_labels_passed_to_pipe(self) -> None: + ds = _make_classlabel_dataset(["a"], [0]) + pipe = _pipe_returning(["World"]) + ev = make_evaluator(ds, pipe=pipe) + ev.compute() + _, call_kwargs = pipe.call_args + assert call_kwargs["candidate_labels"] == CANDIDATE_LABELS + + def test_f1_zero_division_handled(self) -> None: + """Macro F1 should not crash when a class has no predictions.""" + ds = _make_classlabel_dataset( + ["a", "b"], + [0, 1], + ) + # Both predictions collapse to one class — other classes have no preds. + pipe = _pipe_returning(["World", "World"]) + ev = make_evaluator(ds, pipe=pipe) + metrics = ev.compute() + assert 0.0 <= metrics["f1"] <= 1.0 + + def test_string_labels_compute_end_to_end(self) -> None: + ds = _make_string_dataset( + ["a", "b"], + ["World", "Sports"], + ) + pipe = _pipe_returning(["World", "Sports"]) + ev = make_evaluator( + ds, + columns_mapping={ + "input_column": "text", + "label_column": "label", + "candidate_labels": ",".join(CANDIDATE_LABELS), + }, + pipe=pipe, + ) + metrics = ev.compute() + assert metrics["accuracy"] == pytest.approx(1.0) + + def test_override_remaps_classlabel_references(self) -> None: + """Override replaces ClassLabel names positionally for references.""" + ds = _make_classlabel_dataset( + ["a", "b", "c", "d"], + [0, 1, 2, 3], + ) + override = ["politics", "sports", "technology", "science"] + # Predictions in override vocab, perfectly aligned with gold IDs. + pipe = _pipe_returning(override) + ev = make_evaluator( + ds, + columns_mapping={ + "input_column": "text", + "label_column": "label", + "candidate_labels": ",".join(override), + }, + pipe=pipe, + ) + metrics = ev.compute() + assert metrics["accuracy"] == pytest.approx(1.0) + assert metrics["f1"] == pytest.approx(1.0) + + def test_override_length_mismatch_raises(self) -> None: + """Override length must match ClassLabel cardinality.""" + ds = _make_classlabel_dataset( + ["a", "b", "c", "d"], + [0, 1, 2, 3], + ) + pipe = _pipe_returning(["politics", "sports"]) + ev = make_evaluator( + ds, + columns_mapping={ + "input_column": "text", + "label_column": "label", + "candidate_labels": "politics,sports", + }, + pipe=pipe, + ) + with pytest.raises(ValueError, match="one override label per class"): + ev.compute()