From e3e81febc29e1d01e6169c051f2ceb8b10513019 Mon Sep 17 00:00:00 2001 From: "Christopher M. Pierce" Date: Sun, 12 Jan 2025 22:13:09 -0800 Subject: [PATCH 1/9] add pre-commit file and workflow --- .github/workflows/formatting.yml | 32 ++++++++++++++++++++++++++++++++ .pre-commit-config.yaml | 14 ++++++++++++++ pyproject.toml | 7 ++++++- 3 files changed, 52 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/formatting.yml create mode 100644 .pre-commit-config.yaml diff --git a/.github/workflows/formatting.yml b/.github/workflows/formatting.yml new file mode 100644 index 0000000..930d9ee --- /dev/null +++ b/.github/workflows/formatting.yml @@ -0,0 +1,32 @@ +name: Pre-commit Checks + +on: + pull_request: + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install pre-commit + run: | + python -m pip install --upgrade pip + pip install pre-commit + + - name: Run pre-commit + run: pre-commit run --all-files + + # This step will show the exact files that were modified + - name: Check for modified files + run: | + if [[ -n "$(git status --porcelain)" ]]; then + echo "The following files were modified by pre-commit:" + git status --porcelain + exit 1 + fi \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..43f1504 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,14 @@ +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.8.1 + hooks: + # Run the linter + - id: ruff + args: [ --fix ] + # Run the formatter + - id: ruff-format +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: check-added-large-files + args: ['--maxkb=2000'] \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a9fb4ac..68cbfcd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,4 +21,9 @@ dependencies = [ easygdf = ["data/*"] [project.urls] -Homepage = "https://github.com/electronsandstuff/easygdf" \ No newline at end of file +Homepage = "https://github.com/electronsandstuff/easygdf" + +[tool.ruff] +line-length = 120 +indent-width = 4 +target-version = "py312" \ No newline at end of file From b58787306714bc23e0e7e67c04bb6c026aa85b6d Mon Sep 17 00:00:00 2001 From: "Christopher M. Pierce" Date: Sun, 12 Jan 2025 22:19:20 -0800 Subject: [PATCH 2/9] add `environment.yml` --- environment.yml | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 environment.yml diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..efd53a1 --- /dev/null +++ b/environment.yml @@ -0,0 +1,9 @@ +name: easygdf +channels: + - conda-forge + - defaults +dependencies: + - python=3.12 + - numpy + - pip + - pre-commit \ No newline at end of file From 1619dd4d9bd648e4086a9a97b67820de85fb4dbe Mon Sep 17 00:00:00 2001 From: "Christopher M. Pierce" Date: Sun, 12 Jan 2025 22:25:30 -0800 Subject: [PATCH 3/9] run formatter on all files --- examples/initial_distribution.py | 4 +- examples/minimal.py | 2 +- scripts/generate_test_file.py | 41 ++-- scripts/trim_test_files.py | 23 +-- src/easygdf/__init__.py | 61 +++++- src/easygdf/easygdf.py | 177 ++++++++++-------- src/easygdf/initial_distribution.py | 67 ++++++- src/easygdf/screens_touts.py | 111 ++++++++--- src/easygdf/utils.py | 1 + tests/test_easygdf.py | 212 +++++++++++++-------- tests/test_initial_distribution.py | 74 ++++---- tests/test_screens_touts.py | 277 ++++++++++++++++++++-------- tests/utils.py | 3 +- 13 files changed, 711 insertions(+), 342 deletions(-) diff --git a/examples/initial_distribution.py b/examples/initial_distribution.py index 87563c8..f3b1df4 100644 --- a/examples/initial_distribution.py +++ b/examples/initial_distribution.py @@ -4,7 +4,7 @@ # Save some data to an initial distribution file. Unspecified required values are autofilled for us easygdf.save_initial_distribution( "initial.gdf", - x=np.random.normal(size=(3, )), - GBx=np.random.normal(size=(3, )), + x=np.random.normal(size=(3,)), + GBx=np.random.normal(size=(3,)), t=np.random.random((3,)), ) diff --git a/examples/minimal.py b/examples/minimal.py index b56f9f1..d23fccd 100644 --- a/examples/minimal.py +++ b/examples/minimal.py @@ -5,7 +5,7 @@ blocks = [ {"name": "an array", "value": np.array([0, 1, 2, 3])}, {"name": "a string", "value": "Hello world!"}, - {"name": "a group", "value": 3.14, "children": [{"name": "child", "value": 1.0}]} + {"name": "a group", "value": 3.14, "children": [{"name": "child", "value": 1.0}]}, ] easygdf.save("minimal.gdf", blocks) diff --git a/scripts/generate_test_file.py b/scripts/generate_test_file.py index a73dd5f..892e21f 100644 --- a/scripts/generate_test_file.py +++ b/scripts/generate_test_file.py @@ -89,7 +89,7 @@ elif s == 8: NUMPY_TO_GDF[t] = GDF_INT64 else: - raise ValueError("Unable to autodetect GDF flag for numpy data type \"{0}\" with size {1} bytes".format(t, s)) + raise ValueError('Unable to autodetect GDF flag for numpy data type "{0}" with size {1} bytes'.format(t, s)) ######################################################################################################################## @@ -113,10 +113,10 @@ def get_header(magic_number=94325877, gdf_version=(1, 1)): ) -def get_block_header(name="", dtype=easygdf.GDF_NULL, single=True, array=False, group_begin=False, group_end=False, - size=0): - flag = dtype + single * GDF_SINGLE + array * GDF_ARRAY + group_begin * GDF_GROUP_BEGIN \ - + group_end * GDF_GROUP_END +def get_block_header( + name="", dtype=easygdf.GDF_NULL, single=True, array=False, group_begin=False, group_end=False, size=0 +): + flag = dtype + single * GDF_SINGLE + array * GDF_ARRAY + group_begin * GDF_GROUP_BEGIN + group_end * GDF_GROUP_END return struct.pack("16sii", bytes(name, "ascii"), flag, size) @@ -318,23 +318,32 @@ def get_normalize_screen_floats(): :return: Bytes string representing file """ f = get_header() - f += get_block_header(name="position", dtype=GDF_DOUBLE, single=True, array=False, - size=GDF_DTYPES_STRUCT[GDF_DOUBLE][1], group_begin=True) + f += get_block_header( + name="position", + dtype=GDF_DOUBLE, + single=True, + array=False, + size=GDF_DTYPES_STRUCT[GDF_DOUBLE][1], + group_begin=True, + ) f += struct.pack(GDF_DTYPES_STRUCT[GDF_DOUBLE][0], 0.0) - for k in ['ID', 'x', 'y', 'z', 'Bx', 'By', 'Bz', 't', 'm', 'q', 'nmacro', 'rmacro', 'rxy', 'G']: + for k in ["ID", "x", "y", "z", "Bx", "By", "Bz", "t", "m", "q", "nmacro", "rmacro", "rxy", "G"]: dtype = GDF_DOUBLE f += get_block_header(name=k, dtype=dtype, single=False, array=True, size=6 * GDF_DTYPES_STRUCT[dtype][1]) d = GDF_DTYPES_STRUCT[dtype][0] f += struct.pack(6 * d, 0, 1, 2, 3, 4, 5) - f += get_block_header(name="Particles", dtype=GDF_DOUBLE, single=True, array=False, - size=GDF_DTYPES_STRUCT[GDF_DOUBLE][1]) + f += get_block_header( + name="Particles", dtype=GDF_DOUBLE, single=True, array=False, size=GDF_DTYPES_STRUCT[GDF_DOUBLE][1] + ) f += struct.pack(GDF_DTYPES_STRUCT[GDF_DOUBLE][0], 0.0) - f += get_block_header(name="pCentral", dtype=GDF_DOUBLE, single=True, array=False, - size=GDF_DTYPES_STRUCT[GDF_DOUBLE][1]) + f += get_block_header( + name="pCentral", dtype=GDF_DOUBLE, single=True, array=False, size=GDF_DTYPES_STRUCT[GDF_DOUBLE][1] + ) f += struct.pack(GDF_DTYPES_STRUCT[GDF_DOUBLE][0], 0.0) - f += get_block_header(name="Charge", dtype=GDF_DOUBLE, single=True, array=False, - size=GDF_DTYPES_STRUCT[GDF_DOUBLE][1]) + f += get_block_header( + name="Charge", dtype=GDF_DOUBLE, single=True, array=False, size=GDF_DTYPES_STRUCT[GDF_DOUBLE][1] + ) f += struct.pack(GDF_DTYPES_STRUCT[GDF_DOUBLE][0], 0.0) f += get_block_header(name="", group_end=True) return f @@ -351,7 +360,7 @@ def write_file(path, b): data_files_path = "easygdf/tests/data" if __name__ == "__main__": write_file(os.path.join(data_files_path, "normalize_screen_floats.gdf"), get_normalize_screen_floats()) - ''' + """ write_file(os.path.join(data_files_path, "version_mismatch.gdf"), get_file_version_mismatch()) write_file(os.path.join(data_files_path, "wrong_magic_number.gdf"), get_file_wrong_magic_number()) write_file(os.path.join(data_files_path, "too_much_recursion.gdf"), get_file_too_much_recursion()) @@ -367,4 +376,4 @@ def write_file(path, b): write_file(os.path.join(data_files_path, "invalid_size_array.gdf"), get_file_invalid_array_size()) write_file(os.path.join(data_files_path, "nested_groups.gdf"), get_file_nested_group()) write_file(os.path.join(data_files_path, "null_array.gdf"), get_file_null_array()) - ''' + """ diff --git a/scripts/trim_test_files.py b/scripts/trim_test_files.py index 06ff007..45e029b 100644 --- a/scripts/trim_test_files.py +++ b/scripts/trim_test_files.py @@ -62,11 +62,11 @@ def round_sigfigs(x, sigfigs): omags -= 1.0 else: # elif np.all(np.isreal( mantissas )): - fixmsk = mantissas < 1.0, + fixmsk = (mantissas < 1.0,) mantissas[fixmsk] *= 10.0 omags[fixmsk] -= 1.0 - result = xsgn * np.around(mantissas, decimals=sigfigs - 1) * 10.0 ** omags + result = xsgn * np.around(mantissas, decimals=sigfigs - 1) * 10.0**omags if matrixflag: result = np.matrix(result, copy=False) @@ -103,17 +103,11 @@ def trim_screens_tout(): # Trim down the arrays to the correct number of particles particle_blocks = [] for b in trimmed_blocks: - new_block = { - "name": b["name"], - "param": b["param"], - "children": [] - } + new_block = {"name": b["name"], "param": b["param"], "children": []} for c in b["children"]: - new_block["children"].append({ - "name": c["name"], - "param": round_sigfigs(c["param"][:n_particles], 4), - "children": [] - }) + new_block["children"].append( + {"name": c["name"], "param": round_sigfigs(c["param"][:n_particles], 4), "children": []} + ) particle_blocks.append(new_block) d["blocks"] = particle_blocks @@ -150,10 +144,7 @@ def trim_initial_distribution(): trimmed_blocks = [] for b in d["blocks"]: if isinstance(b["param"], np.ndarray): - trimmed_blocks.append({ - "name": b["name"], - "param": round_sigfigs(b["param"][:n_particles], 4) - }) + trimmed_blocks.append({"name": b["name"], "param": round_sigfigs(b["param"][:n_particles], 4)}) d["blocks"] = trimmed_blocks # Save the file diff --git a/src/easygdf/__init__.py b/src/easygdf/__init__.py index 285aa1d..161b76b 100644 --- a/src/easygdf/__init__.py +++ b/src/easygdf/__init__.py @@ -1,8 +1,57 @@ -# This file is part of easygdf and is released under the BSD 3-clause license - -from .easygdf import GDF_ASCII, GDF_DOUBLE, GDF_FLOAT, GDF_INT8, GDF_INT16, GDF_INT32, GDF_INT64, GDF_NULL, GDF_UINT8 -from .easygdf import GDF_UINT16, GDF_UINT32, GDF_UINT64, GDF_UNDEFINED, GDF_NAME_LEN, GDF_MAGIC -from .easygdf import is_gdf, load, save +from .easygdf import ( + GDF_ASCII, + GDF_DOUBLE, + GDF_FLOAT, + GDF_INT8, + GDF_INT16, + GDF_INT32, + GDF_INT64, + GDF_NULL, + GDF_UINT8, + GDF_UINT16, + GDF_UINT32, + GDF_UINT64, + GDF_UNDEFINED, + GDF_NAME_LEN, + GDF_MAGIC, + is_gdf, + load, + save, +) from .initial_distribution import load_initial_distribution, save_initial_distribution from .screens_touts import load_screens_touts, save_screens_touts -from .utils import get_example_screen_tout_filename, get_example_initial_distribution, GDFError, GDFIOError +from .utils import ( + get_example_screen_tout_filename, + get_example_initial_distribution, + GDFError, + GDFIOError, +) + +__all__ = [ + "GDF_ASCII", + "GDF_DOUBLE", + "GDF_FLOAT", + "GDF_INT8", + "GDF_INT16", + "GDF_INT32", + "GDF_INT64", + "GDF_NULL", + "GDF_UINT8", + "GDF_UINT16", + "GDF_UINT32", + "GDF_UINT64", + "GDF_UNDEFINED", + "GDF_NAME_LEN", + "GDF_MAGIC", + "is_gdf", + "load", + "save", + "load_initial_distribution", + "save_initial_distribution", + "load_screens_touts", + "save_screens_touts", + "get_example_screen_tout_filename", + "get_example_initial_distribution", + "GDFError", + "GDFIOError", +] diff --git a/src/easygdf/easygdf.py b/src/easygdf/easygdf.py index 6bd8064..46919f1 100644 --- a/src/easygdf/easygdf.py +++ b/src/easygdf/easygdf.py @@ -86,7 +86,7 @@ elif s == 8: NUMPY_TO_GDF[t] = GDF_INT64 else: - raise ValueError("Unable to autodetect GDF flag for numpy data type \"{0}\" with size {1} bytes".format(t, s)) + raise ValueError('Unable to autodetect GDF flag for numpy data type "{0}" with size {1} bytes'.format(t, s)) # The bit masks for flags in the GDF header GDF_DTYPE = 255 @@ -122,7 +122,7 @@ def is_gdf(f): # Rewind again to read the magic number f.seek(0) - magic_number, = struct.unpack('i', f.read(4)) + (magic_number,) = struct.unpack("i", f.read(4)) return magic_number == GDF_MAGIC @@ -145,13 +145,13 @@ def load_blocks(f, level=0, max_recurse=16, max_block=1e6): blocks = [] # Loop over the blocks - for block_ind in range(int(max_block)): + for _ in range(int(max_block)): # Read the block's header header_raw = f.read(GDF_NAME_LEN + 8) # If no data came back and we are in the root group, then break. If this isn't root group, then something # went wrong. - if header_raw == b'': + if header_raw == b"": if level == 0: break else: @@ -159,7 +159,7 @@ def load_blocks(f, level=0, max_recurse=16, max_block=1e6): # Unpack the header block_name, block_type_flag, block_size = struct.unpack("{0}sii".format(GDF_NAME_LEN), header_raw) - block_name = block_name.split(b'\0', 1)[0].decode('ascii') + block_name = block_name.split(b"\0", 1)[0].decode("ascii") # Make a new empty block with the correct name block = {"name": block_name, "value": None, "children": []} @@ -168,14 +168,17 @@ def load_blocks(f, level=0, max_recurse=16, max_block=1e6): group_begin = bool(block_type_flag & GDF_GROUP_BEGIN) group_end = bool(block_type_flag & GDF_GROUP_END) if group_begin and group_end: - raise ValueError("Potentially invalid group flags in block " - "(\"group_begin\" = {0} \"group_end\" = {1}".format(group_begin, group_end)) + raise ValueError( + "Potentially invalid group flags in block " '("group_begin" = {0} "group_end" = {1}'.format( + group_begin, group_end + ) + ) # If this is a group end block, then break out of the loop. If this end block was encountered in root, then # something went wrong and throw an error if group_end: if level == 0: - raise ValueError("Encountered \"group end\" block in GDF file, but we were not inside of a group") + raise ValueError('Encountered "group end" block in GDF file, but we were not inside of a group') else: break @@ -191,15 +194,18 @@ def load_blocks(f, level=0, max_recurse=16, max_block=1e6): if dtype in GDF_DTYPES_STRUCT: # Confirm that the size is what we expect if block_size != GDF_DTYPES_STRUCT[dtype][1]: - raise ValueError("Potentially bad block size (expected {:d} bytes, got {:d} bytes)".format( - GDF_DTYPES_STRUCT[dtype][1], block_size)) + raise ValueError( + "Potentially bad block size (expected {:d} bytes, got {:d} bytes)".format( + GDF_DTYPES_STRUCT[dtype][1], block_size + ) + ) # Pull the data from the file and convert to the parameter - block["value"], = struct.unpack(GDF_DTYPES_STRUCT[dtype][0], f.read(block_size)) + (block["value"],) = struct.unpack(GDF_DTYPES_STRUCT[dtype][0], f.read(block_size)) # If it is a string, pull it out and decode elif dtype == GDF_ASCII: - block["value"] = f.read(block_size).split(b'\0', 1)[0].decode('ascii') + block["value"] = f.read(block_size).split(b"\0", 1)[0].decode("ascii") # If it is null, put a None object and fast forward through the file by the block size elif dtype == GDF_NULL: @@ -219,14 +225,14 @@ def load_blocks(f, level=0, max_recurse=16, max_block=1e6): if dtype in GDF_DTYPES_NUMPY: # Confirm that the size is what we expect (array is even multiple of type size) if block_size % GDF_DTYPES_NUMPY[dtype][1] != 0: - raise ValueError("Potentially bad block size in array (expected multiple of {:d} bytes," - " got {:d} bytes)".format(GDF_DTYPES_NUMPY[dtype][1], block_size)) + raise ValueError( + "Potentially bad block size in array (expected multiple of {:d} bytes," + " got {:d} bytes)".format(GDF_DTYPES_NUMPY[dtype][1], block_size) + ) # Pull the data from the file and convert to the parameter block["value"] = np.fromfile( - f, - dtype=GDF_DTYPES_NUMPY[dtype][0], - count=block_size // GDF_DTYPES_NUMPY[dtype][1] + f, dtype=GDF_DTYPES_NUMPY[dtype][0], count=block_size // GDF_DTYPES_NUMPY[dtype][1] ) # If it is null, then I don't know how to interpret it as an array so through an error @@ -243,16 +249,11 @@ def load_blocks(f, level=0, max_recurse=16, max_block=1e6): # Something went wrong (single and array are both true or both false) else: - raise ValueError("invalid block flags (\"single\" = {0}, \"array\" = {1})".format(single, array)) + raise ValueError('invalid block flags ("single" = {0}, "array" = {1})'.format(single, array)) # If we have children then recurse to get them if group_begin: - block["children"] = load_blocks( - f, - level=level + 1, - max_recurse=max_recurse, - max_block=max_block - ) + block["children"] = load_blocks(f, level=level + 1, max_recurse=max_recurse, max_block=max_block) # Add this block to the list blocks.append(block) @@ -301,19 +302,21 @@ def load(f, max_recurse=16, max_block=1e6): fh_raw = struct.unpack("2i{0}s{0}s8B".format(GDF_NAME_LEN), f.read(48)) ret = { "creation_time": datetime.datetime.fromtimestamp(fh_raw[1], tz=datetime.timezone.utc), - "creator": fh_raw[2].split(b'\0', 1)[0].decode('ascii'), - "destination": fh_raw[3].split(b'\0', 1)[0].decode('ascii'), + "creator": fh_raw[2].split(b"\0", 1)[0].decode("ascii"), + "destination": fh_raw[3].split(b"\0", 1)[0].decode("ascii"), "gdf_version": (fh_raw[4], fh_raw[5]), "creator_version": (fh_raw[6], fh_raw[7]), "destination_version": (fh_raw[8], fh_raw[9]), - "dummy": (fh_raw[10], fh_raw[11]) + "dummy": (fh_raw[10], fh_raw[11]), } # If the GDF version is too new, then give a warning v = ret["gdf_version"] if v[0] != 1 or v[1] != 1: - warnings.warn("Attempting to open GDF v{:d}.{:d} file. easygdf has only been tested on GDF v1.1 files. " - "Please report any issues to project maintainer at contact@chris-pierce.com".format(v[0], v[1])) + warnings.warn( + "Attempting to open GDF v{:d}.{:d} file. easygdf has only been tested on GDF v1.1 files. " + "Please report any issues to project maintainer at contact@chris-pierce.com".format(v[0], v[1]) + ) # Load all the groups and return ret["blocks"] = load_blocks(f, max_recurse=max_recurse, max_block=max_block) @@ -332,7 +335,7 @@ def save_blocks(f, blocks, level=0, max_recurse=16): """ # Check that blocks is really a list if not isinstance(blocks, list): - raise TypeError("Blocks must be a list, not a \"{0}\"".format(type(blocks))) + raise TypeError('Blocks must be a list, not a "{0}"'.format(type(blocks))) # If we have hit the max recursion depth throw an error if level >= max_recurse: @@ -351,17 +354,19 @@ def save_blocks(f, blocks, level=0, max_recurse=16): for key in user_block: # If the override isn't a valid part of the block, throw an error if key not in block: - raise ValueError("Invalid key in user provided block: \"{:s}\"".format(key)) + raise ValueError('Invalid key in user provided block: "{:s}"'.format(key)) # Check dtype when required if key == "name": if not isinstance(user_block[key], str): - raise TypeError("Block attribute \"name\" must be a string not " - "\"{0}\"".format(type(user_block[key]))) + raise TypeError( + 'Block attribute "name" must be a string not ' '"{0}"'.format(type(user_block[key])) + ) if key == "children": if not isinstance(user_block[key], list): - raise TypeError("Block attribute \"children\" must be a list not " - "\"{0}\"".format(type(user_block[key]))) + raise TypeError( + 'Block attribute "children" must be a list not ' '"{0}"'.format(type(user_block[key])) + ) # Override the header block[key] = user_block[key] @@ -378,19 +383,23 @@ def save_blocks(f, blocks, level=0, max_recurse=16): # If we are a numpy array then write it if isinstance(block["value"], np.ndarray): - if block["value"].dtype == np.dtype('int64'): + if block["value"].dtype == np.dtype("int64"): if (np.abs(block["value"]) > 0x7FFFFFFF).any(): idx = np.argmax(np.abs(block["value"])) - raise ValueError(f'An array element exceeds the range of int32 (max compatible GDF size). The ' - f'element at index {idx} had value {block["value"]}, but int32s must have a max ' - f'absolute value of 2,147,483,647.') + raise ValueError( + f'An array element exceeds the range of int32 (max compatible GDF size). The ' + f'element at index {idx} had value {block["value"]}, but int32s must have a max ' + f'absolute value of 2,147,483,647.' + ) bval = block["value"].astype(np.int32) - elif block["value"].dtype == np.dtype('uint64'): + elif block["value"].dtype == np.dtype("uint64"): if (block["value"] > 0xFFFFFFFF).any(): idx = np.argmax(np.abs(block["value"])) - raise ValueError(f'An array element exceeds the range of uint32 (max compatible GDF size). The ' - f'element at index {idx} had value {block["value"]}, but int32s must have a max ' - f'absolute value of 4,294,967,295.') + raise ValueError( + f'An array element exceeds the range of uint32 (max compatible GDF size). The ' + f'element at index {idx} had value {block["value"]}, but int32s must have a max ' + f'absolute value of 4,294,967,295.' + ) bval = block["value"].astype(np.uint32) else: bval = block["value"] @@ -401,7 +410,7 @@ def save_blocks(f, blocks, level=0, max_recurse=16): # Determine the data type and add it to the header dname = bval.dtype.name if dname not in NUMPY_TO_GDF: - raise TypeError("Cannot write numpy data type \"{0}\" to GDF file".format(dname)) + raise TypeError('Cannot write numpy data type "{0}" to GDF file'.format(dname)) block_type_flag += NUMPY_TO_GDF[dname] # Write the header and then write the numpy array to the file @@ -418,20 +427,28 @@ def save_blocks(f, blocks, level=0, max_recurse=16): if isinstance(block["value"], str): block_type_flag += GDF_ASCII block_size = len(block["value"]) - f.write(bname + struct.pack("ii{:d}s".format(block_size), block_type_flag, block_size, - bytes(block["value"], "ascii"))) + f.write( + bname + + struct.pack( + "ii{:d}s".format(block_size), block_type_flag, block_size, bytes(block["value"], "ascii") + ) + ) elif isinstance(block["value"], int): if block["value"] > 0: if abs(block["value"]) > 0xFFFFFFFF: - raise ValueError(f"Value exceeds range of 32-bit unsigned int (largest supported size in GDF). " - f"Value cannot exceed 4,294,967,295. Received {block['value']}") + raise ValueError( + f"Value exceeds range of 32-bit unsigned int (largest supported size in GDF). " + f"Value cannot exceed 4,294,967,295. Received {block['value']}" + ) block_type_flag += GDF_UINT32 block_size = 4 f.write(bname + struct.pack("iiI", block_type_flag, block_size, block["value"])) else: if abs(block["value"]) > 0x7FFFFFFF: - raise ValueError(f"Value exceeds range of 32-bit signed int (largest supported size in GDF). " - f"Absolute value cannot exceed 2,147,483,647. Received {block['value']}") + raise ValueError( + f"Value exceeds range of 32-bit signed int (largest supported size in GDF). " + f"Absolute value cannot exceed 2,147,483,647. Received {block['value']}" + ) block_type_flag += GDF_INT32 block_size = 4 f.write(bname + struct.pack("iii", block_type_flag, block_size, block["value"])) @@ -444,24 +461,29 @@ def save_blocks(f, blocks, level=0, max_recurse=16): block_size = 0 f.write(bname + struct.pack("ii", block_type_flag, block_size)) else: - raise TypeError("Cannot write data type \"{0}\" to GDF file".format(type(block["value"]))) + raise TypeError('Cannot write data type "{0}" to GDF file'.format(type(block["value"]))) # Recurse on the children of this block if len(block["children"]) != 0: - save_blocks( - f, - block["children"], - level=level + 1, - max_recurse=max_recurse - ) + save_blocks(f, block["children"], level=level + 1, max_recurse=max_recurse) # If we are not the root group, then write a group end block if level > 0: f.write(struct.pack("{0}sii".format(GDF_NAME_LEN), b"", GDF_NULL + GDF_GROUP_END, 0)) -def save(f, blocks=None, creation_time=None, creator="easygdf", destination="", gdf_version=(1, 1), - creator_version=(2, 0), destination_version=(0, 0), dummy=(0, 0), max_recurse=16): +def save( + f, + blocks=None, + creation_time=None, + creator="easygdf", + destination="", + gdf_version=(1, 1), + creator_version=(2, 0), + destination_version=(0, 0), + dummy=(0, 0), + max_recurse=16, +): """ Saves user provided data into a GDF file. Blocks are python dicts with the keys: name, value, children. Name must be a string that may be encoded as ASCII. Values may be an int, a float, a string, None, a bytes object, or a numpy @@ -498,12 +520,13 @@ def save(f, blocks=None, creation_time=None, creator="easygdf", destination="", ff, blocks=blocks, creation_time=creation_time, - creator=creator, destination=destination, + creator=creator, + destination=destination, gdf_version=gdf_version, creator_version=creator_version, destination_version=destination_version, dummy=dummy, - max_recurse=max_recurse + max_recurse=max_recurse, ) # Make sure we have an open file @@ -519,21 +542,23 @@ def save(f, blocks=None, creation_time=None, creator="easygdf", destination="", creation_time = int(datetime.datetime.timestamp(creation_time)) # Write the header - f.write(struct.pack( - "2i{0}s{0}s8B".format(GDF_NAME_LEN), - GDF_MAGIC, - creation_time, - bytes(creator, "ascii"), - bytes(destination, "ascii"), - gdf_version[0], - gdf_version[1], - creator_version[0], - creator_version[1], - destination_version[0], - destination_version[1], - dummy[0], - dummy[1], - )) + f.write( + struct.pack( + "2i{0}s{0}s8B".format(GDF_NAME_LEN), + GDF_MAGIC, + creation_time, + bytes(creator, "ascii"), + bytes(destination, "ascii"), + gdf_version[0], + gdf_version[1], + creator_version[0], + creator_version[1], + destination_version[0], + destination_version[1], + dummy[0], + dummy[1], + ) + ) # Save the root group and then recurse (inside function) save_blocks(f, blocks, max_recurse=max_recurse) diff --git a/src/easygdf/initial_distribution.py b/src/easygdf/initial_distribution.py index a835779..579c1ed 100644 --- a/src/easygdf/initial_distribution.py +++ b/src/easygdf/initial_distribution.py @@ -34,10 +34,34 @@ def load_initial_distribution(f, max_recurse=16, max_block=1e6): return out -def save_initial_distribution(f, x=None, y=None, z=None, GBx=None, GBy=None, GBz=None, Bx=None, By=None, Bz=None, - t=None, G=None, m=None, q=None, nmacro=None, rmacro=None, ID=None, creation_time=None, - creator="easygdf", destination="", gdf_version=(1, 1), creator_version=(2, 0), - destination_version=(0, 0), dummy=(0, 0), max_recurse=16, **kwargs): +def save_initial_distribution( + f, + x=None, + y=None, + z=None, + GBx=None, + GBy=None, + GBz=None, + Bx=None, + By=None, + Bz=None, + t=None, + G=None, + m=None, + q=None, + nmacro=None, + rmacro=None, + ID=None, + creation_time=None, + creator="easygdf", + destination="", + gdf_version=(1, 1), + creator_version=(2, 0), + destination_version=(0, 0), + dummy=(0, 0), + max_recurse=16, + **kwargs, +): """ Saves GPT compatible initial distribution file. All array objects must be the same length (IE the number of particles). If required values (either {x,y,z,GBx,GBy,GBz} or {x,y,z,Bx,By,Bz}) are missing, easyGDF will autofill @@ -72,8 +96,24 @@ def save_initial_distribution(f, x=None, y=None, z=None, GBx=None, GBy=None, GBz :return: None """ # Copy all array elements into dict for processing and get rid of Nones - data_raw = {"x": x, "y": y, "z": z, "GBx": GBx, "GBy": GBy, "GBz": GBz, "Bx": Bx, "By": By, "Bz": Bz, "t": t, - "G": G, "m": m, "q": q, "nmacro": nmacro, "rmacro": rmacro, "ID": ID} + data_raw = { + "x": x, + "y": y, + "z": z, + "GBx": GBx, + "GBy": GBy, + "GBz": GBz, + "Bx": Bx, + "By": By, + "Bz": Bz, + "t": t, + "G": G, + "m": m, + "q": q, + "nmacro": nmacro, + "rmacro": rmacro, + "ID": ID, + } data_raw.update(kwargs) data = {x: data_raw[x] for x in data_raw if data_raw[x] is not None} @@ -104,6 +144,15 @@ def save_initial_distribution(f, x=None, y=None, z=None, GBx=None, GBy=None, GBz blocks = [{"name": x, "value": data[x]} for x in data] # Save the blocks - easygdf.save(f, blocks, creation_time=creation_time, creator=creator, destination=destination, - gdf_version=gdf_version, creator_version=creator_version, destination_version=destination_version, - dummy=dummy, max_recurse=max_recurse) + easygdf.save( + f, + blocks, + creation_time=creation_time, + creator=creator, + destination=destination, + gdf_version=gdf_version, + creator_version=creator_version, + destination_version=destination_version, + dummy=dummy, + max_recurse=max_recurse, + ) diff --git a/src/easygdf/screens_touts.py b/src/easygdf/screens_touts.py index c2e96d2..5ef070f 100644 --- a/src/easygdf/screens_touts.py +++ b/src/easygdf/screens_touts.py @@ -17,7 +17,7 @@ def normalize_screen(screen): out = {} # Write out the list of keys to create array objects from (Warning: G and rxy need to go at end) - arr_keys = ['ID', 'x', 'y', 'z', 'Bx', 'By', 'Bz', 't', 'm', 'q', 'nmacro', 'rmacro', 'rxy', 'G'] + arr_keys = ["ID", "x", "y", "z", "Bx", "By", "Bz", "t", "m", "q", "nmacro", "rmacro", "rxy", "G"] # Handle the position element if "position" not in screen: @@ -63,10 +63,39 @@ def normalize_tout(tout): out = {} # Write out the list of array keys (Warning: G and rxy need to be at end of array) - par_keys = ['x', 'y', 'z', 'Bx', 'By', 'Bz', 'm', 'q', 'nmacro', 'rmacro', 'ID', 'fEx', 'fEy', 'fEz', - 'fBx', 'fBy', 'fBz', 'G', 'rxy'] - scat_keys = ['scat_x', 'scat_y', 'scat_z', 'scat_Qin', 'scat_Qout', 'scat_Qnet', 'scat_Ein', 'scat_Eout', - 'scat_Enet', 'scat_inp'] + par_keys = [ + "x", + "y", + "z", + "Bx", + "By", + "Bz", + "m", + "q", + "nmacro", + "rmacro", + "ID", + "fEx", + "fEy", + "fEz", + "fBx", + "fBy", + "fBz", + "G", + "rxy", + ] + scat_keys = [ + "scat_x", + "scat_y", + "scat_z", + "scat_Qin", + "scat_Qout", + "scat_Qnet", + "scat_Ein", + "scat_Eout", + "scat_Enet", + "scat_inp", + ] # Handle the time element if "time" not in tout: @@ -113,12 +142,32 @@ def normalize_tout(tout): return out -def save_screens_touts(f, screens=None, touts=None, logo="B&M-General Particle Tracer", scat_x=np.array([]), - scat_y=np.array([]), scat_z=np.array([]), scat_Qin=np.array([]), - scat_Qout=np.array([]), scat_Qnet=np.array([]), scat_Ein=np.array([]), - scat_Eout=np.array([]), scat_Enet=np.array([]), scat_inp=np.array([]), numderivs=0, - cputime=0.0, creation_time=None, creator="easygdf", destination="", gdf_version=(1, 1), - creator_version=(2, 0), destination_version=(0, 0), dummy=(0, 0), max_recurse=16): +def save_screens_touts( + f, + screens=None, + touts=None, + logo="B&M-General Particle Tracer", + scat_x=np.array([]), + scat_y=np.array([]), + scat_z=np.array([]), + scat_Qin=np.array([]), + scat_Qout=np.array([]), + scat_Qnet=np.array([]), + scat_Ein=np.array([]), + scat_Eout=np.array([]), + scat_Enet=np.array([]), + scat_inp=np.array([]), + numderivs=0, + cputime=0.0, + creation_time=None, + creator="easygdf", + destination="", + gdf_version=(1, 1), + creator_version=(2, 0), + destination_version=(0, 0), + dummy=(0, 0), + max_recurse=16, +): """ Saves user data into a file with the format of a GPT output. Signature is fully compatible with the output of the corresponding load function. Screens and touts are passed as a list of dicts with the following numpy arrays. @@ -174,9 +223,16 @@ def save_screens_touts(f, screens=None, touts=None, logo="B&M-General Particle T # Deal with the array arguments which must have the same dimensions arr_elems = { - "scat_x": scat_x, "scat_y": scat_y, "scat_z": scat_z, "scat_Qin": scat_Qin, "scat_Qout": scat_Qout, - "scat_Qnet": scat_Qnet, "scat_Ein": scat_Ein, "scat_Eout": scat_Eout, "scat_Enet": scat_Enet, - "scat_inp": scat_inp + "scat_x": scat_x, + "scat_y": scat_y, + "scat_z": scat_z, + "scat_Qin": scat_Qin, + "scat_Qout": scat_Qout, + "scat_Qnet": scat_Qnet, + "scat_Ein": scat_Ein, + "scat_Eout": scat_Eout, + "scat_Enet": scat_Enet, + "scat_inp": scat_inp, } target_len = max([arr_elems[x].size for x in arr_elems]) for key in arr_elems: @@ -191,7 +247,7 @@ def save_screens_touts(f, screens=None, touts=None, logo="B&M-General Particle T blocks = [ {"name": "@logo", "value": logo}, {"name": "numderivs", "value": numderivs}, - {"name": "cputime", "value": cputime} + {"name": "cputime", "value": cputime}, ] # Add all root elements to the blocks @@ -205,9 +261,9 @@ def save_screens_touts(f, screens=None, touts=None, logo="B&M-General Particle T # Create the block pos = nscreen.pop("position") - blocks.append({"name": "position", - "value": pos, - "children": [{"name": x, "value": nscreen[x]} for x in nscreen]}) + blocks.append( + {"name": "position", "value": pos, "children": [{"name": x, "value": nscreen[x]} for x in nscreen]} + ) # Go through each tout and add them for tout in touts: @@ -216,14 +272,21 @@ def save_screens_touts(f, screens=None, touts=None, logo="B&M-General Particle T # Create the block time = ntout.pop("time") - blocks.append({"name": "time", - "value": time, - "children": [{"name": x, "value": ntout[x]} for x in ntout]}) + blocks.append({"name": "time", "value": time, "children": [{"name": x, "value": ntout[x]} for x in ntout]}) # Write the blocks to disk - easygdf.save(f, blocks, creation_time=creation_time, creator=creator, destination=destination, - gdf_version=gdf_version, creator_version=creator_version, destination_version=destination_version, - dummy=dummy, max_recurse=max_recurse) + easygdf.save( + f, + blocks, + creation_time=creation_time, + creator=creator, + destination=destination, + gdf_version=gdf_version, + creator_version=creator_version, + destination_version=destination_version, + dummy=dummy, + max_recurse=max_recurse, + ) def load_screens_touts(f, max_recurse=16, max_block=1e6): diff --git a/src/easygdf/utils.py b/src/easygdf/utils.py index 32e5a54..d5f216e 100644 --- a/src/easygdf/utils.py +++ b/src/easygdf/utils.py @@ -22,6 +22,7 @@ def get_example_initial_distribution(): """ return str(files("easygdf").joinpath("data/initial_distribution.gdf")) + class GDFError(Exception): pass diff --git a/tests/test_easygdf.py b/tests/test_easygdf.py index 03278ea..1104d16 100644 --- a/tests/test_easygdf.py +++ b/tests/test_easygdf.py @@ -1,4 +1,8 @@ -import datetime, os, struct, tempfile, unittest +import datetime +import os +import struct +import tempfile +import unittest import numpy as np import easygdf @@ -8,7 +12,6 @@ class TestEasyGDFHelpers(unittest.TestCase): def test_is_gdf(self): # Attempt to open the file and use the method under test - with load_test_resource("data/test.gdf") as f: is_gdf = easygdf.is_gdf(f) self.assertTrue(is_gdf, True) @@ -21,9 +24,9 @@ def test_is_gdf(self): with tempfile.TemporaryDirectory() as temp_dir: # Create an empty file in the temp directory empty_file_path = os.path.join(temp_dir, "empty.gdf") - with open(empty_file_path, 'wb') as f: + with open(empty_file_path, "wb") as f: pass - + # Test the empty file self.assertFalse(easygdf.is_gdf(empty_file_path)) @@ -72,13 +75,13 @@ def test_load_simple_header(self): # Write out the expected header fh = { - 'creation_time': datetime.datetime(2020, 11, 25, 17, 34, 24, tzinfo=datetime.timezone.utc), - 'creator': 'ASCI2GDF', - 'destination': '', - 'gdf_version': (1, 1), - 'creator_version': (1, 0), - 'destination_version': (0, 0), - 'dummy': (0, 0) + "creation_time": datetime.datetime(2020, 11, 25, 17, 34, 24, tzinfo=datetime.timezone.utc), + "creator": "ASCI2GDF", + "destination": "", + "gdf_version": (1, 1), + "creator_version": (1, 0), + "destination_version": (0, 0), + "dummy": (0, 0), } # Compare with what we got @@ -96,13 +99,13 @@ def test_load_str_URI(self): # Write out the expected header fh = { - 'creation_time': datetime.datetime(2020, 11, 25, 17, 34, 24, tzinfo=datetime.timezone.utc), - 'creator': 'ASCI2GDF', - 'destination': '', - 'gdf_version': (1, 1), - 'creator_version': (1, 0), - 'destination_version': (0, 0), - 'dummy': (0, 0) + "creation_time": datetime.datetime(2020, 11, 25, 17, 34, 24, tzinfo=datetime.timezone.utc), + "creator": "ASCI2GDF", + "destination": "", + "gdf_version": (1, 1), + "creator_version": (1, 0), + "destination_version": (0, 0), + "dummy": (0, 0), } # Compare with what we got @@ -301,13 +304,13 @@ def test_save_simple(self): """ # Write the header expected for the reference file fh = { - 'creation_time': datetime.datetime(2020, 11, 25, 17, 34, 24, tzinfo=datetime.timezone.utc), - 'creator': 'ASCI2GDF', - 'destination': '', - 'gdf_version': (1, 1), - 'creator_version': (1, 0), - 'destination_version': (0, 0), - 'dummy': (0, 0) + "creation_time": datetime.datetime(2020, 11, 25, 17, 34, 24, tzinfo=datetime.timezone.utc), + "creator": "ASCI2GDF", + "destination": "", + "gdf_version": (1, 1), + "creator_version": (1, 0), + "destination_version": (0, 0), + "dummy": (0, 0), } # Write out the block data for the reference file @@ -356,13 +359,13 @@ def test_save_str_URI(self): """ # Write the header expected for the reference file fh = { - 'creation_time': datetime.datetime(2020, 11, 25, 17, 34, 24, tzinfo=datetime.timezone.utc), - 'creator': 'ASCI2GDF', - 'destination': '', - 'gdf_version': (1, 1), - 'creator_version': (1, 0), - 'destination_version': (0, 0), - 'dummy': (0, 0) + "creation_time": datetime.datetime(2020, 11, 25, 17, 34, 24, tzinfo=datetime.timezone.utc), + "creator": "ASCI2GDF", + "destination": "", + "gdf_version": (1, 1), + "creator_version": (1, 0), + "destination_version": (0, 0), + "dummy": (0, 0), } # Write out the block data for the reference file @@ -411,31 +414,40 @@ def test_save_all_dtypes(self): ref_blocks = [] # Dump all of the possible numpy array types into blocks - dtypes = ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64", - "float32", "float64"] + dtypes = ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64", "float32", "float64"] for dtype in dtypes: - ref_blocks.append({ - "name": "array_" + dtype, - "value": np.linspace(0, 5, 6, dtype=dtype), - }) + ref_blocks.append( + { + "name": "array_" + dtype, + "value": np.linspace(0, 5, 6, dtype=dtype), + } + ) # Add each single type - ref_blocks.append({ - "name": "single_str", - "value": "deadbeef", - }) - ref_blocks.append({ - "name": "single_int", - "value": 1729, - }) - ref_blocks.append({ - "name": "single_float", - "value": 3.14, - }) - ref_blocks.append({ - "name": "single_none", - "value": None, - }) + ref_blocks.append( + { + "name": "single_str", + "value": "deadbeef", + } + ) + ref_blocks.append( + { + "name": "single_int", + "value": 1729, + } + ) + ref_blocks.append( + { + "name": "single_float", + "value": 3.14, + } + ) + ref_blocks.append( + { + "name": "single_none", + "value": None, + } + ) # Save everything as a GDF file testfile = os.path.join(tempfile.gettempdir(), "save_all_dtypes.gdf") @@ -514,8 +526,9 @@ def test_header_write(self): :return: """ # Create a header to save - my_time = datetime.datetime.fromtimestamp(int(datetime.datetime.timestamp(datetime.datetime.now())), - tz=datetime.timezone.utc) + my_time = datetime.datetime.fromtimestamp( + int(datetime.datetime.timestamp(datetime.datetime.now())), tz=datetime.timezone.utc + ) fh = { "creation_time": my_time, "creator": "easygdf", @@ -544,21 +557,23 @@ def test_save_groups(self): """ # The reference blocks ref = [ - {"name": "A", - "value": 0, - "children": [ - {"name": "B", - "value": "string", - "children": [ - {"name": "C", - "value": 1.2, - "children": [ - {"name": "D", - "value": "another string", - "children": []} - ]} - ]} - ]} + { + "name": "A", + "value": 0, + "children": [ + { + "name": "B", + "value": "string", + "children": [ + { + "name": "C", + "value": 1.2, + "children": [{"name": "D", "value": "another string", "children": []}], + } + ], + } + ], + } ] # Write it and read it back @@ -608,16 +623,31 @@ def test_int_single_overflow(self): # Test overflowing the negative value with open(os.path.join(tempfile.gettempdir(), "save_int_single_overflow_1.gdf"), "wb") as f: with self.assertRaises(ValueError): - easygdf.save(f, [{'name': 'ID', 'value': -0x80000000, 'children': []}, ]) + easygdf.save( + f, + [ + {"name": "ID", "value": -0x80000000, "children": []}, + ], + ) # Confirm something bigger than the max int32 but smaller than the max uint32 doesn't overflow with open(os.path.join(tempfile.gettempdir(), "save_int_single_overflow_2.gdf"), "wb") as f: - easygdf.save(f, [{'name': 'ID', 'value': 0x80000000, 'children': []}, ]) + easygdf.save( + f, + [ + {"name": "ID", "value": 0x80000000, "children": []}, + ], + ) # Test overflowing the positive value with open(os.path.join(tempfile.gettempdir(), "save_int_single_overflow_3.gdf"), "wb") as f: with self.assertRaises(ValueError): - easygdf.save(f, [{'name': 'ID', 'value': 0x100000000, 'children': []}, ]) + easygdf.save( + f, + [ + {"name": "ID", "value": 0x100000000, "children": []}, + ], + ) def test_int_array_overflow(self): """ @@ -628,14 +658,22 @@ def test_int_array_overflow(self): # Test overflowing int32 with open(os.path.join(tempfile.gettempdir(), "save_int_array_overflow_1.gdf"), "wb") as f: with self.assertRaises(ValueError): - easygdf.save(f, [{'name': 'ID', 'value': np.array([0x80000000, 0, 0, 0], dtype=np.int64), - 'children': []}, ]) + easygdf.save( + f, + [ + {"name": "ID", "value": np.array([0x80000000, 0, 0, 0], dtype=np.int64), "children": []}, + ], + ) # Test overflowing int64 with open(os.path.join(tempfile.gettempdir(), "save_int_array_overflow_2.gdf"), "wb") as f: with self.assertRaises(ValueError): - easygdf.save(f, [{'name': 'ID', 'value': np.array([0x100000000, 0, 0, 0], dtype=np.uint64), - 'children': []}, ]) + easygdf.save( + f, + [ + {"name": "ID", "value": np.array([0x100000000, 0, 0, 0], dtype=np.uint64), "children": []}, + ], + ) class TestEasyGDFLoadSave(unittest.TestCase): @@ -666,11 +704,21 @@ def test_integer_casting(self): # Test conversion from int64 -> int32 test_file = os.path.join(tempfile.gettempdir(), "save_initial_distribution_test_integer_casting_1.gdf") with open(test_file, "wb") as f: - easygdf.save(f, [{'name': 'ID', 'value': np.zeros(32, dtype=np.int64), 'children': []}, ]) - self.assertEqual(easygdf.load_initial_distribution(test_file)['ID'].dtype, np.dtype('int32')) + easygdf.save( + f, + [ + {"name": "ID", "value": np.zeros(32, dtype=np.int64), "children": []}, + ], + ) + self.assertEqual(easygdf.load_initial_distribution(test_file)["ID"].dtype, np.dtype("int32")) # Test conversion from uint64 -> uint32 test_file = os.path.join(tempfile.gettempdir(), "save_initial_distribution_test_integer_casting_2.gdf") with open(test_file, "wb") as f: - easygdf.save(f, [{'name': 'ID', 'value': np.zeros(32, dtype=np.uint64), 'children': []}, ]) - self.assertEqual(easygdf.load_initial_distribution(test_file)['ID'].dtype, np.dtype('uint32')) \ No newline at end of file + easygdf.save( + f, + [ + {"name": "ID", "value": np.zeros(32, dtype=np.uint64), "children": []}, + ], + ) + self.assertEqual(easygdf.load_initial_distribution(test_file)["ID"].dtype, np.dtype("uint32")) diff --git a/tests/test_initial_distribution.py b/tests/test_initial_distribution.py index 71ec421..3e41a86 100644 --- a/tests/test_initial_distribution.py +++ b/tests/test_initial_distribution.py @@ -20,45 +20,45 @@ def setUp(self): :return: """ self.ref = { - 'x': np.array([3.389e-05, 2.314e-05, 3.261e-05, 3.295e-05, 3.559e-05]), - 'y': np.array([3.096e-08, 3.241e-08, 3.506e-08, 3.717e-08, 3.990e-08]), - 'z': np.array([0., 0., 0., 0., 0.]), - 'GBx': np.array([0., 0., 0., 0., 0.]), - 'GBy': np.array([0., 0., 0., 0., 0.]), - 'GBz': np.array([0., 0., 0., 0., 0.]), - 't': np.array([1.865e-11, -1.712e-11, -1.367e-11, -9.963e-12, -1.238e-11]), - 'q': np.array([-1.602e-19, -1.602e-19, -1.602e-19, -1.602e-19, -1.602e-19]), - 'nmacro': np.array([8.752, 8.752, 8.752, 8.752, 8.752]), - 'creation_time': datetime.datetime(2019, 8, 7, 20, 47, 1, tzinfo=datetime.timezone.utc), - 'creator': 'ASCI2GDF', - 'destination': '', - 'gdf_version': (1, 1), - 'creator_version': (1, 0), - 'destination_version': (0, 0), - 'dummy': (0, 0) + "x": np.array([3.389e-05, 2.314e-05, 3.261e-05, 3.295e-05, 3.559e-05]), + "y": np.array([3.096e-08, 3.241e-08, 3.506e-08, 3.717e-08, 3.990e-08]), + "z": np.array([0.0, 0.0, 0.0, 0.0, 0.0]), + "GBx": np.array([0.0, 0.0, 0.0, 0.0, 0.0]), + "GBy": np.array([0.0, 0.0, 0.0, 0.0, 0.0]), + "GBz": np.array([0.0, 0.0, 0.0, 0.0, 0.0]), + "t": np.array([1.865e-11, -1.712e-11, -1.367e-11, -9.963e-12, -1.238e-11]), + "q": np.array([-1.602e-19, -1.602e-19, -1.602e-19, -1.602e-19, -1.602e-19]), + "nmacro": np.array([8.752, 8.752, 8.752, 8.752, 8.752]), + "creation_time": datetime.datetime(2019, 8, 7, 20, 47, 1, tzinfo=datetime.timezone.utc), + "creator": "ASCI2GDF", + "destination": "", + "gdf_version": (1, 1), + "creator_version": (1, 0), + "destination_version": (0, 0), + "dummy": (0, 0), } self.ref2 = { - 'x': np.array([3.389e-05, 2.314e-05, 3.261e-05, 3.295e-05, 3.559e-05]), - 'y': np.array([3.096e-08, 3.241e-08, 3.506e-08, 3.717e-08, 3.990e-08]), - 'z': np.array([0., 0., 0., 0., 0.]), - 'Bx': np.array([0., 0., 0., 0., 0.]), - 'By': np.array([0., 0., 0., 0., 0.]), - 'Bz': np.array([0., 0., 0., 0., 0.]), - 't': np.array([1.865e-11, -1.712e-11, -1.367e-11, -9.963e-12, -1.238e-11]), - 'q': np.array([-1.602e-19, -1.602e-19, -1.602e-19, -1.602e-19, -1.602e-19]), - 'nmacro': np.array([8.752, 8.752, 8.752, 8.752, 8.752]), + "x": np.array([3.389e-05, 2.314e-05, 3.261e-05, 3.295e-05, 3.559e-05]), + "y": np.array([3.096e-08, 3.241e-08, 3.506e-08, 3.717e-08, 3.990e-08]), + "z": np.array([0.0, 0.0, 0.0, 0.0, 0.0]), + "Bx": np.array([0.0, 0.0, 0.0, 0.0, 0.0]), + "By": np.array([0.0, 0.0, 0.0, 0.0, 0.0]), + "Bz": np.array([0.0, 0.0, 0.0, 0.0, 0.0]), + "t": np.array([1.865e-11, -1.712e-11, -1.367e-11, -9.963e-12, -1.238e-11]), + "q": np.array([-1.602e-19, -1.602e-19, -1.602e-19, -1.602e-19, -1.602e-19]), + "nmacro": np.array([8.752, 8.752, 8.752, 8.752, 8.752]), "m": np.array([8.752, 8.752, 8.752, 8.752, 8.752]), - "G": np.array([1., 1., 1., 1., 1.]), - "ID": np.array([1., 2., 3., 4., 5.]), - 'rmacro': np.array([8.752, 8.752, 8.752, 8.752, 8.752]), - 'creation_time': datetime.datetime(2019, 8, 7, 20, 47, 1, tzinfo=datetime.timezone.utc), - 'creator': 'ASCI2GDF', - 'destination': '', - 'gdf_version': (1, 1), - 'creator_version': (1, 0), - 'destination_version': (0, 0), - 'dummy': (0, 0) + "G": np.array([1.0, 1.0, 1.0, 1.0, 1.0]), + "ID": np.array([1.0, 2.0, 3.0, 4.0, 5.0]), + "rmacro": np.array([8.752, 8.752, 8.752, 8.752, 8.752]), + "creation_time": datetime.datetime(2019, 8, 7, 20, 47, 1, tzinfo=datetime.timezone.utc), + "creator": "ASCI2GDF", + "destination": "", + "gdf_version": (1, 1), + "creator_version": (1, 0), + "destination_version": (0, 0), + "dummy": (0, 0), } def test_load(self): @@ -133,7 +133,7 @@ def test_save_length_normalization(self): all_data = easygdf.load_initial_distribution(f) # Check array lengths - arr_names = ['x', 'y', 'z', 'GBx', 'GBy', 'GBz'] + arr_names = ["x", "y", "z", "GBx", "GBy", "GBz"] for a in arr_names: self.assertEqual(all_data[a].size, 11) @@ -148,7 +148,7 @@ def test_save_length_normalization_B(self): all_data = easygdf.load_initial_distribution(f) # Check array lengths - arr_names = ['x', 'y', 'z', 'Bx', 'By', 'Bz'] + arr_names = ["x", "y", "z", "Bx", "By", "Bz"] for a in arr_names: self.assertEqual(all_data[a].size, 11) diff --git a/tests/test_screens_touts.py b/tests/test_screens_touts.py index 01d10f3..d18594a 100644 --- a/tests/test_screens_touts.py +++ b/tests/test_screens_touts.py @@ -1,4 +1,6 @@ -import os, tempfile, unittest +import os +import tempfile +import unittest import numpy as np import easygdf @@ -8,51 +10,74 @@ class TestEasyGDFScreensTouts(unittest.TestCase): def setUp(self): # Gross data-dump of a test file that we will check against - self.ref = {'screens': [{ - 'x': np.array([3.173e-05, 2.286e-05, 2.331e-05, 3.735e-05, 3.040e-05]), - 'y': np.array([1.553e-06, 1.502e-06, 1.586e-06, 3.577e-06, 3.277e-06]), - 'z': np.array([0.01, 0.01, 0.01, 0.01, 0.01]), - 'rxy': np.array([3.177e-05, 2.291e-05, 2.337e-05, 3.752e-05, 3.057e-05]), - 'Bx': np.array([1.697e-04, 1.010e-04, 9.218e-05, 1.604e-04, 1.269e-04]), - 'By': np.array([-1.008e-06, -9.377e-08, -2.000e-06, 1.504e-05, 9.206e-06]), - 'Bz': np.array([0.5728, 0.5728, 0.5728, 0.5728, 0.5728]), - 'G': np.array([1.22, 1.22, 1.22, 1.22, 1.22]), - 't': np.array([7.681e-11, 7.682e-11, 7.655e-11, 7.662e-11, 7.650e-11]), - 'm': np.array([9.11e-31, 9.11e-31, 9.11e-31, 9.11e-31, 9.11e-31]), - 'q': np.array([-1.602e-19, -1.602e-19, -1.602e-19, -1.602e-19, -1.602e-19]), - 'nmacro': np.array([8.752, 8.752, 8.752, 8.752, 8.752]), - 'rmacro': np.array([0., 0., 0., 0., 0.]), - 'ID': np.array([917., 1199., 1259., 1525., 1778.]), 'position': 0.01 - }], - 'touts': [{ - 'scat_x': np.array([], dtype=np.float64), 'scat_y': np.array([], dtype=np.float64), - 'scat_z': np.array([], dtype=np.float64), 'scat_Qin': np.array([], dtype=np.float64), - 'scat_Qout': np.array([], dtype=np.float64), 'scat_Qnet': np.array([], dtype=np.float64), - 'scat_Ein': np.array([], dtype=np.float64), 'scat_Eout': np.array([], dtype=np.float64), - 'scat_Enet': np.array([], dtype=np.float64), 'scat_inp': np.array([], dtype=np.float64), - 'x': np.array([2.382e-05, 3.331e-05, 3.331e-05, 3.626e-05, 1.664e-05]), - 'y': np.array([-1.266e-08, 4.958e-08, 5.084e-08, 8.397e-08, 4.197e-08]), - 'z': np.array([2.876e-04, 1.831e-04, 9.718e-05, 1.503e-04, 2.548e-05]), - 'G': np.array([1.006, 1.004, 1.002, 1.003, 1.001]), - 'Bx': np.array([2.782e-04, 3.723e-04, 3.003e-04, 4.009e-04, 3.904e-05]), - 'By': np.array([-1.447e-05, 3.594e-06, 9.107e-06, 2.116e-05, 8.227e-06]), - 'Bz': np.array([0.112, 0.08946, 0.0652, 0.08107, 0.03326]), - 'rxy': np.array([2.382e-05, 3.331e-05, 3.331e-05, 3.626e-05, 1.664e-05]), - 'm': np.array([9.11e-31, 9.11e-31, 9.11e-31, 9.11e-31, 9.11e-31]), - 'q': np.array([-1.602e-19, -1.602e-19, -1.602e-19, -1.602e-19, -1.602e-19]), - 'nmacro': np.array([8.752, 8.752, 8.752, 8.752, 8.752]), 'rmacro': np.array([0., 0., 0., 0., 0.]), - 'ID': np.array([2., 3., 4., 5., 6.]), 'fEx': np.array([-20340., -38070., -56560., -48510., -33310.]), - 'fEy': np.array([635.5, 737.7, -882.6, -1992., -1727.]), - 'fEz': np.array([-11250000., -11250000., -11250000., -11250000., -11190000.]), - 'fBx': np.array([-1.715e-07, -2.003e-07, 2.240e-07, 5.139e-07, 4.488e-07]), - 'fBy': np.array([-5.324e-06, -9.967e-06, -1.481e-05, -1.270e-05, -8.721e-06]), - 'fBz': np.array([4.428e-05, 4.423e-05, 4.420e-05, 4.422e-05, 4.417e-05]), 'time': 0.0 - }], 'logo': 'B&M-General Particle Tracer', 'scat_x': np.array([], dtype=np.float64), - 'scat_y': np.array([], dtype=np.float64), 'scat_z': np.array([], dtype=np.float64), - 'scat_Qin': np.array([], dtype=np.float64), 'scat_Qout': np.array([], dtype=np.float64), - 'scat_Qnet': np.array([], dtype=np.float64), 'scat_Ein': np.array([], dtype=np.float64), - 'scat_Eout': np.array([], dtype=np.float64), 'scat_Enet': np.array([], dtype=np.float64), - 'scat_inp': np.array([], dtype=np.float64), 'numderivs': 0, 'cputime': 6054.0} + self.ref = { + "screens": [ + { + "x": np.array([3.173e-05, 2.286e-05, 2.331e-05, 3.735e-05, 3.040e-05]), + "y": np.array([1.553e-06, 1.502e-06, 1.586e-06, 3.577e-06, 3.277e-06]), + "z": np.array([0.01, 0.01, 0.01, 0.01, 0.01]), + "rxy": np.array([3.177e-05, 2.291e-05, 2.337e-05, 3.752e-05, 3.057e-05]), + "Bx": np.array([1.697e-04, 1.010e-04, 9.218e-05, 1.604e-04, 1.269e-04]), + "By": np.array([-1.008e-06, -9.377e-08, -2.000e-06, 1.504e-05, 9.206e-06]), + "Bz": np.array([0.5728, 0.5728, 0.5728, 0.5728, 0.5728]), + "G": np.array([1.22, 1.22, 1.22, 1.22, 1.22]), + "t": np.array([7.681e-11, 7.682e-11, 7.655e-11, 7.662e-11, 7.650e-11]), + "m": np.array([9.11e-31, 9.11e-31, 9.11e-31, 9.11e-31, 9.11e-31]), + "q": np.array([-1.602e-19, -1.602e-19, -1.602e-19, -1.602e-19, -1.602e-19]), + "nmacro": np.array([8.752, 8.752, 8.752, 8.752, 8.752]), + "rmacro": np.array([0.0, 0.0, 0.0, 0.0, 0.0]), + "ID": np.array([917.0, 1199.0, 1259.0, 1525.0, 1778.0]), + "position": 0.01, + } + ], + "touts": [ + { + "scat_x": np.array([], dtype=np.float64), + "scat_y": np.array([], dtype=np.float64), + "scat_z": np.array([], dtype=np.float64), + "scat_Qin": np.array([], dtype=np.float64), + "scat_Qout": np.array([], dtype=np.float64), + "scat_Qnet": np.array([], dtype=np.float64), + "scat_Ein": np.array([], dtype=np.float64), + "scat_Eout": np.array([], dtype=np.float64), + "scat_Enet": np.array([], dtype=np.float64), + "scat_inp": np.array([], dtype=np.float64), + "x": np.array([2.382e-05, 3.331e-05, 3.331e-05, 3.626e-05, 1.664e-05]), + "y": np.array([-1.266e-08, 4.958e-08, 5.084e-08, 8.397e-08, 4.197e-08]), + "z": np.array([2.876e-04, 1.831e-04, 9.718e-05, 1.503e-04, 2.548e-05]), + "G": np.array([1.006, 1.004, 1.002, 1.003, 1.001]), + "Bx": np.array([2.782e-04, 3.723e-04, 3.003e-04, 4.009e-04, 3.904e-05]), + "By": np.array([-1.447e-05, 3.594e-06, 9.107e-06, 2.116e-05, 8.227e-06]), + "Bz": np.array([0.112, 0.08946, 0.0652, 0.08107, 0.03326]), + "rxy": np.array([2.382e-05, 3.331e-05, 3.331e-05, 3.626e-05, 1.664e-05]), + "m": np.array([9.11e-31, 9.11e-31, 9.11e-31, 9.11e-31, 9.11e-31]), + "q": np.array([-1.602e-19, -1.602e-19, -1.602e-19, -1.602e-19, -1.602e-19]), + "nmacro": np.array([8.752, 8.752, 8.752, 8.752, 8.752]), + "rmacro": np.array([0.0, 0.0, 0.0, 0.0, 0.0]), + "ID": np.array([2.0, 3.0, 4.0, 5.0, 6.0]), + "fEx": np.array([-20340.0, -38070.0, -56560.0, -48510.0, -33310.0]), + "fEy": np.array([635.5, 737.7, -882.6, -1992.0, -1727.0]), + "fEz": np.array([-11250000.0, -11250000.0, -11250000.0, -11250000.0, -11190000.0]), + "fBx": np.array([-1.715e-07, -2.003e-07, 2.240e-07, 5.139e-07, 4.488e-07]), + "fBy": np.array([-5.324e-06, -9.967e-06, -1.481e-05, -1.270e-05, -8.721e-06]), + "fBz": np.array([4.428e-05, 4.423e-05, 4.420e-05, 4.422e-05, 4.417e-05]), + "time": 0.0, + } + ], + "logo": "B&M-General Particle Tracer", + "scat_x": np.array([], dtype=np.float64), + "scat_y": np.array([], dtype=np.float64), + "scat_z": np.array([], dtype=np.float64), + "scat_Qin": np.array([], dtype=np.float64), + "scat_Qout": np.array([], dtype=np.float64), + "scat_Qnet": np.array([], dtype=np.float64), + "scat_Ein": np.array([], dtype=np.float64), + "scat_Eout": np.array([], dtype=np.float64), + "scat_Enet": np.array([], dtype=np.float64), + "scat_inp": np.array([], dtype=np.float64), + "numderivs": 0, + "cputime": 6054.0, + } def test_load_screens_touts(self): # Try to load the file @@ -60,8 +85,15 @@ def test_load_screens_touts(self): all_data = easygdf.load_screens_touts(f) # Ditch the header - header_names = ["creation_time", "creator", "destination", "gdf_version", "creator_version", - "destination_version", "dummy"] + header_names = [ + "creation_time", + "creator", + "destination", + "gdf_version", + "creator_version", + "destination_version", + "dummy", + ] for hn in header_names: all_data.pop(hn) @@ -109,8 +141,15 @@ def test_save_screens_touts(self): all_data = easygdf.load_screens_touts(f) # Ditch the header - header_names = ["creation_time", "creator", "destination", "gdf_version", "creator_version", - "destination_version", "dummy"] + header_names = [ + "creation_time", + "creator", + "destination", + "gdf_version", + "creator_version", + "destination_version", + "dummy", + ] for hn in header_names: all_data.pop(hn) @@ -158,8 +197,18 @@ def test_save_screens_touts_scatter(self): all_data = easygdf.load_screens_touts(f) # Confirm the lengths of the array objects - arr = ["scat_x", "scat_y", "scat_z", "scat_Qin", "scat_Qout", "scat_Qnet", "scat_Ein", "scat_Eout", - "scat_Enet", "scat_inp"] + arr = [ + "scat_x", + "scat_y", + "scat_z", + "scat_Qin", + "scat_Qout", + "scat_Qnet", + "scat_Ein", + "scat_Eout", + "scat_Enet", + "scat_inp", + ] for an in arr: self.assertEqual(all_data[an].shape, all_data["scat_Ein"].shape) @@ -180,25 +229,64 @@ def test_save_screens_touts_tout_arr(self): :return: """ # Make a tout to write - tout = {"scat_x": np.linspace(0, 1, 7), "x": np.linspace(0, 1, 11), "y": np.linspace(0, 1, 11), - "Bx": np.linspace(0, 0.1, 11), "By": np.linspace(0, 0.1, 11)} + tout = { + "scat_x": np.linspace(0, 1, 7), + "x": np.linspace(0, 1, 11), + "y": np.linspace(0, 1, 11), + "Bx": np.linspace(0, 0.1, 11), + "By": np.linspace(0, 0.1, 11), + } # Write it to the temp directory test_file = os.path.join(tempfile.gettempdir(), "save_screens_tout_tout_arr.gdf") with open(test_file, "wb") as f: - easygdf.save_screens_touts(f, touts=[tout, ]) + easygdf.save_screens_touts( + f, + touts=[ + tout, + ], + ) # Read it back with open(test_file, "rb") as f: all_data = easygdf.load_screens_touts(f) # Confirm that the tout has the correctly shaped arrays - arr = ["scat_x", "scat_y", "scat_z", "scat_Qin", "scat_Qout", "scat_Qnet", "scat_Ein", "scat_Eout", - "scat_Enet", "scat_inp"] + arr = [ + "scat_x", + "scat_y", + "scat_z", + "scat_Qin", + "scat_Qout", + "scat_Qnet", + "scat_Ein", + "scat_Eout", + "scat_Enet", + "scat_inp", + ] for an in arr: self.assertEqual(all_data["touts"][0][an].shape, all_data["touts"][0]["scat_x"].shape) - arr2 = ["x", "y", "z", "G", "Bx", "By", "Bz", "rxy", "m", "q", "nmacro", "rmacro", "ID", "fEx", "fEy", "fEz", - "fBx", "fBy", "fBz"] + arr2 = [ + "x", + "y", + "z", + "G", + "Bx", + "By", + "Bz", + "rxy", + "m", + "q", + "nmacro", + "rmacro", + "ID", + "fEx", + "fEy", + "fEz", + "fBx", + "fBy", + "fBz", + ] for an in arr2: self.assertEqual(all_data["touts"][0][an].shape, all_data["touts"][0]["x"].shape) @@ -220,7 +308,12 @@ def test_save_screens_touts_tout_arr_wrong_dim(self): test_file = os.path.join(tempfile.gettempdir(), "save_screens_tout_tout_arr_wrong_dim.gdf") with open(test_file, "wb") as f: with self.assertRaises(ValueError): - easygdf.save_screens_touts(f, touts=[tout, ]) + easygdf.save_screens_touts( + f, + touts=[ + tout, + ], + ) # Make a tout to write tout = {"scat_x": np.linspace(0, 1, 7), "scat_y": np.linspace(0, 1, 11)} @@ -229,17 +322,31 @@ def test_save_screens_touts_tout_arr_wrong_dim(self): test_file = os.path.join(tempfile.gettempdir(), "save_screens_tout_tout_arr_wrong_dim2.gdf") with open(test_file, "wb") as f: with self.assertRaises(ValueError): - easygdf.save_screens_touts(f, touts=[tout, ]) + easygdf.save_screens_touts( + f, + touts=[ + tout, + ], + ) # Make a tout to write - tout = {"scat_x": np.linspace(0, 1, 7), "scat_y": np.linspace(0, 1, 11), "x": np.linspace(0, 1, 11), - "y": np.linspace(0, 1, 7)} + tout = { + "scat_x": np.linspace(0, 1, 7), + "scat_y": np.linspace(0, 1, 11), + "x": np.linspace(0, 1, 11), + "y": np.linspace(0, 1, 7), + } # Write it to the temp directory test_file = os.path.join(tempfile.gettempdir(), "save_screens_tout_tout_arr_wrong_dim3.gdf") with open(test_file, "wb") as f: with self.assertRaises(ValueError): - easygdf.save_screens_touts(f, touts=[tout, ]) + easygdf.save_screens_touts( + f, + touts=[ + tout, + ], + ) def test_save_screens_touts_tout_aux_elem(self): """ @@ -247,12 +354,19 @@ def test_save_screens_touts_tout_aux_elem(self): :return: """ # Make a tout to write - tout = {"deadbeef": np.linspace(0, 1, 11), } + tout = { + "deadbeef": np.linspace(0, 1, 11), + } # Write it to the temp directory test_file = os.path.join(tempfile.gettempdir(), "save_screens_touts_tout_wrong_elem.gdf") with open(test_file, "wb") as f: - easygdf.save_screens_touts(f, touts=[tout, ]) + easygdf.save_screens_touts( + f, + touts=[ + tout, + ], + ) # Load it back and check we saved it all_data = easygdf.load_screens_touts(test_file) @@ -264,13 +378,22 @@ def test_save_screens_touts_screen_arr(self): :return: """ # Make a tout to write - screen = {"x": np.linspace(0, 1, 11), "y": np.linspace(0, 1, 11), - "Bx": np.linspace(0, 0.1, 11), "By": np.linspace(0, 0.1, 11)} + screen = { + "x": np.linspace(0, 1, 11), + "y": np.linspace(0, 1, 11), + "Bx": np.linspace(0, 0.1, 11), + "By": np.linspace(0, 0.1, 11), + } # Write it to the temp directory test_file = os.path.join(tempfile.gettempdir(), "save_screens_touts_screen_arr.gdf") with open(test_file, "wb") as f: - easygdf.save_screens_touts(f, screens=[screen, ]) + easygdf.save_screens_touts( + f, + screens=[ + screen, + ], + ) # Read it back with open(test_file, "rb") as f: @@ -299,7 +422,12 @@ def test_save_screens_touts_screen_arr_wrong_dim(self): test_file = os.path.join(tempfile.gettempdir(), "save_screens_touts_screen_arr_wrong_dim.gdf") with open(test_file, "wb") as f: with self.assertRaises(ValueError): - easygdf.save_screens_touts(f, screens=[screen, ]) + easygdf.save_screens_touts( + f, + screens=[ + screen, + ], + ) def test_save_screens_touts_screen_aux_elem(self): """ @@ -312,7 +440,12 @@ def test_save_screens_touts_screen_aux_elem(self): # Write it to the temp directory test_file = os.path.join(tempfile.gettempdir(), "save_screens_touts_tout_wrong_elem.gdf") with open(test_file, "wb") as f: - easygdf.save_screens_touts(f, screens=[screen, ]) + easygdf.save_screens_touts( + f, + screens=[ + screen, + ], + ) # Load it back and check we saved it all_data = easygdf.load_screens_touts(test_file) @@ -338,4 +471,4 @@ def test_normalize_screen_float(self): """ # Load a file with load_test_resource("data/normalize_screen_floats.gdf") as f: - all_data = easygdf.load_screens_touts(f) + easygdf.load_screens_touts(f) diff --git a/tests/utils.py b/tests/utils.py index a19ade1..49c6bd8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,5 @@ import os + def load_test_resource(path): - return open(os.path.join(os.path.dirname(os.path.realpath(__file__)), path), 'rb') + return open(os.path.join(os.path.dirname(os.path.realpath(__file__)), path), "rb") From 24603afb7675a5e8b0cb1d70b7fee2345fed7ca6 Mon Sep 17 00:00:00 2001 From: "Christopher M. Pierce" Date: Sun, 12 Jan 2025 22:32:17 -0800 Subject: [PATCH 4/9] cleaning up file headers / imports --- src/easygdf/easygdf.py | 15 ++------------- src/easygdf/initial_distribution.py | 2 -- src/easygdf/screens_touts.py | 2 -- src/easygdf/utils.py | 2 -- tests/test_initial_distribution.py | 2 -- tests/test_screens_touts.py | 2 +- 6 files changed, 3 insertions(+), 22 deletions(-) diff --git a/src/easygdf/easygdf.py b/src/easygdf/easygdf.py index 46919f1..ceb4d9b 100644 --- a/src/easygdf/easygdf.py +++ b/src/easygdf/easygdf.py @@ -1,20 +1,12 @@ -# This file is part of easygdf and is released under the BSD 3-clause license - -######################################################################################################################## -# Imports -######################################################################################################################## import datetime import io +import numpy as np import struct import warnings -import numpy as np - from .utils import GDFIOError -######################################################################################################################## -# GDF Specific Constants -######################################################################################################################## + # Define constants for the GDF specification GDF_NAME_LEN = 16 GDF_MAGIC = 94325877 @@ -96,9 +88,6 @@ GDF_ARRAY = 2048 -######################################################################################################################## -# Library functions start here -######################################################################################################################## def is_gdf(f): """ Determines if a file is GDF formatted or not. diff --git a/src/easygdf/initial_distribution.py b/src/easygdf/initial_distribution.py index 579c1ed..4fe498c 100644 --- a/src/easygdf/initial_distribution.py +++ b/src/easygdf/initial_distribution.py @@ -1,5 +1,3 @@ -# This file is part of easygdf and is released under the BSD 3-clause license - import numpy as np from . import easygdf diff --git a/src/easygdf/screens_touts.py b/src/easygdf/screens_touts.py index 5ef070f..58bc9ca 100644 --- a/src/easygdf/screens_touts.py +++ b/src/easygdf/screens_touts.py @@ -1,5 +1,3 @@ -# This file is part of easygdf and is released under the BSD 3-clause license - import numpy as np from . import easygdf diff --git a/src/easygdf/utils.py b/src/easygdf/utils.py index d5f216e..ee56122 100644 --- a/src/easygdf/utils.py +++ b/src/easygdf/utils.py @@ -1,5 +1,3 @@ -# This file is part of easygdf and is released under the BSD 3-clause license - from importlib.resources import files diff --git a/tests/test_initial_distribution.py b/tests/test_initial_distribution.py index 3e41a86..8cea882 100644 --- a/tests/test_initial_distribution.py +++ b/tests/test_initial_distribution.py @@ -1,5 +1,3 @@ -# This file is part of easygdf and is released under the BSD 3-clause license - import datetime import os import tempfile diff --git a/tests/test_screens_touts.py b/tests/test_screens_touts.py index d18594a..7cf556e 100644 --- a/tests/test_screens_touts.py +++ b/tests/test_screens_touts.py @@ -1,7 +1,7 @@ +import numpy as np import os import tempfile import unittest -import numpy as np import easygdf from .utils import load_test_resource From edc9b612c5e8de699d83ee9877ab82ad7166f91c Mon Sep 17 00:00:00 2001 From: "Christopher M. Pierce" Date: Sun, 12 Jan 2025 22:40:06 -0800 Subject: [PATCH 5/9] move constants to own file --- src/easygdf/constants.py | 88 +++++++++++++++++++++++++++++++++++ src/easygdf/easygdf.py | 99 ++++++++-------------------------------- 2 files changed, 106 insertions(+), 81 deletions(-) create mode 100644 src/easygdf/constants.py diff --git a/src/easygdf/constants.py b/src/easygdf/constants.py new file mode 100644 index 0000000..2ce492e --- /dev/null +++ b/src/easygdf/constants.py @@ -0,0 +1,88 @@ +import numpy as np + + +# Define constants for the GDF specification +GDF_NAME_LEN = 16 +GDF_MAGIC = 94325877 + + +# The GDF data type identifiers +GDF_ASCII = 0x0001 +GDF_DOUBLE = 0x0003 +GDF_FLOAT = 0x0090 +GDF_INT8 = 0x0030 +GDF_INT16 = 0x0050 +GDF_INT32 = 0x0002 +GDF_INT64 = 0x0080 +GDF_NULL = 0x0010 +GDF_UINT8 = 0x0020 +GDF_UINT16 = 0x0040 +GDF_UINT32 = 0x0060 +GDF_UINT64 = 0x0070 +GDF_UNDEFINED = 0x0000 + + +# Conversion from GDF types to information used by struct to convert into a python type. First element of the tuple is +# the identifier for conversion and the second element is the size required by struct (so we can double check the file) +GDF_DTYPES_STRUCT = { + GDF_DOUBLE: ("d", 8), + GDF_FLOAT: ("f", 4), + GDF_INT8: ("b", 1), + GDF_INT16: ("h", 2), + GDF_INT32: ("i", 4), + GDF_INT64: ("q", 8), + GDF_UINT8: ("B", 1), + GDF_UINT16: ("H", 2), + GDF_UINT32: ("I", 4), + GDF_UINT64: ("Q", 8), +} + + +# The same conversion, but for going to numpy data types +GDF_DTYPES_NUMPY = { + GDF_DOUBLE: (np.float64, 8), + GDF_FLOAT: (np.float32, 4), + GDF_INT8: (np.int8, 1), + GDF_INT16: (np.int16, 2), + GDF_INT32: (np.int32, 4), + GDF_INT64: (np.int64, 8), + GDF_UINT8: (np.uint8, 1), + GDF_UINT16: (np.uint16, 2), + GDF_UINT32: (np.uint32, 4), + GDF_UINT64: (np.uint64, 8), +} + + +# Going from numpy data types to GDF types +NUMPY_TO_GDF = { + "int8": GDF_INT8, + "int16": GDF_INT16, + "int32": GDF_INT32, + "int64": GDF_INT64, + "uint8": GDF_UINT8, + "uint16": GDF_UINT16, + "uint32": GDF_UINT32, + "uint64": GDF_UINT64, + "float_": GDF_DOUBLE, + "float32": GDF_FLOAT, + "float64": GDF_DOUBLE, +} + + +# Detect platform specific data types for numpy +for t in ["int_", "intc", "intp"]: + s = np.dtype(t).itemsize + if s == 4: + NUMPY_TO_GDF[t] = GDF_INT32 + elif s == 8: + NUMPY_TO_GDF[t] = GDF_INT64 + else: + raise ValueError('Unable to autodetect GDF flag for numpy data type "{0}" with size {1} bytes'.format(t, s)) + + +# The bit masks for flags in the GDF header +GDF_DTYPE = 255 +GDF_GROUP_BEGIN = 256 +GDF_GROUP_END = 512 +GDF_SINGLE = 1024 +GDF_ARRAY = 2048 diff --git a/src/easygdf/easygdf.py b/src/easygdf/easygdf.py index ceb4d9b..1577281 100644 --- a/src/easygdf/easygdf.py +++ b/src/easygdf/easygdf.py @@ -5,87 +5,24 @@ import warnings from .utils import GDFIOError - - -# Define constants for the GDF specification -GDF_NAME_LEN = 16 -GDF_MAGIC = 94325877 - -# The GDF data type identifiers -GDF_ASCII = 0x0001 -GDF_DOUBLE = 0x0003 -GDF_FLOAT = 0x0090 -GDF_INT8 = 0x0030 -GDF_INT16 = 0x0050 -GDF_INT32 = 0x0002 -GDF_INT64 = 0x0080 -GDF_NULL = 0x0010 -GDF_UINT8 = 0x0020 -GDF_UINT16 = 0x0040 -GDF_UINT32 = 0x0060 -GDF_UINT64 = 0x0070 -GDF_UNDEFINED = 0x0000 - -# Conversion from GDF types to information used by struct to convert into a python type. First element of the tuple is -# the identifier for conversion and the second element is the size required by struct (so we can double check the file) -GDF_DTYPES_STRUCT = { - GDF_DOUBLE: ("d", 8), - GDF_FLOAT: ("f", 4), - GDF_INT8: ("b", 1), - GDF_INT16: ("h", 2), - GDF_INT32: ("i", 4), - GDF_INT64: ("q", 8), - GDF_UINT8: ("B", 1), - GDF_UINT16: ("H", 2), - GDF_UINT32: ("I", 4), - GDF_UINT64: ("Q", 8), -} - -# The same conversion, but for going to numpy data types -GDF_DTYPES_NUMPY = { - GDF_DOUBLE: (np.float64, 8), - GDF_FLOAT: (np.float32, 4), - GDF_INT8: (np.int8, 1), - GDF_INT16: (np.int16, 2), - GDF_INT32: (np.int32, 4), - GDF_INT64: (np.int64, 8), - GDF_UINT8: (np.uint8, 1), - GDF_UINT16: (np.uint16, 2), - GDF_UINT32: (np.uint32, 4), - GDF_UINT64: (np.uint64, 8), -} - -# Going from numpy data types to GDF types -NUMPY_TO_GDF = { - "int8": GDF_INT8, - "int16": GDF_INT16, - "int32": GDF_INT32, - "int64": GDF_INT64, - "uint8": GDF_UINT8, - "uint16": GDF_UINT16, - "uint32": GDF_UINT32, - "uint64": GDF_UINT64, - "float_": GDF_DOUBLE, - "float32": GDF_FLOAT, - "float64": GDF_DOUBLE, -} - -# Detect platform specific data types for numpy -for t in ["int_", "intc", "intp"]: - s = np.dtype(t).itemsize - if s == 4: - NUMPY_TO_GDF[t] = GDF_INT32 - elif s == 8: - NUMPY_TO_GDF[t] = GDF_INT64 - else: - raise ValueError('Unable to autodetect GDF flag for numpy data type "{0}" with size {1} bytes'.format(t, s)) - -# The bit masks for flags in the GDF header -GDF_DTYPE = 255 -GDF_GROUP_BEGIN = 256 -GDF_GROUP_END = 512 -GDF_SINGLE = 1024 -GDF_ARRAY = 2048 +from .constants import ( + GDF_NAME_LEN, + GDF_MAGIC, + GDF_ASCII, + GDF_DOUBLE, + GDF_INT32, + GDF_NULL, + GDF_UINT32, + GDF_UNDEFINED, + GDF_DTYPES_STRUCT, + GDF_DTYPES_NUMPY, + NUMPY_TO_GDF, + GDF_DTYPE, + GDF_GROUP_BEGIN, + GDF_GROUP_END, + GDF_SINGLE, + GDF_ARRAY, +) def is_gdf(f): From 1e7f3e842362c1207ffddb7d1436158e221e2f7c Mon Sep 17 00:00:00 2001 From: "Christopher M. Pierce" Date: Sun, 12 Jan 2025 22:41:08 -0800 Subject: [PATCH 6/9] update imports --- src/easygdf/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/easygdf/__init__.py b/src/easygdf/__init__.py index 161b76b..cbceb9a 100644 --- a/src/easygdf/__init__.py +++ b/src/easygdf/__init__.py @@ -1,4 +1,4 @@ -from .easygdf import ( +from .constants import ( GDF_ASCII, GDF_DOUBLE, GDF_FLOAT, @@ -14,6 +14,8 @@ GDF_UNDEFINED, GDF_NAME_LEN, GDF_MAGIC, +) +from .easygdf import ( is_gdf, load, save, From 0a30d94419494ab441583cd9bfb19f98c19d4646 Mon Sep 17 00:00:00 2001 From: "Christopher M. Pierce" Date: Sun, 12 Jan 2025 22:54:12 -0800 Subject: [PATCH 7/9] break exceptions into own file --- src/easygdf/__init__.py | 2 ++ src/easygdf/easygdf.py | 2 +- src/easygdf/exceptions.py | 6 ++++++ src/easygdf/utils.py | 8 -------- 4 files changed, 9 insertions(+), 9 deletions(-) create mode 100644 src/easygdf/exceptions.py diff --git a/src/easygdf/__init__.py b/src/easygdf/__init__.py index cbceb9a..b9cfa27 100644 --- a/src/easygdf/__init__.py +++ b/src/easygdf/__init__.py @@ -25,6 +25,8 @@ from .utils import ( get_example_screen_tout_filename, get_example_initial_distribution, +) +from .exceptions import ( GDFError, GDFIOError, ) diff --git a/src/easygdf/easygdf.py b/src/easygdf/easygdf.py index 1577281..77229a0 100644 --- a/src/easygdf/easygdf.py +++ b/src/easygdf/easygdf.py @@ -4,7 +4,7 @@ import struct import warnings -from .utils import GDFIOError +from .exceptions import GDFIOError from .constants import ( GDF_NAME_LEN, GDF_MAGIC, diff --git a/src/easygdf/exceptions.py b/src/easygdf/exceptions.py new file mode 100644 index 0000000..5967260 --- /dev/null +++ b/src/easygdf/exceptions.py @@ -0,0 +1,6 @@ +class GDFError(Exception): + pass + + +class GDFIOError(GDFError): + pass diff --git a/src/easygdf/utils.py b/src/easygdf/utils.py index ee56122..5ddaa6f 100644 --- a/src/easygdf/utils.py +++ b/src/easygdf/utils.py @@ -19,11 +19,3 @@ def get_example_initial_distribution(): :return: Path to the example file """ return str(files("easygdf").joinpath("data/initial_distribution.gdf")) - - -class GDFError(Exception): - pass - - -class GDFIOError(GDFError): - pass From d50dde3acb05ea5aec5b4e93a44653165a9d0ec6 Mon Sep 17 00:00:00 2001 From: "Christopher M. Pierce" Date: Sun, 12 Jan 2025 22:57:24 -0800 Subject: [PATCH 8/9] move is_gdf --- src/easygdf/__init__.py | 2 +- src/easygdf/easygdf.py | 28 +--------------------------- src/easygdf/utils.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 32 insertions(+), 28 deletions(-) diff --git a/src/easygdf/__init__.py b/src/easygdf/__init__.py index b9cfa27..58772b0 100644 --- a/src/easygdf/__init__.py +++ b/src/easygdf/__init__.py @@ -16,7 +16,6 @@ GDF_MAGIC, ) from .easygdf import ( - is_gdf, load, save, ) @@ -25,6 +24,7 @@ from .utils import ( get_example_screen_tout_filename, get_example_initial_distribution, + is_gdf, ) from .exceptions import ( GDFError, diff --git a/src/easygdf/easygdf.py b/src/easygdf/easygdf.py index 77229a0..2eb71de 100644 --- a/src/easygdf/easygdf.py +++ b/src/easygdf/easygdf.py @@ -4,6 +4,7 @@ import struct import warnings +from .utils import is_gdf from .exceptions import GDFIOError from .constants import ( GDF_NAME_LEN, @@ -25,33 +26,6 @@ ) -def is_gdf(f): - """ - Determines if a file is GDF formatted or not. - - If binary file is passed, file will be at location four bytes from start after this function is run. - - :param f: filename or open file/stream-like object - :return: True/False whether the file is GDF formatted - """ - # If we were handed a string, then run this function on it with the file opened - if isinstance(f, str): - with open(f, "rb") as ff: - return is_gdf(ff) - - # Rewind the file to the beginning - f.seek(0) - - # Check if file has enough bytes to contain magic number - if len(f.read(4)) != 4: - return False - - # Rewind again to read the magic number - f.seek(0) - (magic_number,) = struct.unpack("i", f.read(4)) - return magic_number == GDF_MAGIC - - def load_blocks(f, level=0, max_recurse=16, max_block=1e6): """ Internal function. Recursively reads groups of blocks in the GDF file. Until group end or file end, block header diff --git a/src/easygdf/utils.py b/src/easygdf/utils.py index 5ddaa6f..a77d4a5 100644 --- a/src/easygdf/utils.py +++ b/src/easygdf/utils.py @@ -1,4 +1,7 @@ from importlib.resources import files +import struct + +from .constants import GDF_MAGIC def get_example_screen_tout_filename(): @@ -19,3 +22,30 @@ def get_example_initial_distribution(): :return: Path to the example file """ return str(files("easygdf").joinpath("data/initial_distribution.gdf")) + + +def is_gdf(f): + """ + Determines if a file is GDF formatted or not. + + If binary file is passed, file will be at location four bytes from start after this function is run. + + :param f: filename or open file/stream-like object + :return: True/False whether the file is GDF formatted + """ + # If we were handed a string, then run this function on it with the file opened + if isinstance(f, str): + with open(f, "rb") as ff: + return is_gdf(ff) + + # Rewind the file to the beginning + f.seek(0) + + # Check if file has enough bytes to contain magic number + if len(f.read(4)) != 4: + return False + + # Rewind again to read the magic number + f.seek(0) + (magic_number,) = struct.unpack("i", f.read(4)) + return magic_number == GDF_MAGIC From ee79d114633859c7c88c330b004136e97273683b Mon Sep 17 00:00:00 2001 From: "Christopher M. Pierce" Date: Sun, 12 Jan 2025 23:01:38 -0800 Subject: [PATCH 9/9] umcomment script --- scripts/generate_test_file.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/scripts/generate_test_file.py b/scripts/generate_test_file.py index 892e21f..a42c345 100644 --- a/scripts/generate_test_file.py +++ b/scripts/generate_test_file.py @@ -360,7 +360,6 @@ def write_file(path, b): data_files_path = "easygdf/tests/data" if __name__ == "__main__": write_file(os.path.join(data_files_path, "normalize_screen_floats.gdf"), get_normalize_screen_floats()) - """ write_file(os.path.join(data_files_path, "version_mismatch.gdf"), get_file_version_mismatch()) write_file(os.path.join(data_files_path, "wrong_magic_number.gdf"), get_file_wrong_magic_number()) write_file(os.path.join(data_files_path, "too_much_recursion.gdf"), get_file_too_much_recursion()) @@ -376,4 +375,3 @@ def write_file(path, b): write_file(os.path.join(data_files_path, "invalid_size_array.gdf"), get_file_invalid_array_size()) write_file(os.path.join(data_files_path, "nested_groups.gdf"), get_file_nested_group()) write_file(os.path.join(data_files_path, "null_array.gdf"), get_file_null_array()) - """