From eb122173df12cfcbcff1070a1a6177a581de248b Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 31 Dec 2024 12:37:55 +0000 Subject: [PATCH] better Key.unit formatting: replace unicode sup/superscripts with /+ASCII - Improve unit formatting in enums.py to include HTML tags for sub/superscripts - keys.yml replace ASCII middle dot with proper Unicode character - Adjust border width in ptable_plotly.py for better visualization - more tests in test_enums.py and test_ptable_plotly.py --- .pre-commit-config.yaml | 2 +- pymatviz/enums.py | 75 +++++++++++- pymatviz/keys.yml | 28 +++-- pymatviz/ptable/ptable_plotly.py | 2 +- tests/ptable/plotly/test_ptable_plotly.py | 41 +++++++ tests/test_enums.py | 141 ++++++++++++++++++---- 6 files changed, 250 insertions(+), 39 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fbbd2ee2..2b9690a7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ repos: types_or: [python, jupyter] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.13.0 + rev: v1.14.0 hooks: - id: mypy additional_dependencies: [types-requests, types-PyYAML] diff --git a/pymatviz/enums.py b/pymatviz/enums.py index 2468b39b..dc91610a 100644 --- a/pymatviz/enums.py +++ b/pymatviz/enums.py @@ -152,8 +152,78 @@ def label(self) -> str: @property def unit(self) -> str | None: - """Unit associated with the key.""" - return _keys[self.value].get("unit") + """Unit associated with the key with HTML tags for sub/superscripts.""" + if not (unit := _keys[self.value].get("unit")): + return None + + # Map Unicode characters to their ASCII equivalents + superscript_map = { + "⁰": "0", + "¹": "1", + "²": "2", + "³": "3", + "⁴": "4", + "⁵": "5", + "⁶": "6", + "⁷": "7", + "⁸": "8", + "⁹": "9", + "⁻": "-", + "½": "1/2", + } + subscript_map = { + "₀": "0", + "₁": "1", + "₂": "2", + "₃": "3", + "₄": "4", + "₅": "5", + "₆": "6", + "₇": "7", + "₈": "8", + "₉": "9", + "₋": "-", + "½": "1/2", + } + + # Process character by character + html_str = "" + in_super = in_sub = False + + idx = 0 + while idx < len(unit): + char = unit[idx] + + # Check if character is superscript + if new_char := superscript_map.get(char): + if not in_super: + html_str += "" + in_super = True + html_str += new_char + # Check if character is subscript + elif new_char := subscript_map.get(char): + if not in_sub: + html_str += "" + in_sub = True + html_str += new_char + else: + # Close any open tags + if in_super: + html_str += "" + in_super = False + if in_sub: + html_str += "" + in_sub = False + html_str += char + idx += 1 + + # Close any remaining open tags + if in_super: + html_str += "" + if in_sub: + html_str += "" + + return html_str @property def category(self) -> str: @@ -296,6 +366,7 @@ def __reduce_ex__(self, proto: object) -> tuple[type, tuple[str]]: cohesive_energy_per_atom = "cohesive_energy_per_atom" heat_of_formation = "heat_of_formation" heat_of_reaction = "heat_of_reaction" + e_form = "e_form" e_form_per_atom = "e_form_per_atom" e_form_pred = "e_form_pred" e_form_true = "e_form_true" diff --git a/pymatviz/keys.yml b/pymatviz/keys.yml index 58641b58..7e5c050e 100644 --- a/pymatviz/keys.yml +++ b/pymatviz/keys.yml @@ -151,7 +151,7 @@ electronic: mobility: label: Carrier Mobility symbol: μ - unit: cm²/V·s + unit: cm²/V⋅s effective_mass: label: Effective Mass symbol: me @@ -184,11 +184,11 @@ electronic: electron_mobility: label: Electron Mobility symbol: μe - unit: cm²/V·s + unit: cm²/V⋅s hole_mobility: label: Hole Mobility symbol: μh - unit: cm²/V·s + unit: cm²/V⋅s thermodynamic: energy: @@ -248,6 +248,11 @@ thermodynamic: symbol: Ecoh/Natoms unit: eV/atom description: Energy required to break crystal into isolated neutral atoms + e_form: + label: Eform + symbol: Eform + unit: eV/atom + description: All-atoms formation energy relative to constituent element reference states e_form_per_atom: label: Eform symbol: Eform/Natoms @@ -354,7 +359,7 @@ mechanical: fracture_toughness: label: Fracture Toughness symbol: KIC - unit: MPa·m1/2 + unit: MPa⋅m½ sound_velocity: label: Sound Velocity symbol: vs @@ -362,7 +367,7 @@ mechanical: thermal_conductivity: label: Thermal Conductivity symbol: κ - unit: W/m·K + unit: W/m⋅K thermal_expansion: label: Thermal Expansion symbol: α @@ -370,11 +375,11 @@ mechanical: lattice_thermal_conductivity: label: Lattice Thermal Conductivity symbol: κlattice - unit: W/m·K + unit: W/m⋅K electronic_thermal_conductivity: label: Electronic Thermal Conductivity symbol: κelectronic - unit: W/m·K + unit: W/m⋅K heat_capacity: label: Heat Capacity symbol: CV @@ -382,7 +387,7 @@ mechanical: specific_heat_capacity: label: Specific Heat Capacity symbol: cp - unit: J/kg·K + unit: J/kg⋅K thermal_expansion_coefficient: label: Thermal Expansion Coefficient symbol: αthermal @@ -398,7 +403,7 @@ mechanical: viscosity: label: Viscosity symbol: η - unit: Pa·s + unit: Pa⋅s strain: label: Strain symbol: ε @@ -456,7 +461,7 @@ thermal: thermal_resistivity: label: Thermal Resistivity symbol: ρth - unit: K·m/W + unit: K⋅m/W description: Resistance to heat flow, inverse of thermal conductivity thermal_time_constant: label: Thermal Time Constant @@ -473,7 +478,7 @@ magnetic: magnetic_moment: label: Magnetic Moment symbol: μB - unit: μB + unit: μB magmoms: label: Magnetic Moments symbol: μ @@ -1134,6 +1139,7 @@ metrics: volume_error: label: Volume Error symbol: Verr + unit: ų max_force_error: label: Max Force Error max_stress_error: diff --git a/pymatviz/ptable/ptable_plotly.py b/pymatviz/ptable/ptable_plotly.py index 6f2d154e..95944acb 100644 --- a/pymatviz/ptable/ptable_plotly.py +++ b/pymatviz/ptable/ptable_plotly.py @@ -312,7 +312,7 @@ def ptable_heatmap_plotly( if border is not False: border = border or {} border_color = border.pop("color", "darkgray") - border_width = border.pop("width", 2) + border_width = border.pop("width", 0.5) common_kwargs = dict( z=np.where(tile_texts, 1, np.nan), showscale=False, hoverinfo="none" diff --git a/tests/ptable/plotly/test_ptable_plotly.py b/tests/ptable/plotly/test_ptable_plotly.py index b207ffc1..737f6fdb 100644 --- a/tests/ptable/plotly/test_ptable_plotly.py +++ b/tests/ptable/plotly/test_ptable_plotly.py @@ -170,3 +170,44 @@ def custom_colorscale(_element_symbol: str, _value: float, split_idx: int) -> st for trace in fig.data: # Each element tile should have a color array with custom colors assert trace.fillcolor in {"rgb(255,0,0)", "rgb(0,0,255)"} + + +def test_ptable_heatmap_plotly_colorbar() -> None: + """Test colorbar customization in ptable_heatmap_plotly.""" + data = {"Fe": 1.234, "O": 5.678} + + # Test colorbar title and formatting + colorbar = dict( + title="Test Title", tickformat=".2f", orientation="v", len=0.8, x=1.1 + ) + + fig = pmv.ptable_heatmap_plotly(data, colorbar=colorbar) + + # Get the colorbar from the figure + colorbar_trace = next(trace for trace in fig.data if hasattr(trace, "colorbar")) + actual_colorbar = colorbar_trace.colorbar + + # Check colorbar properties were set correctly + assert actual_colorbar.title.text == "

Test Title" + assert actual_colorbar.tickformat == ".2f" + assert actual_colorbar.orientation == "v" + assert actual_colorbar.len == 0.8 + assert actual_colorbar.x == 1.1 + + # Test horizontal colorbar title formatting + h_colorbar = dict(title="Horizontal Title", orientation="h", y=0.8) + + fig = pmv.ptable_heatmap_plotly(data, colorbar=h_colorbar) + h_colorbar_trace = next(trace for trace in fig.data if hasattr(trace, "colorbar")) + actual_h_colorbar = h_colorbar_trace.colorbar + + # Check horizontal colorbar properties + assert ( + actual_h_colorbar.title.text == "Horizontal Title
" + ) # Horizontal title has break after + assert actual_h_colorbar.orientation == "h" + assert actual_h_colorbar.y == 0.8 + + # Test disabling colorbar + fig = pmv.ptable_heatmap_plotly(data, show_scale=False) + assert not any(trace.showscale for trace in fig.data) diff --git a/tests/test_enums.py b/tests/test_enums.py index bdfcee15..e052ca43 100644 --- a/tests/test_enums.py +++ b/tests/test_enums.py @@ -223,43 +223,42 @@ def test_key_units_are_consistent() -> None: def test_key_label_formatting() -> None: - """Test that labels are properly formatted.""" + """Test that all labels are properly formatted.""" for key in Key: label = key.label - assert ( # Label should be capitalized + assert ( label[0].isupper() or label[0].isdigit() or label.startswith("r2SCAN") - ), f"Label should be capitalized: {label}" + ), f"{label=} should be capitalized" - assert not label.endswith((".", ",")), f"{label=} ends with period or comma" + assert not label.endswith((".", ",")), f"{label=} ends with punctuation" assert label.strip() == label, f"{label=} has outer whitespace" - assert " " not in label, f"Label has multiple spaces: {label}" + assert " " not in label, f"{label=} has multiple spaces" def test_key_value_matches_name() -> None: - """Test that Key enum values match their names.""" + """Test that all Key enum values match their names.""" for key in Key: # yield is a reserved word in Python, so had to be suffixed, hence can't match # key name can't match value if key in (Key.yield_, Key.mat_id): continue - assert str(key) == key.name.lower(), f"Value doesn't match name for {key}" + name, value = key.name, str(key) + assert name == value, f"{name=} doesn't match {value=}" -def test_key_symbol_html_validity() -> None: +def test_key_html_validity() -> None: """Test that HTML subscripts in symbols are properly formatted.""" for key in Key: - symbol = key.symbol - if symbol and "" in symbol: - # Check matching closing tags - assert symbol.count("") == symbol.count( - "" - ), f"Mismatched sub tags in {symbol}" - # Check proper nesting - assert ( - "" not in symbol[: symbol.index("")] - ), f"Improper tag nesting in {symbol}" + for field in ("symbol", "unit"): + value = getattr(key, field) + if value and "" in value: + # Check matching closing tags + n_sub, n_sub_end = value.count(""), value.count("") + assert n_sub == n_sub_end, f"Mismatched sub tags in {value}" + # Check proper nesting + assert "" not in value[: value.index("")] @pytest.mark.parametrize( @@ -279,10 +278,104 @@ def test_key_descriptions(key: Key, expected_description: str | None) -> None: assert key.desc == expected_description, f"Unexpected description for {key}" -def test_key_description_field_name() -> None: - """Test that the description field in YAML is accessed correctly.""" - from pymatviz.enums import _keys +@pytest.mark.parametrize( + ("key", "expected_unit"), + [ + # Basic units + (Key.volume, "Å3"), + (Key.energy, "eV"), + (Key.temperature, "K"), + (Key.pressure, "Pa"), + # Complex units + (Key.carrier_concentration, "cm-3"), + (Key.mobility, "cm2/V⋅s"), + (Key.thermal_conductivity, "W/m⋅K"), + (Key.fracture_toughness, "MPa⋅m1/2"), + # Units with subscripts + (Key.magnetic_moment, "μB"), + (Key.energy_per_atom, "eV/atom"), + (Key.heat_capacity, "J/K"), + # Mixed sub/superscripts + (Key.specific_heat_capacity, "J/kg⋅K"), + (Key.thermal_resistivity, "K⋅m/W"), + # No units + (Key.formula, None), + (Key.structure, None), + (Key.crystal_system, None), + ], +) +def test_key_units(key: Key, expected_unit: str | None) -> None: + """Test that Key enum units are correctly formatted with HTML tags.""" + assert key.unit == expected_unit + + +def test_unit_html_consistency() -> None: + """Test that all units follow consistent HTML formatting rules.""" + for key in Key: + unit = key.unit + if unit is None: + continue - # Check that descriptions use "description" not "desc" in YAML data - keys_with_desc = [k for k, v in _keys.items() if "description" in v] - assert len(keys_with_desc) >= 50, "No descriptions found in YAML data" + # Check proper HTML tag nesting + if "" in unit: + assert unit.count("") == unit.count( + "" + ), f"Mismatched sup tags in {unit}" + assert ( + "" not in unit[: unit.index("")] + ), f"Improper sup tag nesting in {unit}" + + if "" in unit: + assert unit.count("") == unit.count( + "" + ), f"Mismatched sub tags in {unit}" + assert ( + "" not in unit[: unit.index("")] + ), f"Improper sub tag nesting in {unit}" + + # Check no nested tags + if "" in unit and "" in unit: + sup_start = unit.index("") + sup_end = unit.index("") + sub_start = unit.index("") + sub_end = unit.index("") + assert not (sup_start < sub_start < sup_end), "Nested sup/sub tags" + assert not (sub_start < sup_start < sub_end), "Nested sub/sup tags" + + +def test_unit_special_characters() -> None: + """Test that special characters in units are consistently used.""" + for key in Key: + unit = key.unit + if unit is None: + continue + + # Check for proper middle dot usage + assert "·" not in unit, f"ASCII middle dot in {unit}, use ⋅ instead" + + # Check for proper minus sign usage + if "-" in unit: + assert ( + "-" in unit + ), f"ASCII hyphen in {unit}, use -... for exponents" + + # Common units should use standard symbols + if "angstrom" in unit: + assert "Å" in unit, f"Use Å symbol in {unit} of {key}" + if "micro" in unit: + assert "μ" in unit, f"Use μ symbol in {unit} of {key}" + if "ohm" in unit: + assert "Ω" in unit, f"Use Ω symbol in {unit} of {key}" + if "kelvin" in unit: + assert "K" in unit, f"Use K symbol in {unit} of {key}" + if "pascal" in unit: + assert "Pa" in unit, f"Use Pa symbol in {unit} of {key}" + + +def test_unit_formatting_consistency() -> None: + """Test that similar quantities use consistent unit formatting.""" + ev_keys = {k for k in Key if k.unit and "eV" in k.unit} + valid_ev_units = {"eV", "eV/atom", "eV/K", "eV/Å", "meV"} + for key in ev_keys: + unit = key.unit + assert unit in valid_ev_units, f"Unexpected {unit=} for {key=}"