Skip to content

Commit 6c2bd60

Browse files
mjrenomjreno
authored andcommitted
support NCF encodings and grid mappings
1 parent 33a4298 commit 6c2bd60

File tree

6 files changed

+206
-35
lines changed

6 files changed

+206
-35
lines changed

flopy/discretization/grid.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1261,14 +1261,15 @@ def write_shapefile(self, filename="grid.shp", crs=None, prjfile=None, **kwargs)
12611261
)
12621262
return
12631263

1264-
def dataset(self, modeltime=None, mesh=None):
1264+
def dataset(self, modeltime=None, mesh=None, encoding=None):
12651265
"""
12661266
Method to generate baseline xarray dataset
12671267
12681268
Parameters
12691269
----------
12701270
modeltime : FloPy ModelTime object
12711271
mesh : mesh type
1272+
encoding : variable encoding dictionary
12721273
12731274
Returns
12741275
-------

flopy/discretization/structuredgrid.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1676,11 +1676,12 @@ def get_plottable_layer_array(self, a, layer):
16761676
assert plotarray.shape == required_shape, msg
16771677
return plotarray
16781678

1679-
def dataset(self, modeltime=None, mesh=None):
1679+
def dataset(self, modeltime=None, mesh=None, encoding=None):
16801680
"""
16811681
modeltime : FloPy ModelTime object
16821682
mesh : mesh type
16831683
valid mesh types are "layered" or None
1684+
encoding : variable encoding dictionary
16841685
"""
16851686
from ..utils import import_optional_dependency
16861687

@@ -1693,11 +1694,11 @@ def dataset(self, modeltime=None, mesh=None):
16931694
ds.attrs["modflow_grid"] = "STRUCTURED"
16941695

16951696
if mesh and mesh.upper() == "LAYERED":
1696-
return self._layered_mesh_dataset(ds, modeltime)
1697+
return self._layered_mesh_dataset(ds, modeltime, encoding)
16971698
elif mesh is None:
1698-
return self._structured_dataset(ds, modeltime)
1699+
return self._structured_dataset(ds, modeltime, encoding)
16991700

1700-
def _layered_mesh_dataset(self, ds, modeltime=None):
1701+
def _layered_mesh_dataset(self, ds, modeltime=None, encoding=None):
17011702
FILLNA_INT32 = np.int32(-2147483647)
17021703
FILLNA_DBL = 9.96920996838687e36
17031704
lenunits = {0: "m", 1: "ft", 2: "m", 3: "m"}
@@ -1800,9 +1801,25 @@ def _layered_mesh_dataset(self, ds, modeltime=None):
18001801
ds["mesh_face_nodes"].attrs["_FillValue"] = FILLNA_INT32
18011802
ds["mesh_face_nodes"].attrs["start_index"] = np.int32(1)
18021803

1804+
if encoding is not None and "wkt" in encoding and encoding["wkt"] is not None:
1805+
ds = ds.assign({"projection": ([], np.int32(1))})
1806+
# wkt override to existing crs
1807+
ds["projection"].attrs["wkt"] = encoding["wkt"]
1808+
ds["mesh_node_x"].attrs["grid_mapping"] = "projection"
1809+
ds["mesh_node_y"].attrs["grid_mapping"] = "projection"
1810+
ds["mesh_face_x"].attrs["grid_mapping"] = "projection"
1811+
ds["mesh_face_y"].attrs["grid_mapping"] = "projection"
1812+
elif self.crs is not None:
1813+
ds = ds.assign({"projection": ([], np.int32(1))})
1814+
ds["projection"].attrs["wkt"] = self.crs.to_wkt()
1815+
ds["mesh_node_x"].attrs["grid_mapping"] = "projection"
1816+
ds["mesh_node_y"].attrs["grid_mapping"] = "projection"
1817+
ds["mesh_face_x"].attrs["grid_mapping"] = "projection"
1818+
ds["mesh_face_y"].attrs["grid_mapping"] = "projection"
1819+
18031820
return ds
18041821

1805-
def _structured_dataset(self, ds, modeltime=None):
1822+
def _structured_dataset(self, ds, modeltime=None, encoding=None):
18061823
lenunits = {0: "m", 1: "ft", 2: "m", 3: "m"}
18071824

18081825
x = self.xoffset + self.xycenters[0]
@@ -1859,6 +1876,18 @@ def _structured_dataset(self, ds, modeltime=None):
18591876
ds["x"].attrs["long_name"] = "Easting"
18601877
ds["x"].attrs["bounds"] = "x_bnds"
18611878

1879+
if encoding is not None and "wkt" in encoding and encoding["wkt"] is not None:
1880+
ds = ds.assign({"projection": ([], np.int32(1))})
1881+
# wkt override to existing crs
1882+
ds["projection"].attrs["crs_wkt"] = encoding["wkt"]
1883+
ds["x"].attrs["grid_mapping"] = "projection"
1884+
ds["y"].attrs["grid_mapping"] = "projection"
1885+
elif self.crs is not None:
1886+
ds = ds.assign({"projection": ([], np.int32(1))})
1887+
ds["projection"].attrs["crs_wkt"] = self.crs.to_wkt()
1888+
ds["x"].attrs["grid_mapping"] = "projection"
1889+
ds["y"].attrs["grid_mapping"] = "projection"
1890+
18621891
return ds
18631892

18641893
def _set_structured_iverts(self):

flopy/discretization/vertexgrid.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,12 +600,13 @@ def get_plottable_layer_array(self, a, layer):
600600
assert plotarray.shape == required_shape, msg
601601
return plotarray
602602

603-
def dataset(self, modeltime=None, mesh=None):
603+
def dataset(self, modeltime=None, mesh=None, encoding=None):
604604
"""
605605
modeltime : FloPy ModelTime object
606606
mesh : mesh type
607607
valid mesh types are "layered" or None
608608
VertexGrid objects only support layered mesh
609+
encoding : variable encoding dictionary
609610
"""
610611
from ..utils import import_optional_dependency
611612

@@ -717,6 +718,22 @@ def dataset(self, modeltime=None, mesh=None):
717718
ds["mesh_face_nodes"].attrs["_FillValue"] = FILLNA_INT32
718719
ds["mesh_face_nodes"].attrs["start_index"] = np.int32(1)
719720

721+
if encoding is not None and "wkt" in encoding and encoding["wkt"] is not None:
722+
ds = ds.assign({"projection": ([], np.int32(1))})
723+
# wkt override to existing crs
724+
ds["projection"].attrs["wkt"] = encoding["wkt"]
725+
ds["mesh_node_x"].attrs["grid_mapping"] = "projection"
726+
ds["mesh_node_y"].attrs["grid_mapping"] = "projection"
727+
ds["mesh_face_x"].attrs["grid_mapping"] = "projection"
728+
ds["mesh_face_y"].attrs["grid_mapping"] = "projection"
729+
elif self.crs is not None:
730+
ds = ds.assign({"projection": ([], np.int32(1))})
731+
ds["projection"].attrs["wkt"] = self.crs.to_wkt()
732+
ds["mesh_node_x"].attrs["grid_mapping"] = "projection"
733+
ds["mesh_node_y"].attrs["grid_mapping"] = "projection"
734+
ds["mesh_face_x"].attrs["grid_mapping"] = "projection"
735+
ds["mesh_face_y"].attrs["grid_mapping"] = "projection"
736+
720737
return ds
721738

722739
# initialize grid from a grb file

flopy/mf6/mfmodel.py

Lines changed: 135 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,7 +1314,6 @@ def write(
13141314
ext_file_action=ExtFileAction.copy_relative_paths,
13151315
netcdf=None,
13161316
):
1317-
from ..version import __version__
13181317
"""
13191318
Writes out model's package files.
13201319
@@ -1360,7 +1359,10 @@ def write(
13601359
if write_netcdf:
13611360
# set data storage to write ascii for netcdf
13621361
pp._set_netcdf_storage()
1363-
1362+
if pp.package_type.startswith("dis"):
1363+
crs = pp.crs.get_data()
1364+
if crs is not None and self.modelgrid.crs is None:
1365+
self.modelgrid.crs = crs[0][1]
13641366
if (
13651367
self.simulation_data.verbosity_level.value
13661368
>= VerbosityLevel.normal.value
@@ -1373,35 +1375,11 @@ def write(
13731375

13741376
# write netcdf file
13751377
if write_netcdf and netcdf.lower() != "nofile":
1376-
mesh = netcdf
1377-
if mesh.upper() == "STRUCTURED":
1378-
mesh = None
1379-
1380-
ds = self.modelgrid.dataset(
1381-
modeltime=self.modeltime,
1382-
mesh=mesh,
1383-
)
1384-
1385-
nc_info = self.netcdf_info(mesh=mesh)
1386-
nc_info["attrs"]["title"] = f"{self.name.upper()} input"
1387-
nc_info["attrs"]["source"] = f"flopy {__version__}"
1388-
# :history = "first created 2025/8/21 9:46:2.909" ;
1389-
# :Conventions = "CF-1.11 UGRID-1.0" ;
1390-
ds = self.update_dataset(ds, netcdf_info=nc_info, mesh=mesh)
1391-
1392-
# write dataset to netcdf
1393-
fname = self.name_file.nc_filerecord.get_data()[0][0]
1394-
ds.to_netcdf(
1395-
os.path.join(self.model_ws, fname),
1396-
format="NETCDF4",
1397-
engine="netcdf4"
1398-
)
1399-
1378+
self._write_netcdf(mesh=netcdf)
14001379
if nc_fname is not None:
14011380
self.name_file.nc_filerecord = None
14021381

14031382

1404-
14051383
def get_grid_type(self):
14061384
"""
14071385
Return the type of grid used by model 'model_name' in simulation
@@ -2329,12 +2307,142 @@ def update_dataset(self, dataset, netcdf_info=None, mesh=None, update_data=True)
23292307
else:
23302308
nc_info = netcdf_info
23312309

2310+
if (
2311+
self.simulation.simulation_data.verbosity_level.value
2312+
>= VerbosityLevel.normal.value
2313+
):
2314+
print(f" updating model dataset...")
2315+
23322316
for a in nc_info["attrs"]:
23332317
dataset.attrs[a] = nc_info["attrs"][a]
23342318

23352319
# add all packages and update data
23362320
for p in self.packagelist:
23372321
# add package var to dataset
2322+
if (
2323+
self.simulation.simulation_data.verbosity_level.value
2324+
>= VerbosityLevel.normal.value
2325+
):
2326+
print(f" updating dataset for package {p._get_pname()}...")
23382327
dataset = p.update_dataset(dataset, mesh=mesh, update_data=update_data)
23392328

23402329
return dataset
2330+
2331+
def _write_netcdf(self, mesh=None):
2332+
import datetime
2333+
2334+
from ..version import __version__
2335+
if mesh is not None and mesh.upper() == "STRUCTURED":
2336+
mesh = None
2337+
2338+
encode = {}
2339+
for pp in self.packagelist:
2340+
if pp.package_type == "ncf":
2341+
encode["shuffle"] = pp.shuffle.get_data()
2342+
encode["deflate"] = pp.deflate.get_data()
2343+
encode["chunk_time"] = pp.chunk_time.get_data()
2344+
encode["chunk_face"] = pp.chunk_face.get_data()
2345+
encode["chunk_x"] = pp.chunk_x.get_data()
2346+
encode["chunk_y"] = pp.chunk_y.get_data()
2347+
encode["chunk_z"] = pp.chunk_z.get_data()
2348+
wkt = pp.wkt.get_data()
2349+
if wkt is not None:
2350+
wkt = wkt[0][1]
2351+
encode["wkt"] = wkt
2352+
2353+
if (
2354+
self.simulation.simulation_data.verbosity_level.value
2355+
>= VerbosityLevel.normal.value
2356+
):
2357+
print(f" creating model dataset...")
2358+
2359+
ds = self.modelgrid.dataset(
2360+
modeltime=self.modeltime,
2361+
mesh=mesh,
2362+
encoding=encode,
2363+
)
2364+
2365+
dt = datetime.datetime.now()
2366+
timestamp = dt.strftime("%m/%d/%Y %H:%M:%S")
2367+
2368+
nc_info = self.netcdf_info(mesh=mesh)
2369+
nc_info["attrs"]["title"] = f"{self.name.upper()} input"
2370+
nc_info["attrs"]["source"] = f"flopy {__version__}"
2371+
nc_info["attrs"]["history"] = f"first created {timestamp}"
2372+
if mesh is None:
2373+
nc_info["attrs"]["Conventions"] = "CF-1.11"
2374+
elif mesh.upper() is "LAYERED":
2375+
nc_info["attrs"]["Conventions"] = "CF-1.11 UGRID-1.0"
2376+
2377+
ds = self.update_dataset(
2378+
ds,
2379+
netcdf_info=nc_info,
2380+
mesh=mesh,
2381+
)
2382+
2383+
chunk = False
2384+
chunk_t = False
2385+
if mesh is None:
2386+
if (
2387+
"chunk_x" in encode
2388+
and encode["chunk_x"] is not None
2389+
and "chunk_y" in encode
2390+
and encode["chunk_y"] is not None
2391+
and "chunk_z" in encode
2392+
and encode["chunk_z"] is not None
2393+
):
2394+
chunk = True
2395+
elif mesh.upper() == "LAYERED":
2396+
if "chunk_face" in encode and encode["chunk_face"] is not None:
2397+
chunk = True
2398+
if "chunk_time" in encode and encode["chunk_time"] is not None:
2399+
chunk_t = True
2400+
2401+
base_encode = {}
2402+
if "deflate" in encode and encode["deflate"] is not None:
2403+
base_encode["zlib"] = True
2404+
base_encode["complevel"] = encode["deflate"]
2405+
if "shuffle" in encode and encode["deflate"] is not None:
2406+
base_encode["shuffle"] = True
2407+
2408+
encoding = {}
2409+
chunk_dims = {'time', 'nmesh_face', 'z', 'y', 'x'}
2410+
for varname, da in ds.data_vars.items():
2411+
dims = ds.data_vars[varname].dims
2412+
codes = dict(base_encode)
2413+
if (
2414+
not set(dims).issubset(chunk_dims)
2415+
or not chunk or not chunk_t
2416+
):
2417+
encoding[varname] = codes
2418+
continue
2419+
chunksizes = []
2420+
if "time" in dims:
2421+
chunksizes.append(encode["chunk_time"])
2422+
if mesh is None:
2423+
if "z" in dims:
2424+
chunksizes.append(encode["chunk_z"])
2425+
if "y" in dims:
2426+
chunksizes.append(encode["chunk_y"])
2427+
if "x" in dims:
2428+
chunksizes.append(encode["chunk_x"])
2429+
elif mesh.upper() == "LAYERED" and "nmesh_face" in dims:
2430+
chunksizes.append(encode["chunk_face"])
2431+
if len(chunksizes) > 0:
2432+
codes["chunksizes"] = chunksizes
2433+
encoding[varname] = codes
2434+
2435+
fname = self.name_file.nc_filerecord.get_data()[0][0]
2436+
2437+
if (
2438+
self.simulation.simulation_data.verbosity_level.value
2439+
>= VerbosityLevel.normal.value
2440+
):
2441+
print(f" writing NetCDF file {fname}...")
2442+
# write dataset to netcdf
2443+
ds.to_netcdf(
2444+
os.path.join(self.model_ws, fname),
2445+
format="NETCDF4",
2446+
engine="netcdf4",
2447+
encoding=encoding,
2448+
)

flopy/mf6/mfpackage.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3760,6 +3760,8 @@ def _data_shape(shape):
37603760

37613761
return dims_l
37623762

3763+
projection = "projection" in dataset.data_vars
3764+
37633765
last_path = ''
37643766
pitem = None
37653767
pdata = None
@@ -3774,6 +3776,15 @@ def _data_shape(shape):
37743776
dataset = dataset.assign(var_d)
37753777
for a in nc_info[v]["attrs"]:
37763778
dataset[varname].attrs[a] = nc_info[v]["attrs"][a]
3779+
if projection:
3780+
dims = dataset[varname].dims
3781+
if "nmesh_face" in dims or "nmesh_node" in dims:
3782+
dataset[varname].attrs["grid_mapping"] = "projection"
3783+
dataset[varname].attrs["coordinates"] = "mesh_face_x mesh_face_y"
3784+
elif mesh is None and len(dims) > 1:
3785+
# TODO don't set if lon / lat?
3786+
dataset[varname].attrs["grid_mapping"] = "projection"
3787+
dataset[varname].attrs["coordinates"] = "x y"
37773788

37783789
if update_data:
37793790
path = nc_info[v]["attrs"]["modflow_input"]

flopy/utils/datautil.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def reset_delimiter_used():
305305

306306
@staticmethod
307307
def split_data_line(line, external_file=False, delimiter_conf_length=15):
308+
no_split_keys = ['crs', 'wkt']
308309
if PyListUtil.line_num > delimiter_conf_length and PyListUtil.consistent_delim:
309310
# consistent delimiter has been found. continue using that
310311
# delimiter without doing further checks
@@ -358,7 +359,11 @@ def split_data_line(line, external_file=False, delimiter_conf_length=15):
358359
max_split_size = len(max_split_list)
359360
max_split_type = "combo"
360361

361-
if max_split_type is not None and max_split_size > 1:
362+
if (
363+
max_split_type is not None
364+
and max_split_size > 1
365+
and clean_line[0].lower() not in no_split_keys
366+
):
362367
clean_line = max_split_list
363368
if PyListUtil.line_num == 0:
364369
PyListUtil.delimiter_used = max_split_type

0 commit comments

Comments
 (0)