Skip to content

Commit

Permalink
Issue #1602 - Add get_label_instances to Analysis (#1608)
Browse files Browse the repository at this point in the history
Addresses #1602. Added a method to analysis/error_analysis that wraps get_label_buckets functionality. Given a bucket, a NumPy array x of your data, and corresponding y label(s), it will return to you x with only the instances corresponding to that bucket.
  • Loading branch information
DavidKoleczek authored Sep 5, 2020
1 parent 86a21a2 commit ed77718
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/packages/analysis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ Generic model analysis utilities shared across Snorkel.

Scorer
get_label_buckets
get_label_instances
metric_score
2 changes: 1 addition & 1 deletion snorkel/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Generic model analysis utilities shared across Snorkel."""

from .error_analysis import get_label_buckets # noqa: F401
from .error_analysis import get_label_buckets, get_label_instances # noqa: F401
from .metrics import metric_score # noqa: F401
from .scorer import Scorer # noqa: F401
61 changes: 61 additions & 0 deletions snorkel/analysis/error_analysis.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections import defaultdict
from typing import DefaultDict, Dict, List, Tuple

Expand Down Expand Up @@ -55,3 +56,63 @@ def get_label_buckets(*y: np.ndarray) -> Dict[Tuple[int, ...], np.ndarray]:
for i, labels in enumerate(zip(*y_flat)):
buckets[labels].append(i)
return {k: np.array(v) for k, v in buckets.items()}


def get_label_instances(
bucket: Tuple[int, ...], x: np.ndarray, *y: np.ndarray
) -> np.ndarray:
"""Return instances in x with the specified combination of labels.
Parameters
----------
bucket
A tuple of label values corresponding to which instances from x are returned
x
NumPy array of data instances to be returned
*y
A list of np.ndarray of (int) labels
Returns
-------
np.ndarray
NumPy array of instances from x with the specified combination of labels
Example
-------
A common use case is calling ``get_label_instances(bucket, x.to_numpy(), Y_gold, Y_pred)``
where ``x`` is a NumPy array of data instances that the labels correspond to,
``Y_gold`` is a list of gold (i.e. ground truth) labels, and
``Y_pred`` is a corresponding list of predicted labels.
>>> import pandas as pd
>>> x = pd.DataFrame(data={'col1': ["this is a string", "a second string", "a third string"], 'col2': ["1", "2", "3"]})
>>> Y_gold = np.array([1, 1, 1])
>>> Y_pred = np.array([1, 0, 0])
>>> bucket = (1, 0)
The returned NumPy array of data instances from ``x`` will correspond to
the rows where the first list had a 1 and the second list had a 0.
>>> get_label_instances(bucket, x.to_numpy(), Y_gold, Y_pred)
array([['a second string', '2'],
['a third string', '3']], dtype=object)
More generally, given bucket ``(i, j, ...)`` and lists ``y1, y2, ...``
the returned data instances from ``x`` will correspond to the rows where
y1 had label i, y2 had label j, and so on. Note that ``x`` and ``y``
must all be the same length.
"""
if len(y) != len(bucket):
raise ValueError("Number of lists must match the amount of labels in bucket")
if x.shape[0] != len(y[0]):
# Note: the check for all y having the same number of elements occurs in get_label_buckets
raise ValueError(
"Number of rows in x does not match number of elements in at least one label list"
)
buckets = get_label_buckets(*y)
try:
indices = buckets[bucket]
except KeyError:
logging.warning("Bucket" + str(bucket) + " does not exist.")
return np.array([])
instances = x[indices]
return instances
38 changes: 37 additions & 1 deletion test/analysis/test_error_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from snorkel.analysis import get_label_buckets
from snorkel.analysis import get_label_buckets, get_label_instances


class ErrorAnalysisTest(unittest.TestCase):
Expand Down Expand Up @@ -37,6 +37,42 @@ def test_get_label_buckets_bad_shape(self) -> None:
with self.assertRaisesRegex(ValueError, "same number of elements"):
get_label_buckets(np.array([0, 1, 1]), np.array([1, 1]))

def test_get_label_instances(self) -> None:
x = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
y1 = np.array([1, 0, 0, 0])
y2 = np.array([1, 1, 1, 0])
instances = get_label_instances((0, 1), x, y1, y2)
expected_instances = np.array([[3, 4], [5, 6]])
np.testing.assert_equal(instances, expected_instances)

x = np.array(["this", "is", "a", "test", "of", "multi"])
y1 = np.array([[2], [1], [3], [1], [1], [3]])
y2 = np.array([1, 2, 3, 1, 2, 3])
y3 = np.array([[3], [2], [1], [1], [2], [3]])
instances = get_label_instances((3, 3, 3), x, y1, y2, y3)
expected_instances = np.array(["multi"])
np.testing.assert_equal(instances, expected_instances)

def test_get_label_instances_exceptions(self) -> None:
x = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
y1 = np.array([1, 0, 0, 0])
y2 = np.array([1, 1, 1, 0])
instances = get_label_instances((2, 0), x, y1, y2)
expected_instances = np.array([])
np.testing.assert_equal(instances, expected_instances)

with self.assertRaisesRegex(
ValueError, "Number of lists must match the amount of labels in bucket"
):
get_label_instances((1, 0), x, y1)

x = np.array([[1, 2], [3, 4], [5, 6]])
with self.assertRaisesRegex(
ValueError,
"Number of rows in x does not match number of elements in at least one label list",
):
get_label_instances((1, 0), x, y1, y2)


if __name__ == "__main__":
unittest.main()

0 comments on commit ed77718

Please sign in to comment.