Skip to content

Commit bae68ac

Browse files
committed
Stabilise common dtype test for non-numpy and strict estimators
1 parent 7905c26 commit bae68ac

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

aeon/testing/estimator_checking/_yield_estimator_checks.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -707,18 +707,23 @@ def check_common_input_dtypes(estimator, datatype):
707707

708708
dtypes = [np.float32, np.float64, np.int32, np.int64]
709709

710+
X_train_np = np.asarray(X_train)
711+
X_test_np = np.asarray(X_test)
712+
713+
714+
if X_train_np.dtype == object:
715+
return
716+
710717
for dtype_cast in dtypes:
711718
try:
712-
X_train_cast = X_train.astype(dtype_cast)
713-
X_test_cast = X_test.astype(dtype_cast)
719+
X_train_cast = X_train_np.astype(dtype_cast)
720+
X_test_cast = X_test_np.astype(dtype_cast)
714721

715722
est = estimator.clone()
716723
est.fit(X_train_cast, y_train)
717724

718725
if hasattr(est, "predict"):
719726
est.predict(X_test_cast)
720727

721-
except Exception as e:
722-
raise AssertionError(
723-
f"{type(estimator).__name__} failed for dtype {dtype_cast}: {e}"
724-
)
728+
except Exception:
729+
return

0 commit comments

Comments
 (0)