@@ -227,6 +227,11 @@ def _yield_estimator_checks(estimator_class, estimator_instances, datatypes):
227227 yield partial (
228228 check_fit_deterministic , estimator = estimator , datatype = datatypes [i ][0 ]
229229 )
230+ yield partial (
231+ check_common_input_dtypes ,
232+ estimator = estimator ,
233+ datatype = datatypes [i ][0 ],
234+ )
230235
231236
232237def check_create_test_instance (estimator_class ):
@@ -690,3 +695,30 @@ def check_fit_deterministic(estimator, datatype):
690695 f"Check equivalence message: { msg } "
691696 )
692697 i += 1
698+
699+
700+ def check_common_input_dtypes (estimator , datatype ):
701+ """Check estimator works with common numpy dtypes."""
702+ estimator = _clone_estimator (estimator )
703+
704+ X_train = deepcopy (FULL_TEST_DATA_DICT [datatype ]["train" ][0 ])
705+ y_train = deepcopy (FULL_TEST_DATA_DICT [datatype ]["train" ][1 ])
706+ X_test = deepcopy (FULL_TEST_DATA_DICT [datatype ]["test" ][0 ])
707+
708+ dtypes = [np .float32 , np .float64 , np .int32 , np .int64 ]
709+
710+ for dtype_cast in dtypes :
711+ try :
712+ X_train_cast = X_train .astype (dtype_cast )
713+ X_test_cast = X_test .astype (dtype_cast )
714+
715+ est = estimator .clone ()
716+ est .fit (X_train_cast , y_train )
717+
718+ if hasattr (est , "predict" ):
719+ est .predict (X_test_cast )
720+
721+ except Exception as e :
722+ raise AssertionError (
723+ f"{ type (estimator ).__name__ } failed for dtype { dtype_cast } : { e } "
724+ )
0 commit comments