diff --git a/docs/examples/units.ipynb b/docs/examples/units.ipynb index aef5dea..35beea2 100644 --- a/docs/examples/units.ipynb +++ b/docs/examples/units.ipynb @@ -218,7 +218,7 @@ "source": [ "x = 1e-4\n", "unit = \"m\"\n", - "nice_array(x)" + "nice_array(x, unit)" ] }, { @@ -229,7 +229,7 @@ }, "outputs": [], "source": [ - "nice_array([-0.01, 0.01])" + "nice_array([-0.01, 0.01], \"m\")" ] }, { @@ -254,6 +254,45 @@ "nice_scale_prefix(0.009)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prefixes are not applied to sqrt units\n", + "\n", + "When a unit involves a square root, the nice_array function will not add a prefix:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Regular unit - gets a prefix\n", + "nice_array(1e-6, \"m\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Sqrt unit - does NOT get a prefix\n", + "nice_array(1e-6, \"sqrt(m)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Another example with a different magnitude\n", + "nice_array(0.005, \"√(m)\")" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/pmd_beamphysics/labels.py b/pmd_beamphysics/labels.py index 8ecb461..9186f90 100644 --- a/pmd_beamphysics/labels.py +++ b/pmd_beamphysics/labels.py @@ -128,7 +128,7 @@ def texlabel(key: str): if key.startswith("bunching"): wavelength = parse_bunching_str(key) - x, _, prefix = nice_array(wavelength) + x, _, prefix = nice_array(wavelength, unit_symbol='m') return rf"\mathrm{{bunching~at}}~{x:.1f}~\mathrm{{ {prefix}m }}" return rf"\mathrm{{ {key} }}" diff --git a/pmd_beamphysics/plot.py b/pmd_beamphysics/plot.py index 1665aca..0f93340 100644 --- a/pmd_beamphysics/plot.py +++ b/pmd_beamphysics/plot.py @@ -106,7 +106,7 @@ def slice_plot( x -= particle_group["mean_" + slice_key] slice_key = "delta_" + slice_key # restore - x, f1, p1, xmin, xmax = plottable_array(x, nice=nice, lim=xlim) + x, f1, p1, xmin, xmax = plottable_array(x, nice=nice, lim=xlim, unit_symbol=particle_group.units(slice_key).unitSymbol) ux = p1 + str(particle_group.units(slice_key)) # Y-axis @@ -120,7 +120,7 @@ def slice_plot( ymin = max([slice_dat[k].min() for k in keys]) ymax = max([slice_dat[k].max() for k in keys]) - _, f2, p2, ymin, ymax = plottable_array(np.array([ymin, ymax]), nice=nice, lim=ylim) + _, f2, p2, ymin, ymax = plottable_array(np.array([ymin, ymax]), nice=nice, lim=ylim, unit_symbol=uy) uy = p2 + uy # Form Figure @@ -139,7 +139,8 @@ def slice_plot( ax.legend() # Density on r.h.s - y2, _, prey2, _, _ = plottable_array(slice_dat[y2_key], nice=nice, lim=None) + y2_unit_symbol = f"C/{particle_group.units(x_key).unitSymbol}" + y2, _, prey2, _, _ = plottable_array(slice_dat[y2_key], nice=nice, lim=None, unit_symbol=y2_unit_symbol) # Convert to Amps if possible y2_units = f"C/{particle_group.units(x_key)}" @@ -187,7 +188,7 @@ def density_plot( bins = int(n / 100) # Scale to nice units and get the factor, unit prefix - x, f1, p1, xmin, xmax = plottable_array(particle_group[key], nice=nice, lim=xlim) + x, f1, p1, xmin, xmax = plottable_array(particle_group[key], nice=nice, lim=xlim, unit_symbol=particle_group.units(key).unitSymbol) w = particle_group["weight"] u1 = particle_group.units(key).unitSymbol ux = p1 + u1 @@ -200,11 +201,13 @@ def density_plot( hist, bin_edges = np.histogram(x, bins=bins, weights=w) hist_x = bin_edges[:-1] + np.diff(bin_edges) / 2 hist_width = np.diff(bin_edges) - hist_y, hist_f, hist_prefix = nice_array(hist / hist_width) + # Unit for histogram is charge per unit of x-axis + hist_unit = f"C/{ux}" if u1 != "s" else "A" + hist_y, hist_f, hist_prefix = nice_array(hist / hist_width, unit_symbol=hist_unit) ax.bar(hist_x, hist_y, hist_width, color="grey") # Special label for C/s = A if u1 == "s": - _, hist_prefix = nice_scale_prefix(hist_f / f1) + _, hist_prefix = nice_scale_prefix(hist_f / f1, unit_symbol="A") ax.set_ylabel(f"{hist_prefix}A") else: ax.set_ylabel(f"{hist_prefix}C/{ux}") @@ -299,13 +302,12 @@ def marginal_plot( ylim = tuple(sorted((0.9 * y0, 1.1 * y0))) # Form nice arrays - x, f1, p1, xmin, xmax = plottable_array(x, nice=nice, lim=xlim) - y, f2, p2, ymin, ymax = plottable_array(y, nice=nice, lim=ylim) - - w = particle_group["weight"] - u1 = particle_group.units(key1).unitSymbol u2 = particle_group.units(key2).unitSymbol + x, f1, p1, xmin, xmax = plottable_array(x, nice=nice, lim=xlim, unit_symbol=u1) + y, f2, p2, ymin, ymax = plottable_array(y, nice=nice, lim=ylim, unit_symbol=u2) + + w = particle_group["weight"] ux = p1 + u1 uy = p2 + u2 @@ -363,11 +365,12 @@ def marginal_plot( hist, bin_edges = np.histogram(x, bins=bins, weights=w) hist_x = bin_edges[:-1] + np.diff(bin_edges) / 2 hist_width = np.diff(bin_edges) - hist_y, hist_f, hist_prefix = nice_array(hist / hist_width) + hist_unit = f"C/{ux}" if u1 != "s" else "A" + hist_y, hist_f, hist_prefix = nice_array(hist / hist_width, unit_symbol=hist_unit) ax_marg_x.bar(hist_x, hist_y, hist_width, color="gray") # Special label for C/s = A if u1 == "s": - _, hist_prefix = nice_scale_prefix(hist_f / f1) + _, hist_prefix = nice_scale_prefix(hist_f / f1, unit_symbol="A") ax_marg_x.set_ylabel(f"{hist_prefix}A") else: ax_marg_x.set_ylabel(f"{hist_prefix}" + mathlabel(f"C/{ux}")) # Always use tex @@ -379,7 +382,8 @@ def marginal_plot( hist, bin_edges = np.histogram(y, bins=bins, weights=w) hist_x = bin_edges[:-1] + np.diff(bin_edges) / 2 hist_width = np.diff(bin_edges) - hist_y, hist_f, hist_prefix = nice_array(hist / hist_width) + hist_unit = f"C/{uy}" if u2 != "s" else "A" + hist_y, hist_f, hist_prefix = nice_array(hist / hist_width, unit_symbol=hist_unit) ax_marg_y.barh(hist_x, hist_y, hist_width, color="gray") ax_marg_y.set_xlabel(f"{hist_prefix}" + mathlabel(f"C/{uy}")) # Always use tex @@ -422,12 +426,11 @@ def density_and_slice_plot( """ # Scale to nice units and get the factor, unit prefix - x, f1, p1, xmin, xmax = plottable_array(particle_group[key1]) - y, f2, p2, ymin, ymax = plottable_array(particle_group[key2]) - w = particle_group["weight"] - u1 = particle_group.units(key1).unitSymbol u2 = particle_group.units(key2).unitSymbol + x, f1, p1, xmin, xmax = plottable_array(particle_group[key1], unit_symbol=u1) + y, f2, p2, ymin, ymax = plottable_array(particle_group[key2], unit_symbol=u2) + w = particle_group["weight"] ux = p1 + u1 uy = p2 + u2 @@ -465,7 +468,7 @@ def density_and_slice_plot( max2 = max([np.ptp(slice_dat[k]) for k in stat_keys]) - f3, p3 = nice_scale_prefix(max2) + f3, p3 = nice_scale_prefix(max2, unit_symbol=ulist[0]) u2 = ulist[0] assert all([u == u2 for u in ulist]) @@ -889,7 +892,7 @@ def plot_fieldmesh_rectangular_2d( field_2d = interpolated_values.reshape(len(x), len(y)) if nice: - field_2d, _, prefix = nice_array(field_2d) + field_2d, _, prefix = nice_array(field_2d, unit_symbol=unit.unitSymbol) else: prefix = "" diff --git a/pmd_beamphysics/units.py b/pmd_beamphysics/units.py index ac1b5a5..3e4b653 100644 --- a/pmd_beamphysics/units.py +++ b/pmd_beamphysics/units.py @@ -440,7 +440,7 @@ def unit(symbol: str) -> pmd_unit: # Nice scaling -def nice_scale_prefix(scale: float) -> tuple[float, str]: +def nice_scale_prefix(scale: float, unit_symbol: str = None) -> tuple[float, str]: """ Returns a nice factor and an SI prefix string. @@ -448,6 +448,9 @@ def nice_scale_prefix(scale: float) -> tuple[float, str]: ---------- scale : float The scale to be converted into a nice factor and SI prefix. + unit_symbol : str, optional + The unit symbol to check for special cases. If provided, prefixes + will be suppressed for units containing sqrt symbols (√, sqrt, \\sqrt). Returns ------- @@ -460,7 +463,15 @@ def nice_scale_prefix(scale: float) -> tuple[float, str]: -------- >>> nice_scale_prefix(scale=2e-10) (1e-12, 'p') + + >>> nice_scale_prefix(scale=2e-10, unit_symbol='√m') + (1, '') """ + + # Check if unit contains sqrt symbols - suppress prefixes for these + if unit_symbol is not None: + if 'sqrt' in unit_symbol.lower() or '√' in unit_symbol or '\\sqrt' in unit_symbol: + return 1, "" if scale == 0: return 1, "" @@ -479,7 +490,7 @@ def nice_scale_prefix(scale: float) -> tuple[float, str]: return f, SHORT_PREFIX[f] -def nice_array(a: np.ndarray) -> tuple[np.ndarray, float, str]: +def nice_array(a: np.ndarray, unit_symbol: str = None) -> tuple[np.ndarray, float, str]: """ Scale an input array and return the scaled array, the scaling factor, and the corresponding unit prefix. @@ -488,6 +499,9 @@ def nice_array(a: np.ndarray) -> tuple[np.ndarray, float, str]: ---------- a : array-like, or float Input array to be scaled. + unit_symbol : str, optional + The unit symbol to check for special cases. If provided, prefixes + may be suppressed for certain unit types (e.g., sqrt units). Returns ------- @@ -512,11 +526,11 @@ def nice_array(a: np.ndarray) -> tuple[np.ndarray, float, str]: a = np.asarray(a) x = max(np.ptp(a), abs(np.mean(a))) # Account for tiny spread - fac, prefix = nice_scale_prefix(x) + fac, prefix = nice_scale_prefix(x, unit_symbol) return a / fac, fac, prefix -def plottable_array(x: np.ndarray, nice: bool = True, lim: Limit | None = None): +def plottable_array(x: np.ndarray, nice: bool = True, lim: Limit | None = None, unit_symbol: str = None): """ Similar to nice_array, but also considers limits for plotting @@ -525,7 +539,11 @@ def plottable_array(x: np.ndarray, nice: bool = True, lim: Limit | None = None): x: array-like nice: bool, default = True Scale array by some nice factor. - xlim: tuple, default = None + lim: tuple, default = None + Optional limits (min, max) + unit_symbol: str, optional + Unit symbol to check for special cases (e.g., sqrt). If provided, + prefixes may be suppressed for certain unit types. Returns ------- @@ -552,7 +570,7 @@ def plottable_array(x: np.ndarray, nice: bool = True, lim: Limit | None = None): xmax = x.max() if nice: - _, factor, p1 = nice_array([xmin, xmax]) + _, factor, p1 = nice_array([xmin, xmax], unit_symbol) else: factor, p1 = 1, ""