Skip to content
This repository was archived by the owner on Jul 28, 2025. It is now read-only.

CU-86983ruw9 Fix test train split #521

Merged
merged 3 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions medcat/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,9 @@ def prepare_from_json_chars(data: Dict,
return out_data


def make_mc_train_test(data: Dict, cdb: CDB, test_size: float = 0.2) -> Tuple:
def make_mc_train_test(data: Dict, cdb: CDB, test_size: float = 0.2,
min_test_count: int = 10,
max_test_fraction: float = 0.3) -> Tuple:
"""Make train set.

This is a disaster.
Expand All @@ -823,6 +825,10 @@ def make_mc_train_test(data: Dict, cdb: CDB, test_size: float = 0.2) -> Tuple:
data (Dict): The data.
cdb (CDB): The concept database.
test_size (float): The test size. Defaults to 0.2.
min_test_count (int): The minimum numer of examples of a concepts
for it to be considered for the test set. Defaults to 10.
max_test_fraction (float): The maximum fraction of a concept
in the test set. Defaults to 0.3

Returns:
Tuple: The train set, the test set, the test annotations, and the total annotations
Expand Down Expand Up @@ -912,14 +918,17 @@ def make_mc_train_test(data: Dict, cdb: CDB, test_size: float = 0.2) -> Tuple:


# Did we get more than 30% of concepts for any CUI with >=10 cnt
is_test = True
for cui, v in _cnts.items():
if (v + test_cnts.get(cui, 0)) / cnts[cui] > 0.3:
if cnts[cui] >= 10:
# We only care for concepts if count >= 10, else they will be ignored
#during the test phase (for all metrics and similar)
is_test = False
break
# NOTE: This implementation is true to the INTENT of the previous one
# but the previous one would act quite a bit differently since
# the logic was flawed. The previous implementation guaranteed
# any document with only rare concepts (i.e ones with fewer than 10
# examples across the entire dataset) would get a chance to be
# included in the test set (as long as the test size wasn't met)
is_test = any(
cnts[cui] >= min_test_count and
(v + test_cnts.get(cui, 0)) / cnts[cui] < max_test_fraction
for cui, v in _cnts.items()
)

# Add to test set
if is_test and np.random.rand() < test_prob:
Expand Down
90 changes: 90 additions & 0 deletions tests/utils/test_data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import os
import json
from copy import deepcopy

from medcat.utils import data_utils
from medcat.stats.mctexport import count_all_annotations, count_all_docs

from unittest import TestCase


class FakeCDB:

def __init__(self):
self.cui2tui = {}

def get_name(self, cui: str) -> str:
return cui


class TestTrainSplitTestsBase(TestCase):
file_name = os.path.join(os.path.dirname(__file__),
"..", "resources", "medcat_trainer_export.json")
allowed_doc_ids = {3204, 3205}
test_size = 0.2
expect_empty_train_set = False
expect_empty_test_set = False
seed = None

@classmethod
def setUpClass(cls):
with open(cls.file_name) as f:
cls.data = json.load(f)
cls.undertest = cls.data
cls.cdb = FakeCDB()

def setUp(self):
if self.seed is not None:
data_utils.set_all_seeds(self.seed)
(self.train_set, self.test_set,
self.num_test_anns,
self.num_total_anns) = data_utils.make_mc_train_test(
self.undertest, self.cdb, test_size=self.test_size)


class TestTrainSplitUnfilteredTests(TestTrainSplitTestsBase):

def test_all_docs_accounted_for(self):
self.assertEqual(count_all_docs(self.undertest),
count_all_docs(self.train_set) + count_all_docs(self.test_set))

def test_all_anns_accounted_for(self):
self.assertEqual(count_all_annotations(self.undertest),
count_all_annotations(self.train_set) + count_all_annotations(self.test_set))

def test_total_anns_match(self):
total = count_all_annotations(self.undertest)
self.assertEqual(self.num_total_anns, total)
self.assertEqual(self.num_test_anns + count_all_annotations(self.train_set),
total)

def test_nonempty_train(self):
if not self.expect_empty_train_set:
self.assertTrue(self.train_set)
self.assertTrue(self.num_total_anns - self.num_test_anns)
self.assertEqual(self.num_total_anns - self.num_test_anns,
count_all_annotations(self.train_set))

def test_nonempty_test(self):
if not self.expect_empty_test_set:
self.assertTrue(self.test_set)
self.assertTrue(self.num_test_anns)
self.assertEqual(self.num_test_anns,
count_all_annotations(self.test_set))


class TestTrainSplitFilteredTestsBase(TestTrainSplitUnfilteredTests):
expect_empty_test_set = True
# would work with previous version:
# seed = 332378110
# was guaranteed to fail with previous version:
seed = 73607120

@classmethod
def setUpClass(cls):
super().setUpClass()
cls.filtered = deepcopy(cls.data)
for proj in cls.filtered['projects']:
proj['documents'] = [doc for doc in proj['documents']
if doc['id'] in cls.allowed_doc_ids]
cls.undertest = cls.filtered