From 4b8109ea39c56bb2ed9df121bfb0f59abf42b245 Mon Sep 17 00:00:00 2001 From: Jason Krone Date: Wed, 4 Dec 2024 13:51:08 -0500 Subject: [PATCH] update mean reduction zloss to ignore labels == ignore_index vs. setting them to 0 --- olmo/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/olmo/train.py b/olmo/train.py index 105f82e40..4bcfe6a98 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -135,7 +135,8 @@ def cross_entropy_loss( z_squared = logits.logsumexp(-1).pow(2) if reduction == "mean": - z_squared = (z_squared * (labels != ignore_index)).mean() + mask = labels != ignore_index + z_squared = (z_squared * mask).sum() / mask.sum() elif reduction == "sum": z_squared = (z_squared * (labels != ignore_index)).sum()