-
Notifications
You must be signed in to change notification settings - Fork 13
Add sanitization functions and netcdf export (with cleanup) to FileIO #173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,25 +34,149 @@ def __init__(self,xr_obj): | |
| def savePickle(self,filename): | ||
| with open(filename, 'wb') as file: | ||
| pickle.dump(self._obj, file) | ||
|
|
||
|
|
||
| def sanitize_attrs(xr_obj): | ||
| """ | ||
| Sanitize the attributes of an xarray object to make them JSON serializable, | ||
| handling deeply nested dictionaries, lists, and array-like objects. | ||
|
|
||
| Parameters: | ||
| xr_obj (xarray.DataArray or xarray.Dataset): The xarray object to sanitize. | ||
|
|
||
| Returns: | ||
| xarray.DataArray or xarray.Dataset: A copy of the input object with sanitized attributes. | ||
| """ | ||
| def sanitize_value(value): | ||
| """Recursively sanitize a value to ensure JSON serializability.""" | ||
| if isinstance(value, datetime): | ||
| return value.isoformat() # Convert datetime to ISO 8601 string | ||
| elif isinstance(value, np.ndarray): | ||
| return value.tolist() # Convert numpy arrays to lists | ||
| elif hasattr(value, "__array__"): # Handles other array-like objects | ||
| return np.asarray(value).tolist() | ||
| elif isinstance(value, dict): | ||
| # Recursively sanitize dictionary values | ||
| return {k: sanitize_value(v) for k, v in value.items()} | ||
| elif isinstance(value, list): | ||
| # Recursively sanitize list elements | ||
| return [sanitize_value(v) for v in value] | ||
| else: | ||
| try: | ||
| # Check if the value can be serialized to JSON | ||
| json.dumps(value) | ||
| return value | ||
| except (TypeError, OverflowError): | ||
| return None # Mark non-serializable values as None | ||
|
|
||
| sanitized_obj = xr_obj.copy() | ||
| sanitized_attrs = {} | ||
| dropped_attrs = {} | ||
|
|
||
| for key, value in sanitized_obj.attrs.items(): | ||
| sanitized_value = sanitize_value(value) | ||
| if sanitized_value is not None: | ||
| sanitized_attrs[key] = sanitized_value | ||
| else: | ||
| dropped_attrs[key] = value | ||
|
|
||
| sanitized_obj.attrs = sanitized_attrs | ||
|
|
||
| # Print or log a summary of the sanitized attributes | ||
| if dropped_attrs: | ||
| print("Dropped non-serializable attributes:") | ||
| for key, value in dropped_attrs.items(): | ||
| print(f" {key}: {type(value)} - {value}") | ||
| else: | ||
| print("No attributes were dropped.") | ||
|
|
||
| if sanitized_attrs: | ||
| print("\nConverted attributes:") | ||
| for key, value in sanitized_attrs.items(): | ||
| print(f" {key}: {type(value)} -> {value}") | ||
|
|
||
| return sanitized_obj | ||
| def make_attrs_netcdf_safe(xr_obj): | ||
| """ | ||
| Make the attributes of an xarray object safe for NetCDF by JSON-encoding | ||
| dictionaries and other complex data types. | ||
|
|
||
| Parameters: | ||
| xr_obj (xarray.DataArray or xarray.Dataset): The xarray object to process. | ||
|
|
||
| Returns: | ||
| xarray.DataArray or xarray.Dataset: A copy of the input object with NetCDF-safe attributes. | ||
| """ | ||
| def encode_complex(value): | ||
| """ | ||
| Encode complex data types (like dicts) into JSON strings. | ||
| """ | ||
| if isinstance(value, (dict, list, tuple)): | ||
| try: | ||
| # Convert to a JSON string | ||
| return json.dumps(value) | ||
| except (TypeError, OverflowError) as e: | ||
| # Handle unexpected cases gracefully | ||
| print(f"Error encoding attribute value: {value} ({e})") | ||
| return None | ||
| return value | ||
|
|
||
| sanitized_obj = xr_obj.copy() | ||
| encoded_attrs = {} | ||
|
|
||
| for key, value in sanitized_obj.attrs.items(): | ||
| encoded_value = encode_complex(value) | ||
| if encoded_value is not None: | ||
| encoded_attrs[key] = encoded_value | ||
| else: | ||
| print(f"Dropping unsupported attribute: {key} -> {value}") | ||
|
|
||
| sanitized_obj.attrs = encoded_attrs | ||
|
|
||
| return sanitized_obj | ||
|
|
||
| # - This was copied from the Toney group contribution for GIWAXS. | ||
| def saveZarr(self, filename, mode: str = 'w'): | ||
| """ | ||
| Save the DataArray as a .zarr file in a specific path, with a file name constructed from a prefix and suffix. | ||
| """ | ||
| Save the DataArray as a .zarr file in a specific path, with a file name constructed from a prefix and suffix. | ||
| Parameters: | ||
| da (xr.DataArray): The DataArray to be saved. | ||
| base_path (Union[str, pathlib.Path]): The base path to save the .zarr file. | ||
| prefix (str): The prefix to use for the file name. | ||
| suffix (str): The suffix to use for the file name. | ||
| mode (str): The mode to use when saving the file. Default is 'w'. | ||
| """ | ||
| da = self._obj | ||
| ds = da.to_dataset(name='DA') | ||
| ds = self.sanitize_attrs(ds) | ||
| # unstack any multiindexes on the array | ||
| if hasattr(da, "indexes"): | ||
| multiindexes = [dim for dim in da.indexes if isinstance(da.indexes[dim], xr.core.indexes.MultiIndex)] | ||
| da = da.unstack(multiindexes) if multiindexes else da | ||
|
|
||
| file_path = pathlib.Path(filename) | ||
| ds.to_zarr(file_path, mode=mode) | ||
| def saveNetCDF(self, filename): | ||
| """ | ||
| Save the DataArray as a netcdf file in a specific path, with a file name constructed from a prefix and suffix. | ||
|
|
||
| Parameters: | ||
| da (xr.DataArray): The DataArray to be saved. | ||
| base_path (Union[str, pathlib.Path]): The base path to save the .zarr file. | ||
| prefix (str): The prefix to use for the file name. | ||
| suffix (str): The suffix to use for the file name. | ||
| mode (str): The mode to use when saving the file. Default is 'w'. | ||
| """ | ||
| da = self._obj | ||
| ds = da.to_dataset(name='DA') | ||
| file_path = pathlib.Path(filename) | ||
| ds.to_zarr(file_path, mode=mode) | ||
|
|
||
| """ | ||
| da = self._obj | ||
| # sanitize attrs and make netcdf safe by converting dicts to json strings | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are there two functions which clean up attrs? Is there a case where they would be used separately?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NetCDF safe is a stricter standard than json serializable. |
||
| da = self.sanitize_attrs(da) | ||
| da = self.make_attrs_netcdf_safe(da) | ||
| # unstack any multiindexes on the array | ||
| if hasattr(da, "indexes"): | ||
| multiindexes = [dim for dim in da.indexes if isinstance(da.indexes[dim], xr.core.indexes.MultiIndex)] | ||
| da = da.unstack(multiindexes) if multiindexes else da | ||
| file_path = pathlib.Path(filename) | ||
| da.to_netcdf(file_path) | ||
|
|
||
| def saveNexus(self,fileName,compression=5): | ||
| data = self._obj | ||
| timestamp = datetime.datetime.now() | ||
|
|
@@ -309,4 +433,4 @@ def _make_coords(f): | |
| else: | ||
| coords[axes[n]] = f['entry']['sasdata'][axis] | ||
|
|
||
| return coords | ||
| return coords | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dislike the nested function, but Copilot seems to think it's justified so I'm not sure what to say. A general utility for jsonifying Python objects doesn't seems like a reasonable thing to put in a utils.py or even this in module.