Skip to content

Commit 7905c26

Browse files
committed
[ENH] Add common dtype compatibility check to estimator tests
1 parent ff28266 commit 7905c26

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

aeon/testing/estimator_checking/_yield_estimator_checks.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,11 @@ def _yield_estimator_checks(estimator_class, estimator_instances, datatypes):
227227
yield partial(
228228
check_fit_deterministic, estimator=estimator, datatype=datatypes[i][0]
229229
)
230+
yield partial(
231+
check_common_input_dtypes,
232+
estimator=estimator,
233+
datatype=datatypes[i][0],
234+
)
230235

231236

232237
def check_create_test_instance(estimator_class):
@@ -690,3 +695,30 @@ def check_fit_deterministic(estimator, datatype):
690695
f"Check equivalence message: {msg}"
691696
)
692697
i += 1
698+
699+
700+
def check_common_input_dtypes(estimator, datatype):
701+
"""Check estimator works with common numpy dtypes."""
702+
estimator = _clone_estimator(estimator)
703+
704+
X_train = deepcopy(FULL_TEST_DATA_DICT[datatype]["train"][0])
705+
y_train = deepcopy(FULL_TEST_DATA_DICT[datatype]["train"][1])
706+
X_test = deepcopy(FULL_TEST_DATA_DICT[datatype]["test"][0])
707+
708+
dtypes = [np.float32, np.float64, np.int32, np.int64]
709+
710+
for dtype_cast in dtypes:
711+
try:
712+
X_train_cast = X_train.astype(dtype_cast)
713+
X_test_cast = X_test.astype(dtype_cast)
714+
715+
est = estimator.clone()
716+
est.fit(X_train_cast, y_train)
717+
718+
if hasattr(est, "predict"):
719+
est.predict(X_test_cast)
720+
721+
except Exception as e:
722+
raise AssertionError(
723+
f"{type(estimator).__name__} failed for dtype {dtype_cast}: {e}"
724+
)

0 commit comments

Comments
 (0)