2828from . import sgrid
2929from .criteria import (
3030 _DSG_ROLES ,
31+ _GEOMETRY_TYPES ,
3132 cf_role_criteria ,
3233 coordinate_criteria ,
34+ geometry_var_criteria ,
3335 grid_mapping_var_criteria ,
3436 regex ,
3537)
3941 _format_data_vars ,
4042 _format_dsg_roles ,
4143 _format_flags ,
44+ _format_geometries ,
4245 _format_sgrid ,
4346 _maybe_panel ,
4447)
@@ -198,7 +201,9 @@ def _get_groupby_time_accessor(
198201
199202
200203def _get_custom_criteria (
201- obj : DataArray | Dataset , key : Hashable , criteria : Mapping | None = None
204+ obj : DataArray | Dataset ,
205+ key : Hashable ,
206+ criteria : Iterable [Mapping ] | Mapping | None = None ,
202207) -> list [Hashable ]:
203208 """
204209 Translate from axis, coord, or custom name to variable name.
@@ -227,18 +232,16 @@ def _get_custom_criteria(
227232 except ImportError :
228233 from re import match as regex_match # type: ignore[no-redef]
229234
230- if isinstance (obj , DataArray ):
231- obj = obj ._to_temp_dataset ()
232- variables = obj ._variables
233-
234235 if criteria is None :
235236 if not OPTIONS ["custom_criteria" ]:
236237 return []
237238 criteria = OPTIONS ["custom_criteria" ]
238239
239- if criteria is not None :
240- criteria_iter = always_iterable (criteria , allowed = (tuple , list , set ))
240+ if isinstance (obj , DataArray ):
241+ obj = obj ._to_temp_dataset ()
242+ variables = obj ._variables
241243
244+ criteria_iter = always_iterable (criteria , allowed = (tuple , list , set ))
242245 criteria_map = ChainMap (* criteria_iter )
243246 results : set = set ()
244247 if key in criteria_map :
@@ -367,6 +370,21 @@ def _get_measure(obj: DataArray | Dataset, key: str) -> list[str]:
367370 return list (results )
368371
369372
373+ def _parse_related_geometry_vars (attrs : Mapping ) -> tuple [Hashable ]:
374+ names = itertools .chain (
375+ * [
376+ attrs .get (attr , "" ).split (" " )
377+ for attr in [
378+ "interior_ring" ,
379+ "node_coordinates" ,
380+ "node_count" ,
381+ "part_node_count" ,
382+ ]
383+ ]
384+ )
385+ return tuple (n for n in names if n )
386+
387+
370388def _get_bounds (obj : DataArray | Dataset , key : Hashable ) -> list [Hashable ]:
371389 """
372390 Translate from key (either CF key or variable name) to its bounds' variable names.
@@ -470,8 +488,14 @@ def _get_all(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
470488 """
471489 all_mappers : tuple [Mapper ] = (
472490 _get_custom_criteria ,
473- functools .partial (_get_custom_criteria , criteria = cf_role_criteria ), # type: ignore[assignment]
474- functools .partial (_get_custom_criteria , criteria = grid_mapping_var_criteria ),
491+ functools .partial (
492+ _get_custom_criteria ,
493+ criteria = (
494+ cf_role_criteria ,
495+ grid_mapping_var_criteria ,
496+ geometry_var_criteria ,
497+ ),
498+ ), # type: ignore[assignment]
475499 _get_axis_coord ,
476500 _get_measure ,
477501 _get_grid_mapping_name ,
@@ -821,6 +845,23 @@ def check_results(names, key):
821845 successful [k ] = bool (grid_mapping )
822846 if grid_mapping :
823847 varnames .extend (grid_mapping )
848+ elif "geometries" not in skip and (k == "geometry" or k in _GEOMETRY_TYPES ):
849+ geometries = _get_all (obj , k )
850+ if geometries and k in _GEOMETRY_TYPES :
851+ new = itertools .chain (
852+ _parse_related_geometry_vars (
853+ ChainMap (obj [g ].attrs , obj [g ].encoding )
854+ )
855+ for g in geometries
856+ )
857+ geometries .extend (* new )
858+ if len (geometries ) > 1 and scalar_key :
859+ raise ValueError (
860+ f"CF geometries must be represented by an Xarray Dataset. To request a Dataset in return please pass `[{ k !r} ]` instead."
861+ )
862+ successful [k ] = bool (geometries )
863+ if geometries :
864+ varnames .extend (geometries )
824865 elif k in custom_criteria or k in cf_role_criteria :
825866 names = _get_all (obj , k )
826867 check_results (names , k )
@@ -1559,8 +1600,7 @@ def _generate_repr(self, rich=False):
15591600 _format_flags (self , rich ), title = "Flag Variable" , rich = rich
15601601 )
15611602
1562- roles = self .cf_roles
1563- if roles :
1603+ if roles := self .cf_roles :
15641604 if any (role in roles for role in _DSG_ROLES ):
15651605 yield _maybe_panel (
15661606 _format_dsg_roles (self , dims , rich ),
@@ -1576,6 +1616,13 @@ def _generate_repr(self, rich=False):
15761616 rich = rich ,
15771617 )
15781618
1619+ if self .geometries :
1620+ yield _maybe_panel (
1621+ _format_geometries (self , dims , rich ),
1622+ title = "Geometries" ,
1623+ rich = rich ,
1624+ )
1625+
15791626 yield _maybe_panel (
15801627 _format_coordinates (self , dims , coords , rich ),
15811628 title = "Coordinates" ,
@@ -1755,12 +1802,42 @@ def cf_roles(self) -> dict[str, list[Hashable]]:
17551802
17561803 vardict : dict [str , list [Hashable ]] = {}
17571804 for k , v in variables .items ():
1758- if "cf_role" in v .attrs :
1759- role = v . attrs [ "cf_role" ]
1805+ attrs_or_encoding = ChainMap ( v .attrs , v . encoding )
1806+ if role := attrs_or_encoding . get ( "cf_role" , None ):
17601807 vardict [role ] = vardict .setdefault (role , []) + [k ]
17611808
17621809 return {role_ : sort_maybe_hashable (v ) for role_ , v in vardict .items ()}
17631810
1811+ @property
1812+ def geometries (self ) -> dict [str , list [Hashable ]]:
1813+ """
1814+ Mapping geometry type names to variable names.
1815+
1816+ Returns
1817+ -------
1818+ dict
1819+ Dictionary mapping geometry names to variable names.
1820+
1821+ References
1822+ ----------
1823+ Please refer to the CF conventions document : http://cfconventions.org/Data/cf-conventions/cf-conventions-1.8/cf-conventions.html#coordinates-metadata
1824+ """
1825+ vardict : dict [str , list [Hashable ]] = {}
1826+
1827+ if isinstance (self ._obj , Dataset ):
1828+ variables = self ._obj ._variables
1829+ elif isinstance (self ._obj , DataArray ):
1830+ variables = {"_" : self ._obj ._variable }
1831+
1832+ for v in variables .values ():
1833+ attrs_or_encoding = ChainMap (v .attrs , v .encoding )
1834+ if geometry := attrs_or_encoding .get ("geometry" , None ):
1835+ gtype = self ._obj [geometry ].attrs ["geometry_type" ]
1836+ vardict .setdefault (gtype , [])
1837+ if geometry not in vardict [gtype ]:
1838+ vardict [gtype ] += [geometry ]
1839+ return {type_ : sort_maybe_hashable (v ) for type_ , v in vardict .items ()}
1840+
17641841 def get_associated_variable_names (
17651842 self , name : Hashable , skip_bounds : bool = False , error : bool = True
17661843 ) -> dict [str , list [Hashable ]]:
@@ -1795,15 +1872,15 @@ def get_associated_variable_names(
17951872 "bounds" ,
17961873 "grid_mapping" ,
17971874 "grid" ,
1875+ "geometry" ,
17981876 ]
17991877
18001878 coords : dict [str , list [Hashable ]] = {k : [] for k in keys }
18011879 attrs_or_encoding = ChainMap (self ._obj [name ].attrs , self ._obj [name ].encoding )
18021880
1803- coordinates = attrs_or_encoding .get ("coordinates" , None )
18041881 # Handles case where the coordinates attribute is None
18051882 # This is used to tell xarray to not write a coordinates attribute
1806- if coordinates :
1883+ if coordinates := attrs_or_encoding . get ( "coordinates" , None ) :
18071884 coords ["coordinates" ] = coordinates .split (" " )
18081885
18091886 if "cell_measures" in attrs_or_encoding :
@@ -1822,27 +1899,32 @@ def get_associated_variable_names(
18221899 )
18231900 coords ["cell_measures" ] = []
18241901
1825- if (
1826- isinstance (self ._obj , Dataset )
1827- and "ancillary_variables" in attrs_or_encoding
1902+ if isinstance (self ._obj , Dataset ) and (
1903+ anc := attrs_or_encoding .get ("ancillary_variables" , None )
18281904 ):
1829- coords ["ancillary_variables" ] = attrs_or_encoding [
1830- "ancillary_variables"
1831- ].split (" " )
1905+ coords ["ancillary_variables" ] = anc .split (" " )
18321906
18331907 if not skip_bounds :
1834- if " bounds" in attrs_or_encoding :
1835- coords ["bounds" ] = [attrs_or_encoding [ " bounds" ] ]
1908+ if bounds := attrs_or_encoding . get ( "bounds" , None ) :
1909+ coords ["bounds" ] = [bounds ]
18361910 for dim in self ._obj [name ].dims :
1837- dbounds = self ._obj [dim ].attrs .get ("bounds" , None )
1838- if dbounds :
1911+ if dbounds := self ._obj [dim ].attrs .get ("bounds" , None ):
18391912 coords ["bounds" ].append (dbounds )
18401913
1841- if "grid" in attrs_or_encoding :
1842- coords ["grid" ] = [attrs_or_encoding ["grid" ]]
1914+ for attrname in ["grid" , "grid_mapping" ]:
1915+ if maybe := attrs_or_encoding .get (attrname , None ):
1916+ coords [attrname ] = [maybe ]
18431917
1844- if "grid_mapping" in attrs_or_encoding :
1845- coords ["grid_mapping" ] = [attrs_or_encoding ["grid_mapping" ]]
1918+ more : Sequence [Hashable ] = ()
1919+ if geometry_var := attrs_or_encoding .get ("geometry" , None ):
1920+ coords ["geometry" ] = [geometry_var ]
1921+ _attrs = ChainMap (
1922+ self ._obj [geometry_var ].attrs , self ._obj [geometry_var ].encoding
1923+ )
1924+ more = _parse_related_geometry_vars (_attrs )
1925+ elif "geometry_type" in attrs_or_encoding :
1926+ more = _parse_related_geometry_vars (attrs_or_encoding )
1927+ coords ["geometry" ].extend (more )
18461928
18471929 allvars = itertools .chain (* coords .values ())
18481930 missing = set (allvars ) - set (self ._maybe_to_dataset ()._variables )
0 commit comments