-
Notifications
You must be signed in to change notification settings - Fork 1
fix(eval): fix embedding averaging in privacy metrics #81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
3d557e2
fix embedding average with unequal weighting, DRY & test
3a4f833
make format && make lint
f930523
nit type
nina-xu 6c07468
skip tests when sentence transformers is not available
nina-xu da805bf
guide import of sentence transformers; formating
nina-xu 01987a5
reformat docstrings
nina-xu f80594f
reapply changes in pr 141
nina-xu 11edf78
fix type checking
nina-xu 559cfc2
update tests
nina-xu 8642450
make format
nina-xu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
78 changes: 78 additions & 0 deletions
78
src/nemo_safe_synthesizer/evaluation/components/privacy_metric_utils.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,78 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| import pandas as pd | ||
| import torch | ||
|
|
||
| if TYPE_CHECKING: | ||
| from sentence_transformers import SentenceTransformer | ||
|
|
||
| from ...artifacts.analyzers.field_features import describe_field | ||
|
|
||
|
|
||
| def find_text_fields(df: pd.DataFrame) -> list[str]: | ||
| """Identify columns in ``df`` whose content is free-form text. | ||
| Each column is passed through ``describe_field``; those classified | ||
| as ``"text"`` are returned. | ||
| Args: | ||
| df: DataFrame whose columns are inspected. | ||
| Returns: | ||
| Column names classified as free-form text. | ||
| """ | ||
| text_fields: list[str] = [] | ||
| for col in df.columns: | ||
| field_info = describe_field(col, df[col]) | ||
| if field_info.type.value == "text": | ||
| text_fields.append(col) | ||
| return text_fields | ||
|
|
||
|
|
||
| def divide_tabular_text(df: pd.DataFrame, text_fields: list[str]) -> tuple[pd.DataFrame, pd.DataFrame]: | ||
| """Split ``df`` into a tabular-only and a text-only DataFrame. | ||
| Columns present in ``text_fields`` go into the text DataFrame; the | ||
| remaining columns go into the tabular DataFrame. | ||
| Args: | ||
| df: Source DataFrame to split. | ||
| text_fields: Column names to treat as text. | ||
| Returns: | ||
| A ``(tabular_df, text_df)`` tuple where ``tabular_df`` contains only | ||
| the non-text columns and ``text_df`` contains only the text columns. | ||
| """ | ||
| tabular_fields = [col for col in df.columns if col not in text_fields] | ||
| return df.filter(tabular_fields), df.filter(text_fields) | ||
|
|
||
|
|
||
| def embed_text(df: pd.DataFrame, embedder: SentenceTransformer) -> pd.DataFrame: | ||
| """Embed every text column in ``df`` and return a single averaged embedding per row. | ||
| For each column the ``embedder`` produces a ``(n_rows, embed_dim)`` matrix. | ||
| The per-column matrices are stacked and averaged across columns so that | ||
| every column contributes equally to the final embedding. | ||
| Args: | ||
| df: DataFrame whose columns are all text to be embedded. | ||
| embedder: Sentence-transformer model used to produce embeddings. | ||
| Returns: | ||
| Single-column DataFrame with column ``"embedding"`` whose values are | ||
| 1-D tensors of shape ``(embed_dim,)``. | ||
| """ | ||
| embeddings = {} | ||
| for col in df.columns: | ||
| data = [str(r) for r in df[col].to_list()] | ||
| embeddings[col] = torch.as_tensor(embedder.encode(data, show_progress_bar=False, convert_to_tensor=True)) | ||
|
|
||
| stacked = torch.stack([embeddings[col] for col in df.columns], dim=0) # shape: (n_cols, n_rows, embed_dim) | ||
| avg_embeddings = torch.mean(stacked, dim=0) # shape: (n_rows, embed_dim) | ||
|
|
||
| return pd.DataFrame({"embedding": list(avg_embeddings)}) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from unittest.mock import MagicMock | ||
|
|
||
| import numpy as np | ||
| import pandas as pd | ||
| import pytest | ||
| import torch | ||
|
|
||
| from nemo_safe_synthesizer.evaluation.components.privacy_metric_utils import ( | ||
| divide_tabular_text, | ||
| embed_text, | ||
| ) | ||
nina-xu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| @pytest.fixture | ||
| def mock_embedder(): | ||
| """A mock SentenceTransformer whose .encode() returns deterministic tensors.""" | ||
| embedder = MagicMock() | ||
|
|
||
| def _encode(data, **kwargs): | ||
| # Return a distinct but deterministic embedding per string. | ||
| # Use the length of each string as a simple seed for reproducibility. | ||
| return torch.tensor([[float(len(s)), float(len(s)) * 2, float(len(s)) * 3] for s in data], dtype=torch.float32) | ||
|
|
||
| embedder.encode = MagicMock(side_effect=_encode) | ||
| return embedder | ||
|
|
||
|
|
||
| def test_divide_tabular_text(train_df): | ||
| text_fields = ["text", "other"] | ||
| tabular, text = divide_tabular_text(train_df, text_fields) | ||
|
|
||
| assert "text" not in tabular.columns | ||
| assert "other" not in tabular.columns | ||
| assert set(text.columns) == {"other", "text"} | ||
| assert len(tabular) == len(train_df) | ||
| assert len(text) == len(train_df) | ||
|
|
||
|
|
||
| def test_embed_text(mock_embedder): | ||
| """Regression test: with 3+ columns the old pairwise-averaging code | ||
| over-weighted later columns. The corrected stack/mean reduction must give | ||
| each column equal weight. | ||
| """ | ||
| df = pd.DataFrame( | ||
| { | ||
| "a": ["x"], # len 1 → embedding [1, 2, 3] | ||
| "b": ["xx"], # len 2 → embedding [2, 4, 6] | ||
| "c": ["xxxx"], # len 4 → embedding [4, 8, 12] | ||
| } | ||
| ) | ||
| result = embed_text(df, mock_embedder) | ||
|
|
||
| # True mean of [1,2,3], [2,4,6], [4,8,12] across columns: | ||
| # = [(1+2+4)/3, (2+4+8)/3, (3+6+12)/3] = [7/3, 14/3, 21/3] | ||
| expected = np.array([7 / 3, 14 / 3, 7.0]) | ||
| assert isinstance(result["embedding"].iloc[0], torch.Tensor) | ||
| np.testing.assert_array_almost_equal(result["embedding"].iloc[0].numpy(), expected) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.