Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions docs/examples/units.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@
"source": [
"x = 1e-4\n",
"unit = \"m\"\n",
"nice_array(x)"
"nice_array(x, unit)"
]
},
{
Expand All @@ -229,7 +229,7 @@
},
"outputs": [],
"source": [
"nice_array([-0.01, 0.01])"
"nice_array([-0.01, 0.01], \"m\")"
]
},
{
Expand All @@ -254,6 +254,45 @@
"nice_scale_prefix(0.009)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prefixes are not applied to sqrt units\n",
"\n",
"When a unit involves a square root, the nice_array function will not add a prefix:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Regular unit - gets a prefix\n",
"nice_array(1e-6, \"m\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Sqrt unit - does NOT get a prefix\n",
"nice_array(1e-6, \"sqrt(m)\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Another example with a different magnitude\n",
"nice_array(0.005, \"√(m)\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
2 changes: 1 addition & 1 deletion pmd_beamphysics/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def texlabel(key: str):

if key.startswith("bunching"):
wavelength = parse_bunching_str(key)
x, _, prefix = nice_array(wavelength)
x, _, prefix = nice_array(wavelength, unit_symbol='m')
return rf"\mathrm{{bunching~at}}~{x:.1f}~\mathrm{{ {prefix}m }}"

return rf"\mathrm{{ {key} }}"
Expand Down
43 changes: 23 additions & 20 deletions pmd_beamphysics/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def slice_plot(
x -= particle_group["mean_" + slice_key]
slice_key = "delta_" + slice_key # restore

x, f1, p1, xmin, xmax = plottable_array(x, nice=nice, lim=xlim)
x, f1, p1, xmin, xmax = plottable_array(x, nice=nice, lim=xlim, unit_symbol=particle_group.units(slice_key).unitSymbol)
ux = p1 + str(particle_group.units(slice_key))

# Y-axis
Expand All @@ -120,7 +120,7 @@ def slice_plot(
ymin = max([slice_dat[k].min() for k in keys])
ymax = max([slice_dat[k].max() for k in keys])

_, f2, p2, ymin, ymax = plottable_array(np.array([ymin, ymax]), nice=nice, lim=ylim)
_, f2, p2, ymin, ymax = plottable_array(np.array([ymin, ymax]), nice=nice, lim=ylim, unit_symbol=uy)
uy = p2 + uy

# Form Figure
Expand All @@ -139,7 +139,8 @@ def slice_plot(
ax.legend()

# Density on r.h.s
y2, _, prey2, _, _ = plottable_array(slice_dat[y2_key], nice=nice, lim=None)
y2_unit_symbol = f"C/{particle_group.units(x_key).unitSymbol}"
y2, _, prey2, _, _ = plottable_array(slice_dat[y2_key], nice=nice, lim=None, unit_symbol=y2_unit_symbol)

# Convert to Amps if possible
y2_units = f"C/{particle_group.units(x_key)}"
Expand Down Expand Up @@ -187,7 +188,7 @@ def density_plot(
bins = int(n / 100)

# Scale to nice units and get the factor, unit prefix
x, f1, p1, xmin, xmax = plottable_array(particle_group[key], nice=nice, lim=xlim)
x, f1, p1, xmin, xmax = plottable_array(particle_group[key], nice=nice, lim=xlim, unit_symbol=particle_group.units(key).unitSymbol)
w = particle_group["weight"]
u1 = particle_group.units(key).unitSymbol
ux = p1 + u1
Expand All @@ -200,11 +201,13 @@ def density_plot(
hist, bin_edges = np.histogram(x, bins=bins, weights=w)
hist_x = bin_edges[:-1] + np.diff(bin_edges) / 2
hist_width = np.diff(bin_edges)
hist_y, hist_f, hist_prefix = nice_array(hist / hist_width)
# Unit for histogram is charge per unit of x-axis
hist_unit = f"C/{ux}" if u1 != "s" else "A"
hist_y, hist_f, hist_prefix = nice_array(hist / hist_width, unit_symbol=hist_unit)
ax.bar(hist_x, hist_y, hist_width, color="grey")
# Special label for C/s = A
if u1 == "s":
_, hist_prefix = nice_scale_prefix(hist_f / f1)
_, hist_prefix = nice_scale_prefix(hist_f / f1, unit_symbol="A")
ax.set_ylabel(f"{hist_prefix}A")
else:
ax.set_ylabel(f"{hist_prefix}C/{ux}")
Expand Down Expand Up @@ -299,13 +302,12 @@ def marginal_plot(
ylim = tuple(sorted((0.9 * y0, 1.1 * y0)))

# Form nice arrays
x, f1, p1, xmin, xmax = plottable_array(x, nice=nice, lim=xlim)
y, f2, p2, ymin, ymax = plottable_array(y, nice=nice, lim=ylim)

w = particle_group["weight"]

u1 = particle_group.units(key1).unitSymbol
u2 = particle_group.units(key2).unitSymbol
x, f1, p1, xmin, xmax = plottable_array(x, nice=nice, lim=xlim, unit_symbol=u1)
y, f2, p2, ymin, ymax = plottable_array(y, nice=nice, lim=ylim, unit_symbol=u2)

w = particle_group["weight"]
ux = p1 + u1
uy = p2 + u2

Expand Down Expand Up @@ -363,11 +365,12 @@ def marginal_plot(
hist, bin_edges = np.histogram(x, bins=bins, weights=w)
hist_x = bin_edges[:-1] + np.diff(bin_edges) / 2
hist_width = np.diff(bin_edges)
hist_y, hist_f, hist_prefix = nice_array(hist / hist_width)
hist_unit = f"C/{ux}" if u1 != "s" else "A"
hist_y, hist_f, hist_prefix = nice_array(hist / hist_width, unit_symbol=hist_unit)
ax_marg_x.bar(hist_x, hist_y, hist_width, color="gray")
# Special label for C/s = A
if u1 == "s":
_, hist_prefix = nice_scale_prefix(hist_f / f1)
_, hist_prefix = nice_scale_prefix(hist_f / f1, unit_symbol="A")
ax_marg_x.set_ylabel(f"{hist_prefix}A")
else:
ax_marg_x.set_ylabel(f"{hist_prefix}" + mathlabel(f"C/{ux}")) # Always use tex
Expand All @@ -379,7 +382,8 @@ def marginal_plot(
hist, bin_edges = np.histogram(y, bins=bins, weights=w)
hist_x = bin_edges[:-1] + np.diff(bin_edges) / 2
hist_width = np.diff(bin_edges)
hist_y, hist_f, hist_prefix = nice_array(hist / hist_width)
hist_unit = f"C/{uy}" if u2 != "s" else "A"
hist_y, hist_f, hist_prefix = nice_array(hist / hist_width, unit_symbol=hist_unit)
ax_marg_y.barh(hist_x, hist_y, hist_width, color="gray")
ax_marg_y.set_xlabel(f"{hist_prefix}" + mathlabel(f"C/{uy}")) # Always use tex

Expand Down Expand Up @@ -422,12 +426,11 @@ def density_and_slice_plot(
"""

# Scale to nice units and get the factor, unit prefix
x, f1, p1, xmin, xmax = plottable_array(particle_group[key1])
y, f2, p2, ymin, ymax = plottable_array(particle_group[key2])
w = particle_group["weight"]

u1 = particle_group.units(key1).unitSymbol
u2 = particle_group.units(key2).unitSymbol
x, f1, p1, xmin, xmax = plottable_array(particle_group[key1], unit_symbol=u1)
y, f2, p2, ymin, ymax = plottable_array(particle_group[key2], unit_symbol=u2)
w = particle_group["weight"]
ux = p1 + u1
uy = p2 + u2

Expand Down Expand Up @@ -465,7 +468,7 @@ def density_and_slice_plot(

max2 = max([np.ptp(slice_dat[k]) for k in stat_keys])

f3, p3 = nice_scale_prefix(max2)
f3, p3 = nice_scale_prefix(max2, unit_symbol=ulist[0])

u2 = ulist[0]
assert all([u == u2 for u in ulist])
Expand Down Expand Up @@ -889,7 +892,7 @@ def plot_fieldmesh_rectangular_2d(
field_2d = interpolated_values.reshape(len(x), len(y))

if nice:
field_2d, _, prefix = nice_array(field_2d)
field_2d, _, prefix = nice_array(field_2d, unit_symbol=unit.unitSymbol)
else:
prefix = ""

Expand Down
30 changes: 24 additions & 6 deletions pmd_beamphysics/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,14 +440,17 @@ def unit(symbol: str) -> pmd_unit:
# Nice scaling


def nice_scale_prefix(scale: float) -> tuple[float, str]:
def nice_scale_prefix(scale: float, unit_symbol: str = None) -> tuple[float, str]:
"""
Returns a nice factor and an SI prefix string.

Parameters
----------
scale : float
The scale to be converted into a nice factor and SI prefix.
unit_symbol : str, optional
The unit symbol to check for special cases. If provided, prefixes
will be suppressed for units containing sqrt symbols (√, sqrt, \\sqrt).

Returns
-------
Expand All @@ -460,7 +463,15 @@ def nice_scale_prefix(scale: float) -> tuple[float, str]:
--------
>>> nice_scale_prefix(scale=2e-10)
(1e-12, 'p')

>>> nice_scale_prefix(scale=2e-10, unit_symbol='√m')
(1, '')
"""

# Check if unit contains sqrt symbols - suppress prefixes for these
if unit_symbol is not None:
if 'sqrt' in unit_symbol.lower() or '√' in unit_symbol or '\\sqrt' in unit_symbol:
return 1, ""

if scale == 0:
return 1, ""
Expand All @@ -479,7 +490,7 @@ def nice_scale_prefix(scale: float) -> tuple[float, str]:
return f, SHORT_PREFIX[f]


def nice_array(a: np.ndarray) -> tuple[np.ndarray, float, str]:
def nice_array(a: np.ndarray, unit_symbol: str = None) -> tuple[np.ndarray, float, str]:
"""
Scale an input array and return the scaled array, the scaling factor, and the
corresponding unit prefix.
Expand All @@ -488,6 +499,9 @@ def nice_array(a: np.ndarray) -> tuple[np.ndarray, float, str]:
----------
a : array-like, or float
Input array to be scaled.
unit_symbol : str, optional
The unit symbol to check for special cases. If provided, prefixes
may be suppressed for certain unit types (e.g., sqrt units).

Returns
-------
Expand All @@ -512,11 +526,11 @@ def nice_array(a: np.ndarray) -> tuple[np.ndarray, float, str]:
a = np.asarray(a)
x = max(np.ptp(a), abs(np.mean(a))) # Account for tiny spread

fac, prefix = nice_scale_prefix(x)
fac, prefix = nice_scale_prefix(x, unit_symbol)
return a / fac, fac, prefix


def plottable_array(x: np.ndarray, nice: bool = True, lim: Limit | None = None):
def plottable_array(x: np.ndarray, nice: bool = True, lim: Limit | None = None, unit_symbol: str = None):
"""
Similar to nice_array, but also considers limits for plotting

Expand All @@ -525,7 +539,11 @@ def plottable_array(x: np.ndarray, nice: bool = True, lim: Limit | None = None):
x: array-like
nice: bool, default = True
Scale array by some nice factor.
xlim: tuple, default = None
lim: tuple, default = None
Optional limits (min, max)
unit_symbol: str, optional
Unit symbol to check for special cases (e.g., sqrt). If provided,
prefixes may be suppressed for certain unit types.

Returns
-------
Expand All @@ -552,7 +570,7 @@ def plottable_array(x: np.ndarray, nice: bool = True, lim: Limit | None = None):
xmax = x.max()

if nice:
_, factor, p1 = nice_array([xmin, xmax])
_, factor, p1 = nice_array([xmin, xmax], unit_symbol)
else:
factor, p1 = 1, ""

Expand Down