Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions scripts/e2e_eval/testsets/models_with_acc.json
Original file line number Diff line number Diff line change
Expand Up @@ -1014,5 +1014,21 @@
"label_column": "answers"
}
}
},
{
"hf_id": "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
"task": "zero-shot-classification",
"model_type": "deberta-v2",
"group": "Top200",
"priority": "P1",
"dataset_config": {
"path": "fancyzhx/ag_news",
"split": "test",
"metric": "accuracy",
"columns_mapping": {
"input_column": "text",
"label_column": "label"
}
}
}
]
1 change: 1 addition & 0 deletions src/winml/modelkit/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"sentence-similarity": TextDataset,
"next-sentence-prediction": TextDataset,
"fill-mask": TextDataset,
"zero-shot-classification": TextDataset,
"image-segmentation": ImageSegmentationDataset,
"random": RandomDataset,
# Add more task types as needed
Expand Down
4 changes: 4 additions & 0 deletions src/winml/modelkit/eval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,20 @@
from .evaluate import EvalResult, evaluate
from .feature_extraction_evaluator import WinMLFeatureExtractionEvaluator
from .image_segmentation_evaluator import WinMLImageSegmentationEvaluator
from .metrics.classification import ClassificationMetric
from .metrics.mean_average_precision import MAPMetric
from .metrics.mean_iou import IGNORE_INDEX, MeanIoUMetric
from .metrics.spearman_correlation import SpearmanCorrelationMetric
from .object_detection_evaluator import WinMLObjectDetectionEvaluator
from .question_answering_evaluator import WinMLQuestionAnsweringEvaluator
from .text_classification_evaluator import WinMLTextClassificationEvaluator
from .token_classification_evaluator import WinMLTokenClassificationEvaluator
from .zero_shot_classification_evaluator import WinMLZeroShotClassificationEvaluator


__all__ = [
"IGNORE_INDEX",
"ClassificationMetric",
"EvalResult",
"MAPMetric",
"MeanIoUMetric",
Expand All @@ -36,5 +39,6 @@
"WinMLQuestionAnsweringEvaluator",
"WinMLTextClassificationEvaluator",
"WinMLTokenClassificationEvaluator",
"WinMLZeroShotClassificationEvaluator",
"evaluate",
]
12 changes: 12 additions & 0 deletions src/winml/modelkit/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .question_answering_evaluator import WinMLQuestionAnsweringEvaluator
from .text_classification_evaluator import WinMLTextClassificationEvaluator
from .token_classification_evaluator import WinMLTokenClassificationEvaluator
from .zero_shot_classification_evaluator import WinMLZeroShotClassificationEvaluator


if TYPE_CHECKING:
Expand All @@ -39,6 +40,7 @@
"question-answering": WinMLQuestionAnsweringEvaluator,
"feature-extraction": WinMLFeatureExtractionEvaluator,
"sentence-similarity": WinMLFeatureExtractionEvaluator,
"zero-shot-classification": WinMLZeroShotClassificationEvaluator,
}

_FE_DEFAULT = DatasetConfig(
Expand Down Expand Up @@ -107,6 +109,16 @@
),
"feature-extraction": _FE_DEFAULT,
"sentence-similarity": _FE_DEFAULT,
"zero-shot-classification": DatasetConfig(
path="fancyzhx/ag_news",
split="test",
samples=100,
shuffle=True,
columns_mapping={
"input_column": "text",
"label_column": "label",
},
),
}


Expand Down
9 changes: 8 additions & 1 deletion src/winml/modelkit/eval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@

"""Evaluation metrics."""

from .classification import ClassificationMetric
from .mean_average_precision import MAPMetric
from .mean_iou import IGNORE_INDEX, MeanIoUMetric
from .spearman_correlation import SpearmanCorrelationMetric


__all__ = ["IGNORE_INDEX", "MAPMetric", "MeanIoUMetric", "SpearmanCorrelationMetric"]
__all__ = [
"IGNORE_INDEX",
"ClassificationMetric",
"MAPMetric",
"MeanIoUMetric",
"SpearmanCorrelationMetric",
]
56 changes: 56 additions & 0 deletions src/winml/modelkit/eval/metrics/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

"""Classification metrics.

Accuracy and macro-F1 over string labels, for classification evaluators
that do not have an HF evaluate wrapper (e.g. zero-shot-classification).
"""

from __future__ import annotations

from typing import Any


class ClassificationMetric:
"""Accuracy and macro-F1 over string labels."""

def compute(
self,
predictions: list[str],
references: list[str],
labels: list[str],
) -> dict[str, Any]:
"""Compute accuracy and macro-F1.

Args:
predictions: Predicted label strings, one per sample.
references: Ground-truth label strings, one per sample.
labels: Full set of class labels for macro-F1 averaging.

Returns:
Dict with ``accuracy`` and ``f1`` (both floats in [0, 1]).
"""
from sklearn.metrics import accuracy_score, f1_score

if len(predictions) != len(references):
raise ValueError(
f"predictions and references must have the same length, "
f"got {len(predictions)} vs {len(references)}.",
)
if not references:
raise ValueError("references must not be empty.")
if not labels:
raise ValueError("labels must not be empty.")

accuracy = accuracy_score(references, predictions)
macro_f1 = f1_score(
references,
predictions,
labels=labels,
average="macro",
zero_division=0,
)
return {"accuracy": float(accuracy), "f1": float(macro_f1)}
175 changes: 175 additions & 0 deletions src/winml/modelkit/eval/zero_shot_classification_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

"""Zero-shot classification evaluator for NLI checkpoints.

Computes accuracy and macro-F1 via ClassificationMetric.
HF evaluate library has no zero-shot-classification evaluator, so this
class runs the metric loop manually: each text is scored against every
candidate label as a (premise, hypothesis) NLI pair, and the label with
the top entailment score wins.

Candidate labels come from ``columns_mapping["candidate_labels"]`` if set
(comma-separated), otherwise from ``dataset.features[label_column].names``
when the column is a ``ClassLabel``.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from transformers.pipelines.zero_shot_classification import ZeroShotClassificationPipeline

from .base_evaluator import WinMLEvaluator


if TYPE_CHECKING:
from datasets import Dataset
from transformers.pipelines.base import Pipeline

from ..datasets.config import DatasetConfig
from ..models.winml.base import WinMLPreTrainedModel
from .config import WinMLEvaluationConfig


class _FixedShapeZeroShotPipeline(ZeroShotClassificationPipeline):
"""Pad to ``tokenizer.model_max_length`` for fixed-shape ONNX exports."""

def _parse_and_tokenize(self, sequence_pairs: Any, **kwargs: Any) -> Any:
kwargs["padding"] = "max_length"
kwargs.setdefault("truncation", True)
return super()._parse_and_tokenize(sequence_pairs, **kwargs)


class WinMLZeroShotClassificationEvaluator(WinMLEvaluator):
"""Evaluator for zero-shot text classification using NLI models."""

@classmethod
def schema_info(cls) -> list:
"""Return expected dataset schema for zero-shot classification."""
from .config import SchemaColumn

return [
SchemaColumn("text", "Value(string)", "input_column", description="input text"),
SchemaColumn(
"label",
"ClassLabel",
"label_column",
description="gold label (ClassLabel or string)",
),
SchemaColumn(
"<candidate_labels>",
"comma-separated str",
"candidate_labels",
required=False,
description="override candidate labels (required for non-ClassLabel columns)",
),
SchemaColumn(
"<hypothesis_template>",
"Value(string)",
"hypothesis_template",
required=False,
description='NLI hypothesis template (default: "This example is {}.")',
),
]

def __init__(
self,
config: WinMLEvaluationConfig,
model: WinMLPreTrainedModel,
) -> None:
mapping = config.dataset.columns_mapping
self._input_col = mapping.get("input_column", "text")
self._label_col = mapping.get("label_column", "label")
self._candidate_labels_override = mapping.get("candidate_labels")
self._hypothesis_template = mapping.get("hypothesis_template")
super().__init__(config, model)

def prepare_pipeline(self) -> Pipeline:
"""Create pipeline with fixed-length tokenization for ONNX."""
from transformers import pipeline

io_config = getattr(self.model, "io_config", None) or {}
shapes = io_config.get("input_shapes", [[]])
max_length: int | None = None
if shapes and len(shapes[0]) > 1 and isinstance(shapes[0][1], int):
max_length = shapes[0][1]

pipe = pipeline(
"zero-shot-classification",
model=self.model,
framework="pt",
tokenizer=self.config.model_id,
device="cpu",
pipeline_class=_FixedShapeZeroShotPipeline,
)

if pipe.tokenizer is not None and max_length is not None:
pipe.tokenizer.model_max_length = max_length

# Drop tokenizer keys the ONNX graph does not accept
# (e.g. DeBERTa-v3 MNLI exports omit token_type_ids).
input_names = io_config.get("input_names", [])
if input_names:
filtered = [n for n in pipe.tokenizer.model_input_names if n in input_names]
if filtered:
pipe.tokenizer.model_input_names = filtered

return pipe

def align_labels(
self,
dataset: Dataset,
ds_config: DatasetConfig,
) -> Dataset:
"""Validate input and label columns.

Base-class label alignment is bypassed: NLI ``label2id`` identifies
entailment/neutral/contradiction classes, which are unrelated to the
ground-truth labels used for accuracy.
"""
col_names = set(dataset.column_names)
for col in (self._input_col, self._label_col):
if col not in col_names:
raise ValueError(f"Column '{col}' not found in dataset: {sorted(col_names)}")
return dataset

def _resolve_candidate_labels(self, dataset: Dataset) -> list[str]:
"""Return candidate labels from user override or dataset ``ClassLabel``."""
if self._candidate_labels_override:
labels = [s.strip() for s in self._candidate_labels_override.split(",") if s.strip()]
if not labels:
raise ValueError("candidate_labels override must not be empty.")
return labels

names = getattr(dataset.features.get(self._label_col), "names", None)
if names:
return list(names)

raise ValueError(
f"Column '{self._label_col}' is not a ClassLabel; pass "
f'--column "candidate_labels=a,b,...".',
)

def compute(self) -> dict[str, Any]:
"""Compute accuracy and macro-F1 over all samples."""
from .metrics import ClassificationMetric

candidate_labels = self._resolve_candidate_labels(self.data)
class_names = getattr(self.data.features.get(self._label_col), "names", None)

pipe_kwargs: dict[str, Any] = {"candidate_labels": candidate_labels}
if self._hypothesis_template is not None:
pipe_kwargs["hypothesis_template"] = self._hypothesis_template

predictions: list[str] = []
references: list[str] = []
for sample in self.data:
result = self.pipe(sample[self._input_col], **pipe_kwargs)
predictions.append(result["labels"][0])
raw = sample[self._label_col]
references.append(class_names[int(raw)] if class_names else str(raw))

return ClassificationMetric().compute(predictions, references, candidate_labels)
Empty file.
53 changes: 53 additions & 0 deletions tests/integration/eval/test_zero_shot_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""End-to-end integration test for zero-shot classification evaluation.

Downloads a small NLI checkpoint and runs the evaluator against a handful
of samples from AG News. Skipped by default via `pytest -m "not slow"`.
"""

from __future__ import annotations

import pytest

from winml.modelkit.datasets.config import DatasetConfig
from winml.modelkit.eval import WinMLZeroShotClassificationEvaluator
from winml.modelkit.eval.config import WinMLEvaluationConfig


# Representative NLI checkpoints across the three families listed in issue #325.
_MODEL_IDS = [
"typeform/distilbert-base-uncased-mnli",
"cross-encoder/nli-roberta-base",
"MoritzLaurer/deberta-v3-base-zeroshot-v2.0",
]


@pytest.mark.slow
@pytest.mark.network
@pytest.mark.integration
@pytest.mark.parametrize("model_id", _MODEL_IDS)
def test_zero_shot_classification_end_to_end(model_id: str) -> None:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(model_id)

config = WinMLEvaluationConfig(
model_id=model_id,
task="zero-shot-classification",
dataset=DatasetConfig(
path="fancyzhx/ag_news",
split="test",
samples=5,
shuffle=False,
),
)

results = WinMLZeroShotClassificationEvaluator(config, model).compute()

assert "accuracy" in results
assert "f1" in results
assert 0.0 <= results["accuracy"] <= 1.0
assert 0.0 <= results["f1"] <= 1.0
Loading