diff --git a/pmd_beamphysics/interfaces/cst.py b/pmd_beamphysics/interfaces/cst.py index 19ecb95..9222839 100644 --- a/pmd_beamphysics/interfaces/cst.py +++ b/pmd_beamphysics/interfaces/cst.py @@ -23,6 +23,7 @@ def get_scale(unit): The scaling factor corresponding to the unit. - 1e-3 for millimeters ('[mm]') - 1 for volts per meter ('[V/m]') + - 1 for Tesla ('[T]') - mu0 (vacuum permeability) for amperes per meter ('[A/m]') """ @@ -30,6 +31,8 @@ def get_scale(unit): return 1e-3 elif unit == "[V/m]": return 1 + elif unit == "[T]": + return 1 elif unit == "[A/m]": return mu0 @@ -63,9 +66,15 @@ def get_vec(x): sx = set(x) nx = len(sx) xlist = np.array(sorted(list(sx))) - dx = np.diff(xlist) - assert np.allclose(dx, dx[0]) - dx = dx[0] + if nx == 1: + dx = 0 + elif nx > 1: + dx = np.diff(xlist) + assert np.allclose(dx, dx[0]) + dx = dx[0] + else: + raise ValueError("Length of x vector was < 1!") + return min(x), max(x), dx, nx @@ -129,12 +138,20 @@ def read_cst_ascii_3d_field(filePath, n_header=2): # print(columns, units) field_columns = list( - set([c[:2] for c in columns if c.startswith("E") or c.startswith("H")]) + set( + [ + c[:2] + for c in columns + if c.startswith("E") or c.startswith("H") or c.startswith("B") + ] + ) ) if all([f.startswith("E") for f in field_columns]): field_type = "electric" - elif all([f.startswith("H") for f in field_columns]): + elif all([f.startswith("H") for f in field_columns]) or all( + [f.startswith("B") for f in field_columns] + ): field_type = "magnetic" else: raise ValueError("Mixed CST mode not curretly supported.")