diff --git a/tests/utils/loggers/test_utils.py b/tests/utils/loggers/test_utils.py index e8ec0b2c8f..9d5547448f 100644 --- a/tests/utils/loggers/test_utils.py +++ b/tests/utils/loggers/test_utils.py @@ -29,3 +29,8 @@ def test_scalar_to_float(self) -> None: valid_ndarray = np.array([[[float_x]]]) self.assertAlmostEqual(scalar_to_float(valid_ndarray), float_x) + + def test_scalar_to_float_bf16(self) -> None: + float_x = 3.45 + valid_tensor = torch.Tensor([float_x]).to(torch.bfloat16) + self.assertAlmostEqual(scalar_to_float(valid_tensor), float_x, delta=0.01) diff --git a/torchtnt/utils/loggers/utils.py b/torchtnt/utils/loggers/utils.py index 825e6a9342..453b0aa759 100644 --- a/torchtnt/utils/loggers/utils.py +++ b/torchtnt/utils/loggers/utils.py @@ -20,7 +20,7 @@ def scalar_to_float(scalar: Scalar) -> float: f"Scalar tensor must contain a single item, {numel} given." ) - return float(scalar.cpu().detach().numpy().item()) + return float(scalar.cpu().detach().float().numpy().item()) elif isinstance(scalar, ndarray): numel = scalar.size if numel != 1: