diff --git a/xarray/backends/api.py b/xarray/backends/api.py
index 019c5d11ed0..03f3a35255e 100644
--- a/xarray/backends/api.py
+++ b/xarray/backends/api.py
@@ -35,7 +35,7 @@
 from xarray.backends.locks import _get_scheduler
 from xarray.coders import CFDatetimeCoder, CFTimedeltaCoder
 from xarray.core import indexing
-from xarray.core.chunk import _get_chunk, _maybe_chunk
+from xarray.core.chunk import _get_chunk, _maybe_chunk, _maybe_get_path_chunk
 from xarray.core.combine import (
     _infer_concat_order_from_positions,
     _nested_combine,
@@ -450,7 +450,7 @@ def _datatree_from_backend_datatree(
                     node.dataset,
                     filename_or_obj,
                     engine,
-                    chunks,
+                    _maybe_get_path_chunk(node.path, chunks),
                     overwrite_encoded_chunks,
                     inline_array,
                     chunked_array_type,
diff --git a/xarray/core/chunk.py b/xarray/core/chunk.py
index e8ceba30e4e..d1c1d5f5cfb 100644
--- a/xarray/core/chunk.py
+++ b/xarray/core/chunk.py
@@ -145,3 +145,15 @@ def _maybe_chunk(
         return var
     else:
         return var
+
+
+def _maybe_get_path_chunk(path: str, chunks: int | dict | Any) -> int | dict | Any:
+    """Returns path-specific chunks from a chunks dictionary, if path is a key of chunks.
+    Otherwise, returns chunks as is"""
+    if isinstance(chunks, dict):
+        try:
+            return chunks[path]
+        except KeyError:
+            pass
+
+    return chunks
diff --git a/xarray/core/types.py b/xarray/core/types.py
index 186738ed718..99fa00f52b2 100644
--- a/xarray/core/types.py
+++ b/xarray/core/types.py
@@ -217,7 +217,9 @@ def copy(
 T_ChunkDimFreq: TypeAlias = Union["TimeResampler", T_ChunkDim]
 T_ChunksFreq: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDimFreq]
 # We allow the tuple form of this (though arguably we could transition to named dims only)
-T_Chunks: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDim]
+T_Chunks: TypeAlias = (
+    T_ChunkDim | Mapping[Any, T_ChunkDim] | Mapping[Any, Mapping[Any, T_ChunkDim]]
+)
 T_NormalizedChunks = tuple[tuple[int, ...], ...]
 
 DataVars = Mapping[Any, Any]
diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py
index efc1e131722..664ecafda0a 100644
--- a/xarray/tests/test_backends_datatree.py
+++ b/xarray/tests/test_backends_datatree.py
@@ -256,6 +256,37 @@ def test_open_datatree_chunks(self, tmpdir, simple_datatree) -> None:
 
             assert_chunks_equal(tree, original_tree, enforce_dask=True)
 
+    @requires_dask
+    def test_open_datatree_path_chunks(self, tmpdir, simple_datatree) -> None:
+        filepath = tmpdir / "test.nc"
+
+        root_chunks = {"x": 2, "y": 1}
+        set1_chunks = {"x": 1, "y": 2}
+        set2_chunks = {"x": 2, "y": 3}
+
+        root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})
+        set1_data = xr.Dataset({"a": ("y", [-1, 0, 1]), "b": ("x", [-10, 6])})
+        set2_data = xr.Dataset({"a": ("y", [1, 2, 3]), "b": ("x", [0.1, 0.2])})
+        original_tree = DataTree.from_dict(
+            {
+                "/": root_data.chunk(root_chunks),
+                "/group1": set1_data.chunk(set1_chunks),
+                "/group2": set2_data.chunk(set2_chunks),
+            }
+        )
+        original_tree.to_netcdf(filepath, engine="netcdf4")
+
+        chunks = {
+            "/": root_chunks,
+            "/group1": set1_chunks,
+            "/group2": set2_chunks,
+        }
+
+        with open_datatree(filepath, engine="netcdf4", chunks=chunks) as tree:
+            xr.testing.assert_identical(tree, original_tree)
+
+            assert_chunks_equal(tree, original_tree, enforce_dask=True)
+
     def test_open_groups(self, unaligned_datatree_nc) -> None:
         """Test `open_groups` with a netCDF4 file with an unaligned group hierarchy."""
         unaligned_dict_of_datasets = open_groups(unaligned_datatree_nc)
@@ -549,6 +580,36 @@ def test_open_datatree_chunks(self, tmpdir, simple_datatree) -> None:
             # from each node.
             xr.testing.assert_identical(tree.compute(), original_tree)
 
+    def test_open_datatree_path_chunks(self, tmpdir, simple_datatree) -> None:
+        filepath = tmpdir / "test.zarr"
+
+        root_chunks = {"x": 2, "y": 1}
+        set1_chunks = {"x": 1, "y": 2}
+        set2_chunks = {"x": 2, "y": 3}
+
+        root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})
+        set1_data = xr.Dataset({"a": ("y", [-1, 0, 1]), "b": ("x", [-10, 6])})
+        set2_data = xr.Dataset({"a": ("y", [1, 2, 3]), "b": ("x", [0.1, 0.2])})
+        original_tree = DataTree.from_dict(
+            {
+                "/": root_data.chunk(root_chunks),
+                "/group1": set1_data.chunk(set1_chunks),
+                "/group2": set2_data.chunk(set2_chunks),
+            }
+        )
+        original_tree.to_zarr(filepath)
+
+        chunks = {
+            "/": root_chunks,
+            "/group1": set1_chunks,
+            "/group2": set2_chunks,
+        }
+
+        with open_datatree(filepath, engine="zarr", chunks=chunks) as tree:
+            xr.testing.assert_identical(tree, original_tree)
+            assert_chunks_equal(tree, original_tree, enforce_dask=True)
+            xr.testing.assert_identical(tree.compute(), original_tree)
+
     def test_open_groups(self, unaligned_datatree_zarr) -> None:
         """Test `open_groups` with a zarr store of an unaligned group hierarchy."""