Skip to content

Commit 97c3824

Browse files
committed
update sklearn
1 parent 22a2b71 commit 97c3824

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

pytorch_fob/tasks/tabular/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from sklearn.metrics import r2_score, mean_squared_error
2+
from sklearn.metrics import r2_score, root_mean_squared_error
33
from rtdl_revisiting_models import FTTransformer, _CLSEmbedding, LinearEmbeddings, CategoricalEmbeddings
44
from pytorch_fob.engine.configs import TaskConfig
55
from pytorch_fob.engine.parameter_groups import GroupedModel, ParameterGroup, group_named_parameters
@@ -80,7 +80,7 @@ def compute_and_log_metrics(self, preds: torch.Tensor, targets: torch.Tensor, st
8080
preds = preds.detach().cpu().float().numpy()
8181
targets = targets.detach().cpu().float().numpy()
8282
metrics = {
83-
"rmse": mean_squared_error(targets, preds, squared=False),
83+
"rmse": root_mean_squared_error(targets, preds),
8484
"r2_score": r2_score(targets, preds)
8585
}
8686
for k, v in metrics.items():

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ tqdm~=4.66.1
1010
wget~=3.2
1111
deepspeed~=0.12.6
1212
rtdl_revisiting_models~=0.0.2
13-
scikit-learn~=1.3.2
13+
scikit-learn~=1.5.0
1414
transformers~=4.38.0
1515
tokenizers~=0.15.0
1616
sentencepiece~=0.2.0

0 commit comments

Comments
 (0)