Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 134 additions & 10 deletions src/PyHyperScattering/FileIO.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,149 @@ def __init__(self,xr_obj):
def savePickle(self,filename):
with open(filename, 'wb') as file:
pickle.dump(self._obj, file)


def sanitize_attrs(xr_obj):
"""
Sanitize the attributes of an xarray object to make them JSON serializable,
handling deeply nested dictionaries, lists, and array-like objects.

Parameters:
xr_obj (xarray.DataArray or xarray.Dataset): The xarray object to sanitize.

Returns:
xarray.DataArray or xarray.Dataset: A copy of the input object with sanitized attributes.
"""
def sanitize_value(value):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dislike the nested function, but Copilot seems to think it's justified so I'm not sure what to say. A general utility for jsonifying Python objects doesn't seems like a reasonable thing to put in a utils.py or even this in module.

"""Recursively sanitize a value to ensure JSON serializability."""
if isinstance(value, datetime):
return value.isoformat() # Convert datetime to ISO 8601 string
elif isinstance(value, np.ndarray):
return value.tolist() # Convert numpy arrays to lists
elif hasattr(value, "__array__"): # Handles other array-like objects
return np.asarray(value).tolist()
elif isinstance(value, dict):
# Recursively sanitize dictionary values
return {k: sanitize_value(v) for k, v in value.items()}
elif isinstance(value, list):
# Recursively sanitize list elements
return [sanitize_value(v) for v in value]
else:
try:
# Check if the value can be serialized to JSON
json.dumps(value)
return value
except (TypeError, OverflowError):
return None # Mark non-serializable values as None

sanitized_obj = xr_obj.copy()
sanitized_attrs = {}
dropped_attrs = {}

for key, value in sanitized_obj.attrs.items():
sanitized_value = sanitize_value(value)
if sanitized_value is not None:
sanitized_attrs[key] = sanitized_value
else:
dropped_attrs[key] = value

sanitized_obj.attrs = sanitized_attrs

# Print or log a summary of the sanitized attributes
if dropped_attrs:
print("Dropped non-serializable attributes:")
for key, value in dropped_attrs.items():
print(f" {key}: {type(value)} - {value}")
else:
print("No attributes were dropped.")

if sanitized_attrs:
print("\nConverted attributes:")
for key, value in sanitized_attrs.items():
print(f" {key}: {type(value)} -> {value}")

return sanitized_obj
def make_attrs_netcdf_safe(xr_obj):
"""
Make the attributes of an xarray object safe for NetCDF by JSON-encoding
dictionaries and other complex data types.

Parameters:
xr_obj (xarray.DataArray or xarray.Dataset): The xarray object to process.

Returns:
xarray.DataArray or xarray.Dataset: A copy of the input object with NetCDF-safe attributes.
"""
def encode_complex(value):
"""
Encode complex data types (like dicts) into JSON strings.
"""
if isinstance(value, (dict, list, tuple)):
try:
# Convert to a JSON string
return json.dumps(value)
except (TypeError, OverflowError) as e:
# Handle unexpected cases gracefully
print(f"Error encoding attribute value: {value} ({e})")
return None
return value

sanitized_obj = xr_obj.copy()
encoded_attrs = {}

for key, value in sanitized_obj.attrs.items():
encoded_value = encode_complex(value)
if encoded_value is not None:
encoded_attrs[key] = encoded_value
else:
print(f"Dropping unsupported attribute: {key} -> {value}")

sanitized_obj.attrs = encoded_attrs

return sanitized_obj

# - This was copied from the Toney group contribution for GIWAXS.
def saveZarr(self, filename, mode: str = 'w'):
"""
Save the DataArray as a .zarr file in a specific path, with a file name constructed from a prefix and suffix.
"""
Save the DataArray as a .zarr file in a specific path, with a file name constructed from a prefix and suffix.
Parameters:
da (xr.DataArray): The DataArray to be saved.
base_path (Union[str, pathlib.Path]): The base path to save the .zarr file.
prefix (str): The prefix to use for the file name.
suffix (str): The suffix to use for the file name.
mode (str): The mode to use when saving the file. Default is 'w'.
"""
da = self._obj
ds = da.to_dataset(name='DA')
ds = self.sanitize_attrs(ds)
# unstack any multiindexes on the array
if hasattr(da, "indexes"):
multiindexes = [dim for dim in da.indexes if isinstance(da.indexes[dim], xr.core.indexes.MultiIndex)]
da = da.unstack(multiindexes) if multiindexes else da

file_path = pathlib.Path(filename)
ds.to_zarr(file_path, mode=mode)
def saveNetCDF(self, filename):
"""
Save the DataArray as a netcdf file in a specific path, with a file name constructed from a prefix and suffix.

Parameters:
da (xr.DataArray): The DataArray to be saved.
base_path (Union[str, pathlib.Path]): The base path to save the .zarr file.
prefix (str): The prefix to use for the file name.
suffix (str): The suffix to use for the file name.
mode (str): The mode to use when saving the file. Default is 'w'.
"""
da = self._obj
ds = da.to_dataset(name='DA')
file_path = pathlib.Path(filename)
ds.to_zarr(file_path, mode=mode)

"""
da = self._obj
# sanitize attrs and make netcdf safe by converting dicts to json strings
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are there two functions which clean up attrs? Is there a case where they would be used separately?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NetCDF safe is a stricter standard than json serializable.

da = self.sanitize_attrs(da)
da = self.make_attrs_netcdf_safe(da)
# unstack any multiindexes on the array
if hasattr(da, "indexes"):
multiindexes = [dim for dim in da.indexes if isinstance(da.indexes[dim], xr.core.indexes.MultiIndex)]
da = da.unstack(multiindexes) if multiindexes else da
file_path = pathlib.Path(filename)
da.to_netcdf(file_path)

def saveNexus(self,fileName,compression=5):
data = self._obj
timestamp = datetime.datetime.now()
Expand Down Expand Up @@ -309,4 +433,4 @@ def _make_coords(f):
else:
coords[axes[n]] = f['entry']['sasdata'][axis]

return coords
return coords
Loading