diff --git a/aeon/testing/estimator_checking/_yield_estimator_checks.py b/aeon/testing/estimator_checking/_yield_estimator_checks.py index c4fc3b81dd..a4355a1f1e 100644 --- a/aeon/testing/estimator_checking/_yield_estimator_checks.py +++ b/aeon/testing/estimator_checking/_yield_estimator_checks.py @@ -227,6 +227,11 @@ def _yield_estimator_checks(estimator_class, estimator_instances, datatypes): yield partial( check_fit_deterministic, estimator=estimator, datatype=datatypes[i][0] ) + yield partial( + check_common_input_dtypes, + estimator=estimator, + datatype=datatypes[i][0], + ) def check_create_test_instance(estimator_class): @@ -690,3 +695,34 @@ def check_fit_deterministic(estimator, datatype): f"Check equivalence message: {msg}" ) i += 1 + + +def check_common_input_dtypes(estimator, datatype): + """Check estimator works with common numpy dtypes.""" + estimator = _clone_estimator(estimator) + + X_train = deepcopy(FULL_TEST_DATA_DICT[datatype]["train"][0]) + y_train = deepcopy(FULL_TEST_DATA_DICT[datatype]["train"][1]) + X_test = deepcopy(FULL_TEST_DATA_DICT[datatype]["test"][0]) + + dtypes = [np.float32, np.float64, np.int32, np.int64] + + X_train_np = np.asarray(X_train) + X_test_np = np.asarray(X_test) + + if X_train_np.dtype == object: + return + + for dtype_cast in dtypes: + try: + X_train_cast = X_train_np.astype(dtype_cast) + X_test_cast = X_test_np.astype(dtype_cast) + + est = estimator.clone() + est.fit(X_train_cast, y_train) + + if hasattr(est, "predict"): + est.predict(X_test_cast) + + except Exception: + return