Skip to content

Commit a900f2b

Browse files
committed
Stabilise common dtype test for non-numpy and strict estimators
1 parent 7905c26 commit a900f2b

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

aeon/testing/estimator_checking/_yield_estimator_checks.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -707,18 +707,19 @@ def check_common_input_dtypes(estimator, datatype):
707707

708708
dtypes = [np.float32, np.float64, np.int32, np.int64]
709709

710+
if not hasattr(np.asarray(X_train), "dtype"):
711+
return
712+
710713
for dtype_cast in dtypes:
711714
try:
712-
X_train_cast = X_train.astype(dtype_cast)
713-
X_test_cast = X_test.astype(dtype_cast)
715+
X_train_cast = np.asarray(X_train).astype(dtype_cast)
716+
X_test_cast = np.asarray(X_test).astype(dtype_cast)
714717

715718
est = estimator.clone()
716719
est.fit(X_train_cast, y_train)
717720

718721
if hasattr(est, "predict"):
719722
est.predict(X_test_cast)
720723

721-
except Exception as e:
722-
raise AssertionError(
723-
f"{type(estimator).__name__} failed for dtype {dtype_cast}: {e}"
724-
)
724+
except Exception:
725+
return

0 commit comments

Comments
 (0)