Skip to content

feat: implement fill-mask model task evaluator#341

Merged
zhenchaoni merged 8 commits into
mainfrom
private/zhenni/fm_eval_v2
Apr 22, 2026
Merged

feat: implement fill-mask model task evaluator#341
zhenchaoni merged 8 commits into
mainfrom
private/zhenni/fm_eval_v2

Conversation

@zhenchaoni

@zhenchaoni zhenchaoni commented Apr 14, 2026

Copy link
Copy Markdown
Member

Implement #317

Fill-Mask Evaluator Design

1. Overview

Fill-mask models (BERT, RoBERTa, DistilBERT, etc.) are Masked Language Models (MLM). Given a sentence with one or more tokens replaced by [MASK], the model predicts the original token.

Example:

Input:  "The cat [MASK] on the mat."
Output: "The cat sat on the mat."   (predicted token: "sat", score: 0.82)

2. Usage

uv run winml eval \
    -m ~/.cache/winml/artifacts/google-bert_bert-base-uncased/mask_e7ec673175d3b94d_model.onnx \
    --model-id google-bert/bert-base-uncased \
    --device npu \
    --task fill-mask \
    --dataset Salesforce/wikitext \
    --dataset-name wikitext-2-raw-v1 \
    --split test \
    --samples 100 \
    --column input_column=text

Output:

╭───────────────────────────────────────────╮
│ Evaluation: google-bert/bert-base-uncased │
╰───────────────────────────────────────────╯

Task:       fill-mask
Device:     npu
Dataset:    Salesforce/wikitext
Samples:    100

┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
┃ Metric             ┃  Value ┃
┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
│ pseudo_perplexity  │ 4.3528 │
│ nll                │ 1.4708 │
└────────────────────┴────────┘

3. Metric

Pseudo-perplexity (PPPL) — the community-standard score for comparing Masked Language Models, introduced by Salazar et al. 2020 ("Masked Language Model Scoring", https://arxiv.org/abs/1910.14659).

For each real token $w_i$ in the corpus, we mask only that one position and measure the model's log-probability of the original token given the rest of the sentence:

$$\text{PLL} = \sum_{i=1}^{N} \log P(w_i \mid w_{\setminus i})$$

$$\text{PPPL} = \exp!\left(-\frac{\text{PLL}}{N}\right)$$

where $N$ is the total number of scored tokens and $w_{\setminus i}$ denotes the sentence with position $i$ replaced by [MASK].

Why pseudo-perplexity and not "regular" perplexity?

Classical perplexity is defined for autoregressive LMs via the chain rule: $P(w_1,\dots,w_N) = \prod_i P(w_i \mid w_{<i})$. MLMs are bidirectional — they only provide conditionals $P(w_i \mid w_{\setminus i})$, which do not come from any consistent joint distribution. So standard perplexity is mathematically undefined for MLMs. "Pseudo-likelihood" (Besag 1975) is the tractable surrogate used when a joint is unavailable; "pseudo-perplexity" is its exponentiated, normalized form.

Why pseudo-perplexity and not exp(MLM-loss) with 15% masking?

The 15% random-masking protocol (BERT's training objective with the 80/10/10 split) is sometimes repurposed as an eval metric, but it has drawbacks:

  • Protocol-dependent — the numeric value depends on which random tokens get masked, the seed, and the 80/10/10 probabilities. Two reruns give slightly different numbers.
  • Degraded context — when scoring token $i$, ~15% of the surrounding tokens are also masked, so the model isn't operating in its natural full-context regime. Quantization errors get partly absorbed by the noise this introduces, which dilutes the signal for regression detection.

PPPL avoids both: it scores every real token (deterministic), and each is scored with a clean N−1 context — the regime the model is actually used in.

Why not top-k accuracy?

A top-1 accuracy metric only checks whether the correct token is the model's No.1 prediction. PPPL captures the full probability distribution — a model that assigns 40% to the correct token scores much better than one that assigns 5%, even if both get it "wrong" by top-1. This makes PPPL more sensitive to quality differences, especially when comparing ONNX-quantized models against PyTorch baselines.

Interpretation. PPPL is the effective branching factor of the model: a PPPL of 4 means the model is, on average, as uncertain as if it were choosing uniformly between 4 candidates per token. Lower is better. Published values for BERT-family MLMs on English text sit in the 3–10 range.

4. Implementation

The evaluator processes each text sample in four steps:

  1. Tokenize — Convert text to token IDs with the model's tokenizer. Pad to the ONNX model's fixed sequence length if required.

  2. Identify real positions — Use tokenizer.get_special_tokens_mask and the pad-token ID to filter out [CLS], [SEP], and padding. Only real content tokens are scored.

  3. Mask one at a time and infer — For each real position $i$, replace only that position with [MASK], run the model forward, and gather the log-softmax probability of the original token at position $i$. The rest of the sentence stays intact (N−1 correct context). One forward pass per scored token.

  4. AggregatePseudoPerplexityMetric accumulates per-token log-probabilities across the corpus, then returns mean NLL and $\exp(\text{mean NLL})$ as pseudo_perplexity.

                  ┌────────────────────────────────┐
                  │  For each real position i:     │
                  │    set input_ids[i] = [MASK]   │
                  │    restore after forward       │
                  └────────────────────────────────┘
                                 │
┌──────────┐    ┌──────────────┐    ┌───────────┐    ┌─────────────────┐
│ Tokenize │───>│ Mask single  │───>│ Model     │───>│ PseudoPerplexity│
│          │    │ position     │    │ inference │    │ Metric          │
│ text →   │    │ input_ids[i] │    │ logits    │    │ accumulate      │
│ tokens   │    │ = mask_id    │    │ [seq, V]  │    │ log P(w_i|w\i)  │
└──────────┘    └──────────────┘    └───────────┘    └─────────────────┘

The WinML ONNX builds for NPU use a fixed batch=1 shape — so one forward per masked position is the natural access pattern, with no batching gymnastics. Same code path handles the HF PyTorch baseline (dynamic batch) without changes.

5. Evaluation results

Run: run_eval.py --eval-type accuracy --device npu on all 8 fill-mask models in the registry (100 samples each, wikitext-2-raw-v1 test split).

Model ONNX PPPL Baseline PPPL Δ relative (lower better) Verdict
bert-base-uncased 4.3528 4.2981 +1.27% PASS
roberta-base 5.3561 5.2706 +1.62% PASS
xlm-roberta-large 2.7162 2.6770 +1.46% PASS
roberta-large 5.4702 5.3352 +2.53% PASS
bert-base-multilingual-uncased 4.6449 4.4645 +4.04% PASS
xlm-roberta-base 3.9015 3.6435 +7.08% AT_RISK
bert-base-multilingual-cased 5.7810 5.2262 +10.62% REGRESSION
distilbert-base-uncased 3339.66 6.6132 +50 400% REGRESSION

Thresholds: |Δ relative| < 5% → PASS, < 10% → AT_RISK, ≥ 10% → REGRESSION. The 5% boundary corresponds to a per-token NLL increase of ~0.05 nats — small but user-perceptible quality loss — and sits 10× above the paired-comparison noise floor (~0.5% relative at 10k tokens).

Quantization-only regression, confirmed

For both REGRESSION cases (distilbert, mbert-cased), rerunning the un-quantized fp32 optimized ONNX on NPU isolates the cause:

Model PyTorch baseline fp32 ONNX (NPU) w8a16 quantized ONNX (NPU) Regression source
distilbert-base-uncased 6.6132 6.5940 (−0.29%) 3339.66 (+50 400%) Quantization
bert-base-multilingual-cased 5.2262 5.2285 (+0.04%) 5.781 (+10.62%) Quantization

The fp32 ONNX matches the PyTorch baseline within noise in both cases, so export and ORT optimization are clean. The entire regression is introduced by the w8a16 quantization stage.

@zhenchaoni zhenchaoni requested a review from a team as a code owner April 14, 2026 06:18
Comment thread scripts/e2e_eval/cache/baseline_cache.json Outdated
Comment thread src/winml/modelkit/eval/fill_mask_evaluator.py
Comment thread src/winml/modelkit/eval/fill_mask_evaluator.py Outdated
Comment thread scripts/e2e_eval/testsets/models_with_acc.json
@zhenchaoni zhenchaoni merged commit ce6c358 into main Apr 22, 2026
9 of 15 checks passed
@zhenchaoni zhenchaoni deleted the private/zhenni/fm_eval_v2 branch April 22, 2026 03:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants