diff --git a/src/centimators/model_estimators/keras_estimators/sequence.py b/src/centimators/model_estimators/keras_estimators/sequence.py index 9cd3540..19a832b 100644 --- a/src/centimators/model_estimators/keras_estimators/sequence.py +++ b/src/centimators/model_estimators/keras_estimators/sequence.py @@ -10,7 +10,7 @@ import numpy from .base import BaseKerasEstimator, _ensure_numpy -from keras import ops, layers, models +from keras import layers, models @dataclass(kw_only=True) @@ -25,15 +25,14 @@ def __post_init__(self): def _reshape(self, X: IntoFrame, validation_data: tuple[Any, Any] | None = None): X = _ensure_numpy(X) - X_reshaped = ops.reshape( - X, (X.shape[0], self.seq_length, self.n_features_per_timestep) + X_reshaped = X.reshape( + (X.shape[0], self.seq_length, self.n_features_per_timestep) ) if validation_data: X_val, y_val = validation_data X_val = _ensure_numpy(X_val) - X_val_reshaped = ops.reshape( - X_val, + X_val_reshaped = X_val.reshape( (X_val.shape[0], self.seq_length, self.n_features_per_timestep), ) validation_data = X_val_reshaped, _ensure_numpy(y_val)