diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst
index 49c84dab28..729c86c34f 100644
--- a/docs/source/handlers.rst
+++ b/docs/source/handlers.rst
@@ -83,6 +83,12 @@ Panoptic Quality metrics handler
     :members:
 
 
+:math:`R^{2}` score
+-------------------
+.. autoclass:: R2Score
+    :members:
+
+
 Mean squared error metrics handler
 ----------------------------------
 .. autoclass:: MeanSquaredError
diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst
index 45e0827cf9..b2bc2f114d 100644
--- a/docs/source/metrics.rst
+++ b/docs/source/metrics.rst
@@ -124,6 +124,13 @@ Metrics
 .. autoclass:: PanopticQualityMetric
     :members:
 
+:math:`R^{2}` score
+-------------------
+.. autofunction:: compute_r2_score
+
+.. autoclass:: R2Metric
+    :members:
+
 `Mean squared error`
 --------------------
 .. autoclass:: MSEMetric
diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py
index ed5db8a7f3..39565c0903 100644
--- a/monai/handlers/__init__.py
+++ b/monai/handlers/__init__.py
@@ -35,6 +35,7 @@
 from .parameter_scheduler import ParamSchedulerHandler
 from .postprocessing import PostProcessing
 from .probability_maps import ProbMapProducer
+from .r2_score import R2Score
 from .regression_metrics import MeanAbsoluteError, MeanSquaredError, PeakSignalToNoiseRatio, RootMeanSquaredError
 from .roc_auc import ROCAUC
 from .smartcache_handler import SmartCacheHandler
diff --git a/monai/handlers/r2_score.py b/monai/handlers/r2_score.py
new file mode 100644
index 0000000000..dc94182885
--- /dev/null
+++ b/monai/handlers/r2_score.py
@@ -0,0 +1,56 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from collections.abc import Callable
+
+from monai.handlers.ignite_metric import IgniteMetricHandler
+from monai.metrics import R2Metric
+from monai.utils import MultiOutput
+
+
+class R2Score(IgniteMetricHandler):
+    """
+    Computes :math:`R^{2}` score accumulating predictions and the ground-truth during an epoch and applying `compute_r2_score`.
+
+    Args:
+        multi_output: {``"raw_values"``, ``"uniform_average"``, ``"variance_weighted"``}
+            Type of aggregation performed on multi-output scores.
+            Defaults to ``"uniform_average"``.
+
+            - ``"raw_values"``: the scores for each output are returned.
+            - ``"uniform_average"``: the scores of all outputs are averaged with uniform weight.
+            - ``"variance_weighted"``: the scores of all outputs are averaged, weighted by the variances
+              of each individual output.
+        p: non-negative integer.
+            Number of independent variables used for regression. ``p`` is used to compute adjusted :math:`R^{2}` score.
+            Defaults to 0 (standard :math:`R^{2}` score).
+        output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then
+            construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or
+            lists of `channel-first` Tensors. The form of `(y_pred, y)` is required by the `update()`.
+            `engine.state` and `output_transform` inherit from the ignite concept:
+            https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
+            https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
+
+    See also:
+        :py:class:`monai.metrics.R2Metric`
+
+    """
+
+    def __init__(
+        self,
+        multi_output: MultiOutput | str = MultiOutput.UNIFORM,
+        p: int = 0,
+        output_transform: Callable = lambda x: x,
+    ) -> None:
+        metric_fn = R2Metric(multi_output=multi_output, p=p)
+        super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=False)
diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py
index 7176f3311f..6368467c76 100644
--- a/monai/metrics/__init__.py
+++ b/monai/metrics/__init__.py
@@ -26,6 +26,7 @@
 from .metric import Cumulative, CumulativeIterationMetric, IterationMetric, Metric
 from .mmd import MMDMetric, compute_mmd
 from .panoptic_quality import PanopticQualityMetric, compute_panoptic_quality
+from .r2_score import R2Metric, compute_r2_score
 from .regression import (
     MAEMetric,
     MSEMetric,
diff --git a/monai/metrics/r2_score.py b/monai/metrics/r2_score.py
new file mode 100644
index 0000000000..22b18ab87d
--- /dev/null
+++ b/monai/metrics/r2_score.py
@@ -0,0 +1,194 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import numpy as np
+
+if TYPE_CHECKING:
+    import numpy.typing as npt
+
+import torch
+
+from monai.utils import MultiOutput, look_up_option
+
+from .metric import CumulativeIterationMetric
+
+
+class R2Metric(CumulativeIterationMetric):
+    r"""Computes :math:`R^{2}` score (coefficient of determination). :math:`R^{2}` is used to evaluate
+    a regression model. In the best case, when the predictions match exactly the observed values, :math:`R^{2} = 1`.
+    It has no lower bound, and the more negative it is, the worse the model is. Finally, a baseline model, which always
+    predicts the mean of observed values, will get :math:`R^{2} = 0`.
+
+    .. math::
+        \operatorname {R^{2}}\left(Y, \hat{Y}\right) = 1 - \frac {\sum _{i=1}^{n}\left(y_i-\hat{y_i} \right)^{2}}
+        {\sum _{i=1}^{n}\left(y_i-\bar{y} \right)^{2}},
+        :label: r2
+
+    where :math:`\bar{y}` is the mean of observed :math:`y`.
+
+    However, :math:`R^{2}` automatically increases when extra
+    variables are added to the model. To account for this phenomenon and penalize the addition of unnecessary variables,
+    :math:`adjusted \ R^{2}` (:math:`\bar{R}^{2}`) is defined:
+
+    .. math::
+        \operatorname {\bar{R}^{2}} = 1 - (1-R^{2}) \frac {n-1}{n-p-1},
+        :label: r2_adjusted
+
+    where :math:`p` is the number of independant variables used for the regression.
+
+    More info: https://en.wikipedia.org/wiki/Coefficient_of_determination
+
+    Input `y_pred` is compared with ground truth `y`.
+    `y_pred` and `y` are expected to be 1D (single-output regression) or 2D (multi-output regression) real-valued
+    tensors of same shape.
+
+    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
+
+    Args:
+        multi_output: {``"raw_values"``, ``"uniform_average"``, ``"variance_weighted"``}
+            Type of aggregation performed on multi-output scores.
+            Defaults to ``"uniform_average"``.
+
+            - ``"raw_values"``: the scores for each output are returned.
+            - ``"uniform_average"``: the scores of all outputs are averaged with uniform weight.
+            - ``"variance_weighted"``: the scores of all outputs are averaged, weighted by the variances of
+              each individual output.
+        p: non-negative integer.
+            Number of independent variables used for regression. ``p`` is used to compute :math:`\bar{R}^{2}` score.
+            Defaults to 0 (standard :math:`R^{2}` score).
+
+    """
+
+    def __init__(self, multi_output: MultiOutput | str = MultiOutput.UNIFORM, p: int = 0) -> None:
+        super().__init__()
+        multi_output, p = _check_r2_params(multi_output, p)
+        self.multi_output = multi_output
+        self.p = p
+
+    def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:  # type: ignore[override]
+        _check_dim(y_pred, y)
+        return y_pred, y
+
+    def aggregate(self, multi_output: MultiOutput | str | None = None) -> np.ndarray | float | npt.ArrayLike:
+        """
+        Typically `y_pred` and `y` are stored in the cumulative buffers at each iteration,
+        This function reads the buffers and computes the :math:`R^{2}` score.
+
+        Args:
+            multi_output: {``"raw_values"``, ``"uniform_average"``, ``"variance_weighted"``}
+                Type of aggregation performed on multi-output scores. Defaults to `self.multi_output`.
+
+        """
+        y_pred, y = self.get_buffer()
+        return compute_r2_score(y_pred=y_pred, y=y, multi_output=multi_output or self.multi_output, p=self.p)
+
+
+def _check_dim(y_pred: torch.Tensor, y: torch.Tensor) -> None:
+    if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor):
+        raise ValueError("y_pred and y must be PyTorch Tensor.")
+
+    if y.shape != y_pred.shape:
+        raise ValueError(f"data shapes of y_pred and y do not match, got {y_pred.shape} and {y.shape}.")
+
+    dim = y.ndimension()
+    if dim not in (1, 2):
+        raise ValueError(
+            f"predictions and ground truths should be of shape (batch_size, num_outputs) or (batch_size, ), got {y.shape}."
+        )
+
+
+def _check_r2_params(multi_output: MultiOutput | str, p: int) -> tuple[MultiOutput | str, int]:
+    multi_output = look_up_option(multi_output, MultiOutput)
+    if not isinstance(p, int) or p < 0:
+        raise ValueError(f"`p` must be an integer larger or equal to 0, got {p}.")
+
+    return multi_output, p
+
+
+def _calculate(y_pred: np.ndarray, y: np.ndarray, p: int) -> float:
+    num_obs = len(y)
+    rss = np.sum((y_pred - y) ** 2)
+    tss = np.sum((y - np.mean(y)) ** 2)
+    r2 = 1 - (rss / tss)
+    r2_adjusted = 1 - (1 - r2) * (num_obs - 1) / (num_obs - p - 1)
+
+    return r2_adjusted  # type: ignore[no-any-return]
+
+
+def compute_r2_score(
+    y_pred: torch.Tensor, y: torch.Tensor, multi_output: MultiOutput | str = MultiOutput.UNIFORM, p: int = 0
+) -> np.ndarray | float | npt.ArrayLike:
+    """Computes :math:`R^{2}` score (coefficient of determination). :math:`R^{2}` is used to evaluate
+    a regression model according to equations :eq:`r2` and :eq:`r2_adjusted`.
+
+    Args:
+        y_pred: input data to compute :math:`R^{2}` score, the first dim must be batch.
+            For example: shape `[16]` or `[16, 1]` for a single-output regression, shape `[16, x]` for x output variables.
+        y: ground truth to compute :math:`R^{2}` score, the first dim must be batch.
+            For example: shape `[16]` or `[16, 1]` for a single-output regression, shape `[16, x]` for x output variables.
+        multi_output: {``"raw_values"``, ``"uniform_average"``, ``"variance_weighted"``}
+            Type of aggregation performed on multi-output scores.
+            Defaults to ``"uniform_average"``.
+
+            - ``"raw_values"``: the scores for each output are returned.
+            - ``"uniform_average"``: the scores of all outputs are averaged with uniform weight.
+            - ``"variance_weighted"``: the scores of all outputs are averaged, weighted by the variances
+              each individual output.
+        p: non-negative integer.
+            Number of independent variables used for regression. ``p`` is used to compute :math:`\bar{R}^{2}` score.
+            Defaults to 0 (standard :math:`R^{2}` score).
+
+    Raises:
+        ValueError: When ``multi_output`` is not one of ["raw_values", "uniform_average", "variance_weighted"].
+        ValueError: When ``p`` is not a non-negative integer.
+        ValueError: When ``y_pred`` or ``y`` are not PyTorch tensors.
+        ValueError: When ``y_pred`` and ``y`` don't have the same shape.
+        ValueError: When ``y_pred`` or ``y`` dimension is not one of [1, 2].
+        ValueError: When n_samples is less than 2.
+        ValueError: When ``p`` is greater or equal to n_samples - 1.
+
+    """
+    multi_output, p = _check_r2_params(multi_output, p)
+    _check_dim(y_pred, y)
+    dim = y.ndimension()
+    n = y.shape[0]
+    y = y.cpu().numpy()  # type: ignore[assignment]
+    y_pred = y_pred.cpu().numpy()  # type: ignore[assignment]
+
+    if n < 2:
+        raise ValueError("There is no enough data for computing. Needs at least two samples to calculate r2 score.")
+    if p >= n - 1:
+        raise ValueError("`p` must be smaller than n_samples - 1, " f"got p={p}, n_samples={n}.")
+
+    if dim == 2 and y_pred.shape[1] == 1:
+        y_pred = np.squeeze(y_pred, axis=-1)  # type: ignore[assignment]
+        y = np.squeeze(y, axis=-1)  # type: ignore[assignment]
+        dim = 1
+
+    if dim == 1:
+        return _calculate(y_pred, y, p)  # type: ignore[arg-type]
+
+    y, y_pred = np.transpose(y, axes=(1, 0)), np.transpose(y_pred, axes=(1, 0))  # type: ignore[assignment]
+    r2_values = [_calculate(y_pred_, y_, p) for y_pred_, y_ in zip(y_pred, y)]
+    if multi_output == MultiOutput.RAW:
+        return r2_values
+    if multi_output == MultiOutput.UNIFORM:
+        return np.mean(r2_values)
+    if multi_output == MultiOutput.VARIANCE:
+        weights = np.var(y, axis=1)
+        return np.average(r2_values, weights=weights)  # type: ignore[no-any-return]
+    raise ValueError(
+        f'Unsupported multi_output: {multi_output}, available options are ["raw_values", "uniform_average", "variance_weighted"].'
+    )
diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py
index 8f2f400b5d..d0fdd8369c 100644
--- a/monai/utils/__init__.py
+++ b/monai/utils/__init__.py
@@ -46,6 +46,7 @@
     MetaKeys,
     Method,
     MetricReduction,
+    MultiOutput,
     NdimageMode,
     NumpyPadMode,
     OrderingTransformations,
diff --git a/monai/utils/enums.py b/monai/utils/enums.py
index 3463a92e4b..793f32a16f 100644
--- a/monai/utils/enums.py
+++ b/monai/utils/enums.py
@@ -29,6 +29,7 @@
     "NdimageMode",
     "GridSamplePadMode",
     "Average",
+    "MultiOutput",
     "MetricReduction",
     "LossReduction",
     "DiceCEReduction",
@@ -223,6 +224,16 @@ class Average(StrEnum):
     NONE = "none"
 
 
+class MultiOutput(StrEnum):
+    """
+    See also: :py:func:`monai.metrics.r2_score.compute_r2_score`
+    """
+
+    RAW = "raw_values"
+    UNIFORM = "uniform_average"
+    VARIANCE = "variance_weighted"
+
+
 class MetricReduction(StrEnum):
     """
     See also: :py:func:`monai.metrics.utils.do_metric_reduction`
diff --git a/tests/handlers/test_handler_r2_score.py b/tests/handlers/test_handler_r2_score.py
new file mode 100644
index 0000000000..b4d4c1613e
--- /dev/null
+++ b/tests/handlers/test_handler_r2_score.py
@@ -0,0 +1,72 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+from monai.handlers import R2Score
+from tests.test_utils import DistCall, DistTestCase
+
+
+class TestHandlerR2Score(unittest.TestCase):
+
+    def test_compute(self):
+        r2_score = R2Score(multi_output="variance_weighted", p=1)
+
+        y_pred = [torch.Tensor([0.1, 1.0]), torch.Tensor([-0.25, 0.5])]
+        y = [torch.Tensor([0.1, 0.82]), torch.Tensor([-0.2, 0.01])]
+        r2_score.update([y_pred, y])
+
+        y_pred = [torch.Tensor([3.0, -0.2]), torch.Tensor([0.99, 2.1])]
+        y = [torch.Tensor([2.7, -0.1]), torch.Tensor([1.58, 2.0])]
+
+        r2_score.update([y_pred, y])
+
+        r2 = r2_score.compute()
+        np.testing.assert_allclose(0.867314, r2, rtol=1e-5)
+
+
+class DistributedR2Score(DistTestCase):
+
+    @DistCall(nnodes=1, nproc_per_node=2, node_rank=0)
+    def test_compute(self):
+        r2_score = R2Score(multi_output="variance_weighted", p=1)
+
+        device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu"
+        if dist.get_rank() == 0:
+            y_pred = [torch.tensor([0.1, 1.0], device=device), torch.tensor([-0.25, 0.5], device=device)]
+            y = [torch.tensor([0.1, 0.82], device=device), torch.tensor([-0.2, 0.01], device=device)]
+            r2_score.update([y_pred, y])
+
+        if dist.get_rank() == 1:
+            y_pred = [
+                torch.tensor([3.0, -0.2], device=device),
+                torch.tensor([0.99, 2.1], device=device),
+                torch.tensor([-0.1, 0.0], device=device),
+            ]
+            y = [
+                torch.tensor([2.7, -0.1], device=device),
+                torch.tensor([1.58, 2.0], device=device),
+                torch.tensor([-1.0, -0.1], device=device),
+            ]
+            r2_score.update([y_pred, y])
+
+        result = r2_score.compute()
+        np.testing.assert_allclose(0.829185, result, rtol=1e-5)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/tests/metrics/test_compute_r2_score.py b/tests/metrics/test_compute_r2_score.py
new file mode 100644
index 0000000000..0cea11cf47
--- /dev/null
+++ b/tests/metrics/test_compute_r2_score.py
@@ -0,0 +1,150 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import numpy as np
+import torch
+from parameterized import parameterized
+
+from monai.metrics import R2Metric, compute_r2_score
+
+_device = "cuda:0" if torch.cuda.is_available() else "cpu"
+TEST_CASE_1 = [
+    torch.tensor([0.1, -0.25, 3.0, 0.99], device=_device),
+    torch.tensor([0.1, -0.2, -2.7, 1.58], device=_device),
+    "uniform_average",
+    0,
+    -2.469944,
+]
+
+TEST_CASE_2 = [
+    torch.tensor([0.1, -0.25, 3.0, 0.99]),
+    torch.tensor([0.1, -0.2, 2.7, 1.58]),
+    "uniform_average",
+    2,
+    0.75828,
+]
+
+TEST_CASE_3 = [
+    torch.tensor([[0.1], [-0.25], [3.0], [0.99]]),
+    torch.tensor([[0.1], [-0.2], [2.7], [1.58]]),
+    "raw_values",
+    2,
+    0.75828,
+]
+
+TEST_CASE_4 = [
+    torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]),
+    torch.tensor([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]),
+    "raw_values",
+    1,
+    [0.87914, 0.844375],
+]
+
+TEST_CASE_5 = [
+    torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]),
+    torch.tensor([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]),
+    "variance_weighted",
+    1,
+    0.867314,
+]
+
+TEST_CASE_6 = [
+    torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]),
+    torch.tensor([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]),
+    "uniform_average",
+    0,
+    0.907838,
+]
+
+TEST_CASE_ERROR_1 = [
+    torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]),
+    torch.tensor([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]),
+    "abc",
+    0,
+]
+
+TEST_CASE_ERROR_2 = [
+    torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]),
+    torch.tensor([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]),
+    "uniform_average",
+    -1,
+]
+
+TEST_CASE_ERROR_3 = [
+    torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]),
+    np.array([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]),
+    "uniform_average",
+    0,
+]
+
+TEST_CASE_ERROR_4 = [
+    torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]),
+    torch.tensor([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1]]),
+    "uniform_average",
+    0,
+]
+
+TEST_CASE_ERROR_5 = [
+    torch.tensor([[[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]]),
+    torch.tensor([[[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]]),
+    "uniform_average",
+    0,
+]
+
+TEST_CASE_ERROR_6 = [
+    torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]),
+    torch.tensor([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]),
+    "uniform_average",
+    3,
+]
+
+TEST_CASE_ERROR_7 = [torch.tensor([[0.1, 1.0]]), torch.tensor([[0.1, 0.82]]), "uniform_average", 0]
+
+
+class TestComputeR2Score(unittest.TestCase):
+
+    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])
+    def test_value(self, y_pred, y, multi_output, p, expected_value):
+        result = compute_r2_score(y_pred=y_pred, y=y, multi_output=multi_output, p=p)
+        np.testing.assert_allclose(expected_value, result, rtol=1e-5)
+
+    @parameterized.expand(
+        [
+            TEST_CASE_ERROR_1,
+            TEST_CASE_ERROR_2,
+            TEST_CASE_ERROR_3,
+            TEST_CASE_ERROR_4,
+            TEST_CASE_ERROR_5,
+            TEST_CASE_ERROR_6,
+            TEST_CASE_ERROR_7,
+        ]
+    )
+    def test_error(self, y_pred, y, multi_output, p):
+        with self.assertRaises(ValueError):
+            _ = compute_r2_score(y_pred=y_pred, y=y, multi_output=multi_output, p=p)
+
+    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])
+    def test_class_value(self, y_pred, y, multi_output, p, expected_value):
+        metric = R2Metric(multi_output=multi_output, p=p)
+        metric(y_pred=y_pred, y=y)
+        result = metric.aggregate()
+        np.testing.assert_allclose(expected_value, result, rtol=1e-5)
+        result = metric.aggregate(multi_output=multi_output)  # test optional argument
+        metric.reset()
+        np.testing.assert_allclose(expected_value, result, rtol=1e-5)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/tests/min_tests.py b/tests/min_tests.py
index 12f494be9c..6e70bb77c0 100644
--- a/tests/min_tests.py
+++ b/tests/min_tests.py
@@ -99,6 +99,7 @@ def run_testsuit():
         "test_handler_parameter_scheduler",
         "test_handler_post_processing",
         "test_handler_prob_map_producer",
+        "test_handler_r2_score",
         "test_handler_regression_metrics",
         "test_handler_regression_metrics_dist",
         "test_handler_rocauc",