Skip to content

Commit c1f79b0

Browse files
committed
support worst case membership score
Signed-off-by: ron-shm <[email protected]>
1 parent 697d745 commit c1f79b0

File tree

4 files changed

+428
-44
lines changed

4 files changed

+428
-44
lines changed

art/metrics/privacy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
Module providing metrics and verifications.
33
"""
44
from art.metrics.privacy.membership_leakage import PDTP
5+
from art.metrics.privacy.worst_case_mia_score import get_roc_for_fpr, get_roc_for_multi_fprs
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2020
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
"""
19+
This module implements metric for inference attack worst case accuracy measurement.
20+
"""
21+
from __future__ import absolute_import, division, print_function, unicode_literals
22+
from typing import Optional, List, Tuple
23+
24+
import numpy as np
25+
from sklearn.metrics import roc_curve
26+
import logging
27+
28+
29+
TPR = float
30+
FPR = float
31+
THR = float
32+
33+
34+
def _calculate_roc_for_fpr(y_true, y_proba, targeted_fpr):
35+
""" Get FPR, TPR and, THRESHOLD based on the targeted_fpr (such that FPR <= targeted_fpr) """
36+
fpr, tpr, thr = roc_curve(y_true=y_true, y_score=y_proba)
37+
# take the highest fpr and an appropriated threshold that achieve at least FPR=fpr
38+
if np.isnan(fpr).all() or np.isnan(tpr).all():
39+
logging.warning("TPR or FPR values are NaN")
40+
return None, None, None
41+
else:
42+
targeted_fpr_idx = np.where(fpr <= targeted_fpr)[0][-1]
43+
return fpr[targeted_fpr_idx], tpr[targeted_fpr_idx], thr[targeted_fpr_idx]
44+
45+
46+
def get_roc_for_fpr( # pylint: disable=C0103
47+
attack_proba: np.ndarray,
48+
attack_true: np.ndarray,
49+
target_model_labels: Optional[np.ndarray] = None,
50+
targeted_fpr: Optional[float] = 0.001,
51+
) -> List[Tuple[Optional[int], FPR, TPR, THR]]:
52+
"""
53+
Compute the attack TPR, THRESHOLD and achieved FPR based on the targeted FPR. This implementation supports only
54+
binary attack prediction labels {0,1}. The returned THRESHOLD defines the decision threshold on the attack
55+
probabilities (meaning if p < THRESHOLD predict 0, otherwise predict 1)
56+
| Related paper link: https://arxiv.org/abs/2112.03570
57+
58+
:param attack_proba: Predicted attack probabilities.
59+
:param attack_true: True attack labels.
60+
:param targeted_fpr: the targeted False Positive Rate, attack accuracy will be calculated based on this FPRs.
61+
If not supplied, get_roc_for_fpr will be computed for `0.001` FPR.
62+
:param target_model_labels: Original labels, if provided the Accuracy and threshold will be calculated per each
63+
class separately.
64+
:return: list of tuples the contains (original label (if target_model_labels is provided),
65+
Achieved FPR, TPR, Threshold).
66+
"""
67+
68+
if attack_proba.shape[0] != attack_true.shape[0]:
69+
raise ValueError("Number of rows in attack_pred and attack_true do not match")
70+
if target_model_labels is not None and attack_proba.shape[0] != target_model_labels.shape[0]:
71+
raise ValueError("Number of rows in target_model_labels and attack_pred do not match")
72+
73+
results = []
74+
75+
if target_model_labels is not None:
76+
values, _ = np.unique(target_model_labels, return_counts=True)
77+
for v in values:
78+
idxs = np.where(target_model_labels == v)[0]
79+
fpr, tpr, thr = _calculate_roc_for_fpr(y_proba=attack_proba[idxs],
80+
y_true=attack_true[idxs],
81+
targeted_fpr=targeted_fpr)
82+
results.append((v, fpr, tpr, thr))
83+
else:
84+
fpr, tpr, thr = _calculate_roc_for_fpr(y_proba=attack_proba,
85+
y_true=attack_true,
86+
targeted_fpr=targeted_fpr)
87+
results.append((fpr, tpr, thr))
88+
89+
return results
90+
91+
92+
def get_roc_for_multi_fprs(
93+
attack_proba: np.ndarray,
94+
attack_true: np.ndarray,
95+
targeted_fprs: np.ndarray,
96+
) -> Tuple[List[FPR], List[TPR], List[THR]]:
97+
"""
98+
Compute the attack ROC based on the targeted FPRs. This implementation supports only binary
99+
attack prediction labels. The returned list of THRESHOLDs defines the decision threshold on the attack
100+
probabilities (meaning if p < THRESHOLD predict 0, otherwise predict 1) for each provided fpr
101+
102+
| Related paper link: https://arxiv.org/abs/2112.03570
103+
104+
:param attack_proba: Predicted attack probabilities.
105+
:param attack_true: True attack labels.
106+
:param targeted_fprs: the set of targeted FPR (False Positive Rates), attack accuracy will be calculated based on
107+
these FPRs.
108+
:return: list of tuples that (TPR, Threshold, Achieved FPR).
109+
"""
110+
111+
if attack_proba.shape[0] != attack_true.shape[0]:
112+
raise ValueError("Number of rows in attack_pred and attack_true do not match")
113+
114+
tpr = list()
115+
thr = list()
116+
fpr = list()
117+
118+
for t_fpr in targeted_fprs:
119+
res = _calculate_roc_for_fpr(y_proba=attack_proba, y_true=attack_true, targeted_fpr=t_fpr)
120+
121+
fpr.append(res[0])
122+
tpr.append(res[1])
123+
thr.append(res[2])
124+
125+
return fpr, tpr, thr

0 commit comments

Comments
 (0)