From 970a3da75f8e08e480e3cfad5ac33bcf82f1e0a6 Mon Sep 17 00:00:00 2001 From: koaning Date: Tue, 16 Jun 2026 14:18:18 +0200 Subject: [PATCH 1/6] adding a flax formatter --- marimo/_dependencies/dependencies.py | 1 + marimo/_output/formatters/_nn_tree.py | 383 ++++++++++++++++++ marimo/_output/formatters/flax_formatters.py | 336 +++++++++++++++ marimo/_output/formatters/formatters.py | 2 + .../_output/formatters/pytorch_formatters.py | 376 +---------------- .../formatters/flax_formatters.py | 163 ++++++++ pyproject.toml | 1 + .../formatters/test_flax_formatters.py | 263 ++++++++++++ 8 files changed, 1158 insertions(+), 367 deletions(-) create mode 100644 marimo/_output/formatters/_nn_tree.py create mode 100644 marimo/_output/formatters/flax_formatters.py create mode 100644 marimo/_smoke_tests/formatters/flax_formatters.py create mode 100644 tests/_output/formatters/test_flax_formatters.py diff --git a/marimo/_dependencies/dependencies.py b/marimo/_dependencies/dependencies.py index 40a3a167bf4..8b3e32006e4 100644 --- a/marimo/_dependencies/dependencies.py +++ b/marimo/_dependencies/dependencies.py @@ -255,6 +255,7 @@ class DependencyManager: pydantic = Dependency("pydantic") zmq = Dependency("zmq") # pyzmq for sandbox IPC kernels torch = Dependency("torch") + flax = Dependency("flax") weave = Dependency("weave") # Storage obstore = Dependency("obstore") diff --git a/marimo/_output/formatters/_nn_tree.py b/marimo/_output/formatters/_nn_tree.py new file mode 100644 index 00000000000..8f7478375a5 --- /dev/null +++ b/marimo/_output/formatters/_nn_tree.py @@ -0,0 +1,383 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Shared presentation layer for neural-network module tree formatters. + +Both the PyTorch (`pytorch_formatters`) and Flax NNX (`flax_formatters`) +formatters render a model as the same collapsible HTML tree. The data +extraction differs per framework, but the CSS, layout helpers, and footer +legend are identical and live here so they are defined once. +""" + +from __future__ import annotations + +import re +import typing + +ModuleCategory = typing.Literal["weight", "act", "norm", "reg"] + +_LABELS: dict[ModuleCategory, str] = { + "weight": "Weight", + "act": "Activation", + "norm": "Normalization", + "reg": "Regularization", +} + +# Matches a comma followed by a space that is NOT inside parentheses. +_TOP_COMMA_RE = re.compile(r",\s+(?![^()]*\))") + + +def _comma_to_br(html_str: str) -> str: + """Replace top-level comma separators with
for multi-line display. + + Also replaces the `=` between key/value pairs with a space for the + expanded view, without touching `=` inside HTML attributes. + """ + result = _TOP_COMMA_RE.sub("
", html_str) + return result.replace("=", " ") + + +def _frozen_attr(is_frozen: bool) -> str: + """Build the HTML data-frozen attribute string when needed.""" + if is_frozen: + return ' data-frozen="true"' + return "" + + +def _fmt_integer(n: int) -> str: + """Format int into a human readable string.""" + if n >= 1_000_000: + return f"{n / 1_000_000:.1f}M" + if n >= 1_000: + return f"{n / 1_000:.1f}K" + return str(n) + + +def _footer_html() -> str: + """Build the footer with the info-hover module-type legend.""" + legend_title = 'Module types' + legend_items = "".join( + f'' + f'' + f'{label}' + f"" + for cat, label in _LABELS.items() + ) + # Lucide "info" icon (ISC license) + info_svg = ( + '' + '' + '' + '' + "" + ) + return ( + f'" + ) + + +_CSS = """\ +.nn-t { + font-size: 0.8125rem; + line-height: 1.5; + background-color: var(--slate-1); + color: var(--slate-12); + border-radius: 6px; +} + +/* Header */ +.nn-t-header { + display: flex; + align-items: center; + gap: 0.5rem; + padding: 0.625rem 0.75rem 0.5rem 0.75rem; +} +.nn-t-root { + font-family: monospace; + font-size: 0.875rem; + font-weight: 600; + color: var(--slate-12); +} +.nn-t-summary { + font-family: monospace; + font-size: 0.75rem; + color: var(--slate-11); + margin-left: auto; +} +.nn-t-divider { + height: 1px; + background-color: var(--slate-3); + margin: 0 0.75rem; +} + +/* Body */ +.nn-t-body { + padding: 0.5rem 0 0.5rem 0.75rem; +} + +/* Shared row layout */ +.nn-t-leaf, +.nn-t-node > summary, +.nn-t-expand > summary { + display: flex; + align-items: center; + gap: 0.5rem; + padding: 0.1875rem 0.75rem 0.1875rem 0; + white-space: nowrap; +} +.nn-t-leaf:hover, +.nn-t-node > summary:hover, +.nn-t-expand > summary:hover { + background: var(--slate-2); +} + +/* Expandable nodes */ +.nn-t-node { + margin: 0; + padding: 0; +} +.nn-t-node > summary { + cursor: pointer; + list-style: none; +} +.nn-t-node > summary::-webkit-details-marker { + display: none; +} + +/* Disclosure arrow */ +.nn-t-arrow { + display: inline-flex; + align-items: center; + justify-content: center; + width: 1rem; + flex-shrink: 0; + color: var(--slate-9); + transition: transform 0.12s; + font-size: 0.5rem; +} +.nn-t-node[open] > summary .nn-t-arrow { + transform: rotate(90deg); +} + +/* Leaf spacer matches arrow width */ +.nn-t-spacer { + display: inline-block; + width: 1rem; + flex-shrink: 0; +} + +/* Children with indent guide */ +.nn-t-children { + margin-left: calc(0.5rem - 1px); + padding-left: 0.75rem; + border-left: 1px solid var(--slate-3); +} + +/* Text elements */ +.nn-t-name { + font-family: monospace; + font-size: 0.8125rem; + font-weight: 500; + color: var(--slate-12); +} +.nn-t-type { + font-family: monospace; + font-size: 0.8125rem; + font-weight: 600; + color: var(--slate-12); + padding: 0.0625rem 0.375rem; + border-radius: 0.1875rem; + background: var(--slate-3); +} +.nn-t-type[data-cat="weight"] { --pill-bg: var(--blue-3); --pill-fg: var(--blue-11); } +.nn-t-type[data-cat="norm"] { --pill-bg: var(--green-3); --pill-fg: var(--green-11); } +.nn-t-type[data-cat="act"] { --pill-bg: var(--orange-3); --pill-fg: var(--orange-11); } +.nn-t-type[data-cat="reg"] { --pill-bg: var(--crimson-3); --pill-fg: var(--crimson-11); } +.nn-t-type[data-cat] { + background: var(--pill-bg); + color: var(--pill-fg); +} +/* Positional args (always visible, never truncated) */ +.nn-t-pos { + font-family: monospace; + font-size: 0.8125rem; + color: var(--slate-11); + flex-shrink: 0; +} + +/* Keyword args (truncated with ellipsis) */ +.nn-t-args { + font-family: monospace; + font-size: 0.8125rem; + color: var(--slate-11); + overflow: hidden; + text-overflow: ellipsis; + min-width: 0; +} + +/* Expandable args */ +.nn-t-expand { + margin: 0; + padding: 0; +} +.nn-t-expand > summary { + cursor: pointer; + list-style: none; +} +.nn-t-expand > summary::-webkit-details-marker { + display: none; +} +.nn-t-expand[open] > summary .nn-t-args { + display: none; +} +.nn-t-expand-body { + font-family: monospace; + font-size: 0.8125rem; + color: var(--slate-11); + padding: 0 0.75rem 0.25rem 2.75rem; + line-height: 1.6; +} +.nn-t-key { + color: var(--slate-9); +} +.nn-t-expand-sep { + display: flex; + align-items: center; + gap: 0.25rem; + margin: 0.125rem 0 0 0; +} +.nn-t-expand-sep::after { + content: ""; + flex: 1; + height: 1px; + background: var(--slate-3); +} +.nn-t-expand-sep-label { + font-size: 0.5625rem; + text-transform: uppercase; + letter-spacing: 0.04em; + color: var(--slate-8); + flex-shrink: 0; +} + +/* Param count */ +.nn-t-params { + color: var(--slate-10); + font-family: monospace; + font-size: 0.75rem; + margin-left: auto; + padding-left: 1rem; + flex-shrink: 0; +} +[data-frozen] > .nn-t-type, +[data-frozen] > .nn-t-pos, +[data-frozen] > .nn-t-args, +[data-frozen] > .nn-t-params, +[data-frozen] > .nn-t-spacer, +[data-frozen] > summary > .nn-t-type, +[data-frozen] > summary > .nn-t-pos, +[data-frozen] > summary > .nn-t-args, +[data-frozen] > summary > .nn-t-params, +[data-frozen] > summary > .nn-t-arrow { + opacity: 0.55; +} + +/* Footer with info-hover legend */ +.nn-t-footer { + display: flex; + justify-content: flex-end; + padding: 0.25rem 0.75rem 0.375rem 0.75rem; +} +.nn-t-info { + position: relative; + display: inline-flex; + align-items: center; + justify-content: center; + color: var(--slate-8); + cursor: default; +} +.nn-t-info:hover { color: var(--slate-10); } +.nn-t-info:hover .nn-t-legend { + visibility: visible; + opacity: 1; +} +.nn-t-info svg { + width: 0.875rem; + height: 0.875rem; +} +.nn-t-legend { + visibility: hidden; + opacity: 0; + position: absolute; + bottom: calc(100% + 6px); + right: 0; + z-index: 10; + max-height: 12rem; + overflow-y: auto; + display: flex; + flex-direction: column; + gap: 0.25rem; + padding: 0.375rem 0.625rem; + background: var(--slate-1); + border: 1px solid var(--slate-3); + border-radius: 6px; + white-space: nowrap; + transition: opacity 0.12s, visibility 0.12s; + font-family: monospace; + font-size: 0.75rem; + color: var(--slate-11); +} +.nn-t-legend-title { + font-size: 0.6875rem; + text-transform: uppercase; + letter-spacing: 0.04em; + color: var(--slate-9); + margin-bottom: 0.0625rem; +} +.nn-t-legend-item { + display: flex; + align-items: center; + gap: 0.375rem; +} +.nn-t-swatch { + display: inline-flex; + align-items: center; + justify-content: center; + width: 0.875rem; + height: 0.8125rem; + border-radius: 0.1875rem; + flex-shrink: 0; + background: var(--slate-3); +} +.nn-t-swatch[data-cat="weight"] { background: var(--blue-3); } +.nn-t-swatch[data-cat="norm"] { background: var(--green-3); } +.nn-t-swatch[data-cat="act"] { background: var(--orange-3); } +.nn-t-swatch[data-cat="reg"] { background: var(--crimson-3); } +.nn-t-swatch-dot { + width: 0.25rem; + height: 0.25rem; + border-radius: 50%; + background: var(--slate-8); +} +.nn-t-swatch[data-cat="weight"] .nn-t-swatch-dot { background: var(--blue-11); } +.nn-t-swatch[data-cat="norm"] .nn-t-swatch-dot { background: var(--green-11); } +.nn-t-swatch[data-cat="act"] .nn-t-swatch-dot { background: var(--orange-11); } +.nn-t-swatch[data-cat="reg"] .nn-t-swatch-dot { background: var(--crimson-11); } +.nn-t-swatch[data-dim] { opacity: 0.55; } +.nn-t-legend-sep { + height: 1px; + background: var(--slate-3); + margin: 0.125rem 0; +}""" diff --git a/marimo/_output/formatters/flax_formatters.py b/marimo/_output/formatters/flax_formatters.py new file mode 100644 index 00000000000..eabef75e35a --- /dev/null +++ b/marimo/_output/formatters/flax_formatters.py @@ -0,0 +1,336 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Rich formatter for Flax NNX modules (`flax.nnx.Module`). + +Renders an `nnx.Module` as the same collapsible tree as the PyTorch +formatter (shared presentation lives in `_nn_tree`). NNX is pythonically +close to PyTorch -- submodules are plain attributes -- but its parameter +model differs: variables are typed (`nnx.Param`, `nnx.BatchStat`, ...) and +there is no per-parameter `requires_grad`/frozen concept. We therefore show +the `nnx.Param` count as the primary number and surface any other state +(BatchStat, RngState, ...) as a secondary "+N state" note. +""" + +from __future__ import annotations + +import html +import typing + +from marimo._output.formatters._nn_tree import ( + _CSS, + ModuleCategory, + _comma_to_br, + _fmt_integer, + _footer_html, +) +from marimo._output.formatters.formatter_factory import FormatterFactory +from marimo._output.hypertext import Html + +if typing.TYPE_CHECKING: + from flax import nnx # type: ignore[import-not-found] + +# Map flax.nnx.nn. to our display category. We derive the +# category from `type(module).__module__` without enumerating every class. +# +# This map is shorter than the PyTorch equivalent by design, not because +# layers are missing: NNX consolidates into a handful of subpackages what +# PyTorch splits across many (e.g. conv/sparse/linear all live in `linear`; +# every norm lives in `normalization`). It also keeps stateless ops +# (`relu`, pooling, padding, ...) as plain functions rather than modules, +# so they never appear in the tree. The keys below cover every +# `flax.nnx.nn.*` subpackage that contains Module classes; the structural +# containers (`Dict`, `List`, `Sequential`) are intentionally left +# uncategorized. +_MODULE_CATEGORY: dict[str, ModuleCategory] = { + # Learnable / weighted layers + "linear": "weight", # Linear, LinearGeneral, Conv, ConvTranspose, Einsum, Embed + "attention": "weight", # MultiHeadAttention + "recurrent": "weight", # LSTMCell, GRUCell, RNN, Bidirectional, ... + "lora": "weight", # LoRA, LoRALinear + # Parametric activations (most activations are functions, not modules) + "activations": "act", # PReLU + # Normalization + "normalization": "norm", # BatchNorm, LayerNorm, RMSNorm, GroupNorm, InstanceNorm, ... + # Regularization + "stochastic": "reg", # Dropout +} + + +def _layer_category(module: nnx.Module) -> ModuleCategory | None: + """Classify a module for color-coding using its source subpackage.""" + mod_path = type(module).__module__ or "" + if mod_path.startswith("flax.nnx.nn."): + submod = mod_path.split("flax.nnx.nn.", 1)[1].split(".", 1)[0] + return _MODULE_CATEGORY.get(submod) + return None + + +def _child_modules(mod: nnx.Module) -> list[tuple[str, nnx.Module]]: + """Direct submodules of `mod`, in definition order. + + NNX stores submodules as plain attributes, so we read them from the + instance dict: this preserves the order they were assigned in + `__init__` (matching PyTorch's `named_children`), whereas + `nnx.iter_children` returns them sorted alphabetically by name. List + and dict containers (e.g. `nnx.Sequential.layers`) are themselves + modules whose items are stored under "0", "1", ... attributes, so the + recursion handles them naturally. + """ + from flax import nnx + + children: list[tuple[str, nnx.Module]] = [] + for name, value in vars(mod).items(): + if name.startswith("_"): + continue + if isinstance(value, nnx.Module): + children.append((str(name), value)) + return children + + +def _param_leaves(mod: nnx.Module) -> list[typing.Any]: + """Leaves (arrays) of the module's trainable `nnx.Param` state.""" + import jax # type: ignore[import-not-found, unused-ignore] + from flax import nnx + + try: + return list(jax.tree.leaves(nnx.state(mod, nnx.Param))) + except ValueError: + # No matching state for the filter. + return [] + + +def _sum_size(leaves: typing.Iterable[typing.Any]) -> int: + """Sum the number of elements across array leaves.""" + return sum(int(getattr(leaf, "size", 0)) for leaf in leaves) + + +def _counts(mod: nnx.Module) -> tuple[int, int, int]: + """Return `(param_count, other_state_count, param_bytes)` for a subtree.""" + import jax # type: ignore[import-not-found, unused-ignore] + from flax import nnx + + param_leaves = _param_leaves(mod) + param_count = _sum_size(param_leaves) + param_bytes = sum( + int(getattr(leaf, "size", 0)) * int(getattr(leaf, "itemsize", 0)) + for leaf in param_leaves + ) + try: + total = _sum_size(jax.tree.leaves(nnx.state(mod))) + except ValueError: + total = param_count + return (param_count, max(total - param_count, 0), param_bytes) + + +def _collect_dtype_device( + leaves: typing.Iterable[typing.Any], +) -> tuple[str, str]: + """Summarise dtype and device across array leaves. + + Mirrors the PyTorch formatter: a single token when all leaves agree, + unique values joined with `"/"` when mixed, and `"–"` when empty. + """ + dtypes: set[str] = set() + devices: set[str] = set() + for leaf in leaves: + dtype = getattr(leaf, "dtype", None) + if dtype is not None: + dtypes.add(str(dtype)) + try: + for device in leaf.devices(): + devices.add(str(device)) + except (AttributeError, TypeError): + pass + if not dtypes: + return ("–", "–") + return ( + "/".join(sorted(dtypes)), + "/".join(sorted(devices)) or "–", + ) + + +def _config_kwargs(mod: nnx.Module) -> str: + """Build the HTML key=value config string from a module's attributes. + + NNX modules have no `extra_repr` hook, so we read the plain + configuration attributes set in `__init__`. We keep only simple scalar + values (skipping submodules, variables, callables, and `None`) to avoid + noise, and highlight the keys like the PyTorch formatter does. + """ + from flax import nnx + + parts: list[str] = [] + for key, value in vars(mod).items(): + if key.startswith("_"): + continue + if isinstance(value, (nnx.Module, nnx.Variable)): + continue + if not isinstance(value, (bool, int, float, str, tuple, list)): + continue + parts.append( + f'{html.escape(key)}' + f"={html.escape(repr(value))}" + ) + return ", ".join(parts) + + +def _count_note(param_count: int, other_count: int) -> str: + """Render the param/state count for a row's right-hand summary.""" + if param_count > 0: + note = _fmt_integer(param_count) + if other_count > 0: + note += f" +{_fmt_integer(other_count)} state" + return note + if other_count > 0: + return f"{_fmt_integer(other_count)} state" + return "" + + +def _walk(name: str, mod: nnx.Module) -> str: + """Recursively build HTML tree for an nnx.Module (non-root nodes).""" + children = _child_modules(mod) + type_name = mod.__class__.__name__ + cat = _layer_category(mod) + + name_html = f'{html.escape(name)} ' + cat_attr = f' data-cat="{cat}"' if cat is not None else "" + type_span = f'{type_name}' + + if not children: + param_count, other_count, _ = _counts(mod) + kwargs = _config_kwargs(mod) + note = _count_note(param_count, other_count) + params = f'{note}' if note else "" + + # Build expand body: kwargs first, then dtype/device. + body_parts: list[str] = [] + if kwargs: + body_parts.append(_comma_to_br(kwargs)) + param_leaves = _param_leaves(mod) + if param_leaves: + dtype_s, device_s = _collect_dtype_device(param_leaves) + if body_parts: + body_parts.append( + '
' + 'array' + "
" + ) + body_parts.append( + f'dtype {dtype_s}' + f"
" + f'device {device_s}' + ) + + if body_parts: + kw_inline = ( + f' {kwargs}' if kwargs else "" + ) + return ( + f'
' + f"" + f'' + f"{name_html}{type_span}{kw_inline}" + f"{params}" + f"" + f'
{"".join(body_parts)}
' + f"
" + ) + return ( + f'
' + f'' + f"{name_html}{type_span}" + f"{params}" + f"
" + ) + + # Container node: aggregate all descendant parameters and state. + param_count, other_count, _ = _counts(mod) + note = _count_note(param_count, other_count) + total_params = f'{note}' if note else "" + + children_html = "\n".join( + _walk(child_name, child_mod) for child_name, child_mod in children + ) + + return ( + f'
' + f"" + f'' + f"{name_html}{type_span}" + f"{total_params}" + f"" + f'
{children_html}
' + f"
" + ) + + +def format(module: nnx.Module) -> Html: # noqa: A001 + """Render a Flax NNX module as a collapsible tree. + + The output shows the model name and summary in a fixed header, + with child modules rendered as an expandable tree below. + + Args: + module: A `flax.nnx.Module` instance. + + Returns: + A `marimo.Html` object with the rendered tree. + """ + children = _child_modules(module) + total_params, total_other, total_bytes = _counts(module) + size_mb = total_bytes / (1024 * 1024) + + state_note = f" +{_fmt_integer(total_other)} state" if total_other else "" + header = ( + f'
' + f'{module.__class__.__name__}' + f'' + f"{_fmt_integer(total_params)} params{state_note}" + f" · {size_mb:.1f} MB" + f"" + f"
" + ) + + if children: + body_html = "\n".join( + _walk(child_name, child_mod) for child_name, child_mod in children + ) + body = f'
{body_html}
' + else: + kwargs = _config_kwargs(module) + extra_html = ( + f'{kwargs}' if kwargs else "" + ) + body = ( + f'
' + f'
{extra_html}
' + f"
" + ) + + divider = '
' + footer = _footer_html() + + html_str = ( + f'
' + f"{header}{divider}{body}{footer}" + f"
" + ) + return Html(html_str) + + +class FlaxFormatter(FormatterFactory): + @staticmethod + def package_name() -> str: + return "flax" + + def register(self) -> None: + from flax import nnx # type: ignore[import-not-found,unused-ignore] + + from marimo._messaging.mimetypes import KnownMimeType + from marimo._output import formatting + from marimo._output.formatters.flax_formatters import format as fmt + + @formatting.formatter(nnx.Module) + def _format_module( + module: nnx.Module, + ) -> tuple[KnownMimeType, str]: + return ("text/html", fmt(module).text) diff --git a/marimo/_output/formatters/formatters.py b/marimo/_output/formatters/formatters.py index ee375220989..ff4222dce82 100644 --- a/marimo/_output/formatters/formatters.py +++ b/marimo/_output/formatters/formatters.py @@ -24,6 +24,7 @@ PyArrowFormatter, PySparkFormatter, ) +from marimo._output.formatters.flax_formatters import FlaxFormatter from marimo._output.formatters.formatter_factory import FormatterFactory from marimo._output.formatters.holoviews_formatters import HoloViewsFormatter from marimo._output.formatters.ipython_formatters import IPythonFormatter @@ -80,6 +81,7 @@ OpenAIFormatter.package_name(): OpenAIFormatter(), TransformersFormatter.package_name(): TransformersFormatter(), PyTorchFormatter.package_name(): PyTorchFormatter(), + FlaxFormatter.package_name(): FlaxFormatter(), } # Formatters for builtin types and other things that don't require a diff --git a/marimo/_output/formatters/pytorch_formatters.py b/marimo/_output/formatters/pytorch_formatters.py index 8692ecfc23a..6f94e465ebb 100644 --- a/marimo/_output/formatters/pytorch_formatters.py +++ b/marimo/_output/formatters/pytorch_formatters.py @@ -6,21 +6,20 @@ import re import typing +from marimo._output.formatters._nn_tree import ( + _CSS, + ModuleCategory, + _comma_to_br, + _fmt_integer, + _footer_html, + _frozen_attr, +) from marimo._output.formatters.formatter_factory import FormatterFactory from marimo._output.hypertext import Html if typing.TYPE_CHECKING: import torch # type: ignore[import-not-found] -ModuleCategory = typing.Literal["weight", "act", "norm", "reg"] - -_LABELS: dict[ModuleCategory, str] = { - "weight": "Weight", - "act": "Activation", - "norm": "Normalization", - "reg": "Regularization", -} - # Map torch.nn.modules. to our display category. # # PyTorch organises its layers into subpackages by purpose @@ -52,9 +51,6 @@ # Matches "key=" at the start of a key=value token inside extra_repr(). _KEY_RE = re.compile(r"(? str: - """Replace top-level comma separators with
for multi-line display. - - Also replaces the `=` between key/value pairs with a space for the - expanded view, without touching `=` inside HTML attributes. - """ - result = _TOP_COMMA_RE.sub("
", html_str) - return result.replace("=", " ") - - -def _frozen_attr(is_frozen: bool) -> str: - """Build the HTML data-frozen attribute string when needed.""" - if is_frozen: - return ' data-frozen="true"' - return "" - - def _trainable_info(total: int, trainable: int) -> TrainableInfo: """Compute trainability note and frozen flag from parameter counts.""" if total > 0 and trainable == 0: @@ -165,310 +144,6 @@ def _layer_category(module: torch.nn.Module) -> ModuleCategory | None: return None -def _fmt_integer(n: int) -> str: - """Format int into a human readable string.""" - if n >= 1_000_000: - return f"{n / 1_000_000:.1f}M" - if n >= 1_000: - return f"{n / 1_000:.1f}K" - return str(n) - - -_CSS = """\ -.nn-t { - font-size: 0.8125rem; - line-height: 1.5; - background-color: var(--slate-1); - color: var(--slate-12); - border-radius: 6px; -} - -/* Header */ -.nn-t-header { - display: flex; - align-items: center; - gap: 0.5rem; - padding: 0.625rem 0.75rem 0.5rem 0.75rem; -} -.nn-t-root { - font-family: monospace; - font-size: 0.875rem; - font-weight: 600; - color: var(--slate-12); -} -.nn-t-summary { - font-family: monospace; - font-size: 0.75rem; - color: var(--slate-11); - margin-left: auto; -} -.nn-t-divider { - height: 1px; - background-color: var(--slate-3); - margin: 0 0.75rem; -} - -/* Body */ -.nn-t-body { - padding: 0.5rem 0 0.5rem 0.75rem; -} - -/* Shared row layout */ -.nn-t-leaf, -.nn-t-node > summary, -.nn-t-expand > summary { - display: flex; - align-items: center; - gap: 0.5rem; - padding: 0.1875rem 0.75rem 0.1875rem 0; - white-space: nowrap; -} -.nn-t-leaf:hover, -.nn-t-node > summary:hover, -.nn-t-expand > summary:hover { - background: var(--slate-2); -} - -/* Expandable nodes */ -.nn-t-node { - margin: 0; - padding: 0; -} -.nn-t-node > summary { - cursor: pointer; - list-style: none; -} -.nn-t-node > summary::-webkit-details-marker { - display: none; -} - -/* Disclosure arrow */ -.nn-t-arrow { - display: inline-flex; - align-items: center; - justify-content: center; - width: 1rem; - flex-shrink: 0; - color: var(--slate-9); - transition: transform 0.12s; - font-size: 0.5rem; -} -.nn-t-node[open] > summary .nn-t-arrow { - transform: rotate(90deg); -} - -/* Leaf spacer matches arrow width */ -.nn-t-spacer { - display: inline-block; - width: 1rem; - flex-shrink: 0; -} - -/* Children with indent guide */ -.nn-t-children { - margin-left: calc(0.5rem - 1px); - padding-left: 0.75rem; - border-left: 1px solid var(--slate-3); -} - -/* Text elements */ -.nn-t-name { - font-family: monospace; - font-size: 0.8125rem; - font-weight: 500; - color: var(--slate-12); -} -.nn-t-type { - font-family: monospace; - font-size: 0.8125rem; - font-weight: 600; - color: var(--slate-12); - padding: 0.0625rem 0.375rem; - border-radius: 0.1875rem; - background: var(--slate-3); -} -.nn-t-type[data-cat="weight"] { --pill-bg: var(--blue-3); --pill-fg: var(--blue-11); } -.nn-t-type[data-cat="norm"] { --pill-bg: var(--green-3); --pill-fg: var(--green-11); } -.nn-t-type[data-cat="act"] { --pill-bg: var(--orange-3); --pill-fg: var(--orange-11); } -.nn-t-type[data-cat="reg"] { --pill-bg: var(--crimson-3); --pill-fg: var(--crimson-11); } -.nn-t-type[data-cat] { - background: var(--pill-bg); - color: var(--pill-fg); -} -/* Positional args (always visible, never truncated) */ -.nn-t-pos { - font-family: monospace; - font-size: 0.8125rem; - color: var(--slate-11); - flex-shrink: 0; -} - -/* Keyword args (truncated with ellipsis) */ -.nn-t-args { - font-family: monospace; - font-size: 0.8125rem; - color: var(--slate-11); - overflow: hidden; - text-overflow: ellipsis; - min-width: 0; -} - -/* Expandable args */ -.nn-t-expand { - margin: 0; - padding: 0; -} -.nn-t-expand > summary { - cursor: pointer; - list-style: none; -} -.nn-t-expand > summary::-webkit-details-marker { - display: none; -} -.nn-t-expand[open] > summary .nn-t-args { - display: none; -} -.nn-t-expand-body { - font-family: monospace; - font-size: 0.8125rem; - color: var(--slate-11); - padding: 0 0.75rem 0.25rem 2.75rem; - line-height: 1.6; -} -.nn-t-key { - color: var(--slate-9); -} -.nn-t-expand-sep { - display: flex; - align-items: center; - gap: 0.25rem; - margin: 0.125rem 0 0 0; -} -.nn-t-expand-sep::after { - content: ""; - flex: 1; - height: 1px; - background: var(--slate-3); -} -.nn-t-expand-sep-label { - font-size: 0.5625rem; - text-transform: uppercase; - letter-spacing: 0.04em; - color: var(--slate-8); - flex-shrink: 0; -} - -/* Param count */ -.nn-t-params { - color: var(--slate-10); - font-family: monospace; - font-size: 0.75rem; - margin-left: auto; - padding-left: 1rem; - flex-shrink: 0; -} -[data-frozen] > .nn-t-type, -[data-frozen] > .nn-t-pos, -[data-frozen] > .nn-t-args, -[data-frozen] > .nn-t-params, -[data-frozen] > .nn-t-spacer, -[data-frozen] > summary > .nn-t-type, -[data-frozen] > summary > .nn-t-pos, -[data-frozen] > summary > .nn-t-args, -[data-frozen] > summary > .nn-t-params, -[data-frozen] > summary > .nn-t-arrow { - opacity: 0.55; -} - -/* Footer with info-hover legend */ -.nn-t-footer { - display: flex; - justify-content: flex-end; - padding: 0.25rem 0.75rem 0.375rem 0.75rem; -} -.nn-t-info { - position: relative; - display: inline-flex; - align-items: center; - justify-content: center; - color: var(--slate-8); - cursor: default; -} -.nn-t-info:hover { color: var(--slate-10); } -.nn-t-info:hover .nn-t-legend { - visibility: visible; - opacity: 1; -} -.nn-t-info svg { - width: 0.875rem; - height: 0.875rem; -} -.nn-t-legend { - visibility: hidden; - opacity: 0; - position: absolute; - bottom: calc(100% + 6px); - right: 0; - z-index: 10; - max-height: 12rem; - overflow-y: auto; - display: flex; - flex-direction: column; - gap: 0.25rem; - padding: 0.375rem 0.625rem; - background: var(--slate-1); - border: 1px solid var(--slate-3); - border-radius: 6px; - white-space: nowrap; - transition: opacity 0.12s, visibility 0.12s; - font-family: monospace; - font-size: 0.75rem; - color: var(--slate-11); -} -.nn-t-legend-title { - font-size: 0.6875rem; - text-transform: uppercase; - letter-spacing: 0.04em; - color: var(--slate-9); - margin-bottom: 0.0625rem; -} -.nn-t-legend-item { - display: flex; - align-items: center; - gap: 0.375rem; -} -.nn-t-swatch { - display: inline-flex; - align-items: center; - justify-content: center; - width: 0.875rem; - height: 0.8125rem; - border-radius: 0.1875rem; - flex-shrink: 0; - background: var(--slate-3); -} -.nn-t-swatch[data-cat="weight"] { background: var(--blue-3); } -.nn-t-swatch[data-cat="norm"] { background: var(--green-3); } -.nn-t-swatch[data-cat="act"] { background: var(--orange-3); } -.nn-t-swatch[data-cat="reg"] { background: var(--crimson-3); } -.nn-t-swatch-dot { - width: 0.25rem; - height: 0.25rem; - border-radius: 50%; - background: var(--slate-8); -} -.nn-t-swatch[data-cat="weight"] .nn-t-swatch-dot { background: var(--blue-11); } -.nn-t-swatch[data-cat="norm"] .nn-t-swatch-dot { background: var(--green-11); } -.nn-t-swatch[data-cat="act"] .nn-t-swatch-dot { background: var(--orange-11); } -.nn-t-swatch[data-cat="reg"] .nn-t-swatch-dot { background: var(--crimson-11); } -.nn-t-swatch[data-dim] { opacity: 0.55; } -.nn-t-legend-sep { - height: 1px; - background: var(--slate-3); - margin: 0.125rem 0; -}""" - - def _walk(mod: torch.nn.Module, name: str = "") -> str: """Recursively build HTML tree for an nn.Module (non-root nodes).""" children = list(mod.named_children()) @@ -623,40 +298,7 @@ def format(module: torch.nn.Module) -> Html: # noqa: A001 ) divider = '
' - - legend_title = 'Module types' - legend_items = "".join( - f'' - f'' - f'{label}' - f"" - for cat, label in _LABELS.items() - ) - # Lucide "info" icon (ISC license) - info_svg = ( - '' - '' - '' - '' - "" - ) - footer = ( - f'" - ) + footer = _footer_html() html = ( f'
' diff --git a/marimo/_smoke_tests/formatters/flax_formatters.py b/marimo/_smoke_tests/formatters/flax_formatters.py new file mode 100644 index 00000000000..816e876fda3 --- /dev/null +++ b/marimo/_smoke_tests/formatters/flax_formatters.py @@ -0,0 +1,163 @@ +# /// script +# dependencies = [ +# "marimo", +# "flax>=0.10.0", +# "jax>=0.4.0", +# ] +# requires-python = ">=3.11" +# /// + +import marimo + +__generated_with = "0.19.11" +app = marimo.App() + +with app.setup: + import jax.numpy as jnp + from flax import nnx + + +@app.cell +def _(): + # Simple MLP. Stateless ops (relu) are plain functions in NNX, so they + # don't appear in the tree -- only the Linear and Dropout modules do. + mlp = nnx.Sequential( + nnx.Linear(784, 256, rngs=nnx.Rngs(0)), + nnx.relu, + nnx.Dropout(0.2, rngs=nnx.Rngs(1)), + nnx.Linear(256, 128, rngs=nnx.Rngs(2)), + nnx.relu, + nnx.Dropout(0.1, rngs=nnx.Rngs(3)), + nnx.Linear(128, 10, rngs=nnx.Rngs(4)), + ) + mlp + return + + +@app.cell +def _(): + # CNN for image classification. BatchNorm carries non-trainable state + # (running stats), surfaced as a "+N state" note alongside the params. + class SimpleCNN(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.conv1 = nnx.Conv(3, 32, kernel_size=(3, 3), padding="SAME", rngs=rngs) + self.bn1 = nnx.BatchNorm(32, rngs=rngs) + self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), padding="SAME", rngs=rngs) + self.bn2 = nnx.BatchNorm(64, rngs=rngs) + self.linear = nnx.Linear(64 * 8 * 8, 256, rngs=rngs) + self.out = nnx.Linear(256, 10, rngs=rngs) + + def __call__(self, x): + x = nnx.max_pool(nnx.relu(self.bn1(self.conv1(x))), (2, 2), strides=(2, 2)) + x = nnx.max_pool(nnx.relu(self.bn2(self.conv2(x))), (2, 2), strides=(2, 2)) + x = x.reshape(x.shape[0], -1) + x = nnx.relu(self.linear(x)) + return self.out(x) + + cnn = SimpleCNN(nnx.Rngs(0)) + cnn + return + + +@app.cell +def _(): + # Mini ResNet with skip connections + class ResBlock(nnx.Module): + def __init__(self, channels: int, rngs: nnx.Rngs): + self.conv1 = nnx.Conv(channels, channels, kernel_size=(3, 3), padding="SAME", rngs=rngs) + self.bn1 = nnx.BatchNorm(channels, rngs=rngs) + self.conv2 = nnx.Conv(channels, channels, kernel_size=(3, 3), padding="SAME", rngs=rngs) + self.bn2 = nnx.BatchNorm(channels, rngs=rngs) + + def __call__(self, x): + residual = x + out = nnx.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + return nnx.relu(out + residual) + + class MiniResNet(nnx.Module): + def __init__(self, rngs: nnx.Rngs, num_classes: int = 10): + self.stem = nnx.Conv(3, 64, kernel_size=(7, 7), strides=(2, 2), padding="SAME", rngs=rngs) + self.stem_norm = nnx.BatchNorm(64, rngs=rngs) + self.layer1 = nnx.Sequential(ResBlock(64, rngs), ResBlock(64, rngs)) + self.fc = nnx.Linear(64, num_classes, rngs=rngs) + + def __call__(self, x): + x = nnx.relu(self.stem_norm(self.stem(x))) + x = self.layer1(x) + x = jnp.mean(x, axis=(1, 2)) + return self.fc(x) + + resnet = MiniResNet(nnx.Rngs(0)) + resnet + return + + +@app.cell +def _(): + # Mini Transformer (attention is its own category in the legend) + class TransformerBlock(nnx.Module): + def __init__(self, d_model: int, nhead: int, rngs: nnx.Rngs): + self.attn = nnx.MultiHeadAttention( + num_heads=nhead, in_features=d_model, decode=False, rngs=rngs + ) + self.norm1 = nnx.LayerNorm(d_model, rngs=rngs) + self.ffn = nnx.Sequential( + nnx.Linear(d_model, 512, rngs=rngs), + nnx.relu, + nnx.Linear(512, d_model, rngs=rngs), + ) + self.norm2 = nnx.LayerNorm(d_model, rngs=rngs) + + def __call__(self, x): + x = self.norm1(x + self.attn(x)) + return self.norm2(x + self.ffn(x)) + + class MiniTransformer(nnx.Module): + def __init__( + self, + rngs: nnx.Rngs, + vocab_size: int = 10000, + d_model: int = 256, + nhead: int = 4, + num_layers: int = 2, + num_classes: int = 5, + ): + self.embedding = nnx.Embed(vocab_size, d_model, rngs=rngs) + self.layers = nnx.Sequential( + *[TransformerBlock(d_model, nhead, rngs) for _ in range(num_layers)] + ) + self.classifier = nnx.Linear(d_model, num_classes, rngs=rngs) + + def __call__(self, x): + x = self.embedding(x) + x = self.layers(x) + x = jnp.mean(x, axis=1) + return self.classifier(x) + + transformer = MiniTransformer(nnx.Rngs(0)) + transformer + return + + +@app.cell +def _(): + # NNX-specific layers: LoRA adapters (weight) and the parametric + # PReLU activation (activation), alongside a BatchNorm whose running + # statistics show up as "+N state". + class Adapted(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.lora = nnx.LoRALinear(128, 128, lora_rank=8, rngs=rngs) + self.act = nnx.PReLU() + self.norm = nnx.BatchNorm(128, rngs=rngs) + self.head = nnx.Linear(128, 10, rngs=rngs) + + def __call__(self, x): + return self.head(self.norm(self.act(self.lora(x)))) + + Adapted(nnx.Rngs(0)) + return + + +if __name__ == "__main__": + app.run() diff --git a/pyproject.toml b/pyproject.toml index fa90f603ac2..0860dfc824c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -217,6 +217,7 @@ test-optional = [ "jupytext>=1.17.2", # For standard scientific computing/ ML "jax>=0.4.0; python_version == '3.12'", + "flax>=0.10.0; python_version == '3.12'", "torch>=2.4.0; python_version == '3.12'", "scikit-bio>=0.6.3; python_version == '3.12'", "mcp>=1.0.0", diff --git a/tests/_output/formatters/test_flax_formatters.py b/tests/_output/formatters/test_flax_formatters.py new file mode 100644 index 00000000000..5081f75e656 --- /dev/null +++ b/tests/_output/formatters/test_flax_formatters.py @@ -0,0 +1,263 @@ +# Copyright 2026 Marimo. All rights reserved. +from __future__ import annotations + +import re + +import pytest + +from marimo._dependencies.dependencies import DependencyManager +from marimo._output.formatters.formatters import register_formatters + +HAS_DEPS = DependencyManager.flax.has() + + +def _names_in_order(html: str) -> list[str]: + """Child names as rendered, top-to-bottom.""" + return re.findall(r'nn-t-name">([^<]+)', html) + + +@pytest.mark.skipif(not HAS_DEPS, reason="flax not installed") +class TestFlaxFormatter: + def test_format_simple_module(self) -> None: + from flax import nnx + + from marimo._output.formatters.flax_formatters import format + + model = nnx.Linear(10, 5, rngs=nnx.Rngs(0)) + html = format(model).text + + assert "nn-t" in html + assert "Linear" in html + assert "nn-t-summary" in html + + def test_format_sequential(self) -> None: + from flax import nnx + + from marimo._output.formatters.flax_formatters import format + + model = nnx.Sequential( + nnx.Linear(784, 256, rngs=nnx.Rngs(0)), + nnx.relu, + nnx.Linear(256, 10, rngs=nnx.Rngs(1)), + ) + html = format(model).text + + assert "Sequential" in html + assert "Linear" in html + # Tree structure elements + assert "nn-t-node" in html + assert "nn-t-arrow" in html + # dtype/device in expand bodies of Linear layers + assert "float32" in html + assert "cpu" in html + + def test_format_nested_module(self) -> None: + from flax import nnx + + from marimo._output.formatters.flax_formatters import format + + class SimpleNet(nnx.Module): + def __init__(self, rngs: nnx.Rngs) -> None: + self.features = nnx.Sequential( + nnx.Conv(3, 16, kernel_size=(3, 3), rngs=rngs), + ) + self.classifier = nnx.Linear(16, 10, rngs=rngs) + + html = format(SimpleNet(nnx.Rngs(0))).text + + assert "SimpleNet" in html + assert "features" in html + assert "classifier" in html + assert "Conv" in html + + def test_children_in_definition_order(self) -> None: + """Children render in __init__ order, not alphabetical. + + `nnx.iter_children` sorts by name; the formatter must instead + preserve definition order to match the model's actual structure. + """ + from flax import nnx + + from marimo._output.formatters.flax_formatters import format + + class Model(nnx.Module): + def __init__(self, rngs: nnx.Rngs) -> None: + self.linear = nnx.Linear(4, 8, rngs=rngs) + self.bn = nnx.BatchNorm(8, rngs=rngs) + self.dropout = nnx.Dropout(0.2, rngs=rngs) + self.linear_out = nnx.Linear(8, 2, rngs=rngs) + + html = format(Model(nnx.Rngs(0))).text + assert _names_in_order(html) == [ + "linear", + "bn", + "dropout", + "linear_out", + ] + + def test_state_note(self) -> None: + """Non-trainable state (BatchNorm running stats) is surfaced.""" + from flax import nnx + + from marimo._output.formatters.flax_formatters import format + + html = format(nnx.BatchNorm(16, rngs=nnx.Rngs(0))).text + assert "state" in html + + def test_category_badges(self) -> None: + from flax import nnx + + from marimo._output.formatters.flax_formatters import format + + class Model(nnx.Module): + def __init__(self, rngs: nnx.Rngs) -> None: + self.linear = nnx.Linear(8, 8, rngs=rngs) # weight + self.act = nnx.PReLU() # activation + self.norm = nnx.BatchNorm(8, rngs=rngs) # normalization + self.drop = nnx.Dropout(0.5, rngs=rngs) # regularization + + html = format(Model(nnx.Rngs(0))).text + + # Assert on the layer type-pills specifically -- the footer legend + # always contains a swatch for every category, so a bare + # `data-cat="..."` check would pass even without categorized layers. + for cat in ("weight", "act", "norm", "reg"): + assert f'class="nn-t-type" data-cat="{cat}"' in html + + def test_legend_present(self) -> None: + from flax import nnx + + from marimo._output.formatters.flax_formatters import format + + html = format(nnx.Linear(10, 5, rngs=nnx.Rngs(0))).text + + assert "nn-t-legend" in html + assert "Module types" in html + + def test_returns_html_type(self) -> None: + from flax import nnx + + from marimo._output.formatters.flax_formatters import format + from marimo._output.hypertext import Html + + result = format(nnx.Linear(10, 5, rngs=nnx.Rngs(0))) + assert isinstance(result, Html) + + def test_layer_category(self) -> None: + from flax import nnx + + from marimo._output.formatters.flax_formatters import _layer_category + + r = nnx.Rngs(0) + assert _layer_category(nnx.Linear(1, 1, rngs=r)) == "weight" + assert ( + _layer_category(nnx.Conv(1, 1, kernel_size=(1,), rngs=r)) + == "weight" + ) + assert _layer_category(nnx.Embed(2, 2, rngs=r)) == "weight" + assert ( + _layer_category(nnx.LoRALinear(2, 2, lora_rank=1, rngs=r)) + == "weight" + ) + assert _layer_category(nnx.PReLU()) == "act" + assert _layer_category(nnx.BatchNorm(1, rngs=r)) == "norm" + assert _layer_category(nnx.LayerNorm(1, rngs=r)) == "norm" + assert _layer_category(nnx.Dropout(0.5, rngs=r)) == "reg" + # Container has no category + assert _layer_category(nnx.Sequential()) is None + + def test_counts_params_only(self) -> None: + from flax import nnx + + from marimo._output.formatters.flax_formatters import _counts + + # Linear(4, 8): kernel 4*8 + bias 8 = 40 params, no other state. + param_count, other_count, param_bytes = _counts( + nnx.Linear(4, 8, rngs=nnx.Rngs(0)) + ) + assert param_count == 40 + assert other_count == 0 + assert param_bytes == 40 * 4 # float32 + + def test_counts_with_state(self) -> None: + from flax import nnx + + from marimo._output.formatters.flax_formatters import _counts + + # BatchNorm(8): scale + bias = 16 params; mean + var = 16 state. + param_count, other_count, _ = _counts( + nnx.BatchNorm(8, rngs=nnx.Rngs(0)) + ) + assert param_count == 16 + assert other_count == 16 + + def test_config_kwargs(self) -> None: + from flax import nnx + + from marimo._output.formatters.flax_formatters import _config_kwargs + + kwargs = _config_kwargs(nnx.Linear(10, 5, rngs=nnx.Rngs(0))) + assert "in_features" in kwargs + assert "out_features" in kwargs + + def test_child_modules_order(self) -> None: + from flax import nnx + + from marimo._output.formatters.flax_formatters import _child_modules + + class Model(nnx.Module): + def __init__(self, rngs: nnx.Rngs) -> None: + self.first = nnx.Linear(2, 2, rngs=rngs) + self.second = nnx.BatchNorm(2, rngs=rngs) + self.third = nnx.Linear(2, 2, rngs=rngs) + + names = [name for name, _ in _child_modules(Model(nnx.Rngs(0)))] + assert names == ["first", "second", "third"] + + def test_collect_dtype_device(self) -> None: + import jax.numpy as jnp + + from marimo._output.formatters.flax_formatters import ( + _collect_dtype_device, + ) + + dtype_str, device_str = _collect_dtype_device( + [jnp.zeros(2), jnp.ones(3)] + ) + assert dtype_str == "float32" + assert "cpu" in device_str + + # Mixed dtypes are joined with "/". + dtype_str, _ = _collect_dtype_device( + [ + jnp.zeros(2, dtype=jnp.float32), + jnp.zeros(2, dtype=jnp.float16), + ] + ) + assert dtype_str == "float16/float32" + + # Empty -> en-dash placeholders. + assert _collect_dtype_device([]) == ("–", "–") + + def test_fmt_integer(self) -> None: + from marimo._output.formatters._nn_tree import _fmt_integer + + assert _fmt_integer(500) == "500" + assert _fmt_integer(1_500) == "1.5K" + assert _fmt_integer(1_500_000) == "1.5M" + + def test_formatter_registration(self) -> None: + """Smoke test: the formatter registers and produces output.""" + register_formatters() + + from flax import nnx + + from marimo._output.formatting import get_formatter + + model = nnx.Linear(10, 5, rngs=nnx.Rngs(0)) + formatter = get_formatter(model) + assert formatter is not None + mimetype, data = formatter(model) + assert mimetype == "text/html" + assert "nn-t" in data + assert "Linear" in data From 41e713550c468da95f26e1bf719abfe339c87e34 Mon Sep 17 00:00:00 2001 From: koaning Date: Tue, 16 Jun 2026 14:43:19 +0200 Subject: [PATCH 2/6] Count only trainable params in Flax formatter for PyTorch parity The Flax formatter previously surfaced non-trainable state (BatchNorm running stats, dropout PRNG keys) as a "+N state" note. PyTorch's formatter ignores buffers, so the two were inconsistent. Drop the state note and count only nnx.Param, matching the PyTorch behavior. Updates tests and the smoke-test notebook accordingly. Co-Authored-By: Claude Opus 4.8 --- marimo/_output/formatters/flax_formatters.py | 51 +++++++------------ .../formatters/flax_formatters.py | 7 ++- .../formatters/test_flax_formatters.py | 41 +++++++++------ 3 files changed, 48 insertions(+), 51 deletions(-) diff --git a/marimo/_output/formatters/flax_formatters.py b/marimo/_output/formatters/flax_formatters.py index eabef75e35a..e5ad65ebb9f 100644 --- a/marimo/_output/formatters/flax_formatters.py +++ b/marimo/_output/formatters/flax_formatters.py @@ -5,9 +5,9 @@ formatter (shared presentation lives in `_nn_tree`). NNX is pythonically close to PyTorch -- submodules are plain attributes -- but its parameter model differs: variables are typed (`nnx.Param`, `nnx.BatchStat`, ...) and -there is no per-parameter `requires_grad`/frozen concept. We therefore show -the `nnx.Param` count as the primary number and surface any other state -(BatchStat, RngState, ...) as a secondary "+N state" note. +there is no per-parameter `requires_grad`/frozen concept. We count only +trainable parameters (`nnx.Param`), mirroring the PyTorch formatter's +handling of parameters vs. buffers, so the two stay consistent. """ from __future__ import annotations @@ -103,22 +103,22 @@ def _sum_size(leaves: typing.Iterable[typing.Any]) -> int: return sum(int(getattr(leaf, "size", 0)) for leaf in leaves) -def _counts(mod: nnx.Module) -> tuple[int, int, int]: - """Return `(param_count, other_state_count, param_bytes)` for a subtree.""" - import jax # type: ignore[import-not-found, unused-ignore] - from flax import nnx +def _counts(mod: nnx.Module) -> tuple[int, int]: + """Return `(param_count, param_bytes)` for a subtree's trainable params. + Like the PyTorch formatter, only trainable parameters (`nnx.Param`) + are counted. Non-trainable state -- BatchNorm running statistics + (`nnx.BatchStat`), PRNG keys (`nnx.RngState`), caches, etc. -- is left + out, mirroring PyTorch's handling of buffers and keeping the two + formatters consistent. + """ param_leaves = _param_leaves(mod) param_count = _sum_size(param_leaves) param_bytes = sum( int(getattr(leaf, "size", 0)) * int(getattr(leaf, "itemsize", 0)) for leaf in param_leaves ) - try: - total = _sum_size(jax.tree.leaves(nnx.state(mod))) - except ValueError: - total = param_count - return (param_count, max(total - param_count, 0), param_bytes) + return (param_count, param_bytes) def _collect_dtype_device( @@ -173,18 +173,6 @@ def _config_kwargs(mod: nnx.Module) -> str: return ", ".join(parts) -def _count_note(param_count: int, other_count: int) -> str: - """Render the param/state count for a row's right-hand summary.""" - if param_count > 0: - note = _fmt_integer(param_count) - if other_count > 0: - note += f" +{_fmt_integer(other_count)} state" - return note - if other_count > 0: - return f"{_fmt_integer(other_count)} state" - return "" - - def _walk(name: str, mod: nnx.Module) -> str: """Recursively build HTML tree for an nnx.Module (non-root nodes).""" children = _child_modules(mod) @@ -196,9 +184,9 @@ def _walk(name: str, mod: nnx.Module) -> str: type_span = f'{type_name}' if not children: - param_count, other_count, _ = _counts(mod) + param_count, _ = _counts(mod) kwargs = _config_kwargs(mod) - note = _count_note(param_count, other_count) + note = _fmt_integer(param_count) if param_count > 0 else "" params = f'{note}' if note else "" # Build expand body: kwargs first, then dtype/device. @@ -242,9 +230,9 @@ def _walk(name: str, mod: nnx.Module) -> str: f"
" ) - # Container node: aggregate all descendant parameters and state. - param_count, other_count, _ = _counts(mod) - note = _count_note(param_count, other_count) + # Container node: aggregate all descendant parameters. + param_count, _ = _counts(mod) + note = _fmt_integer(param_count) if param_count > 0 else "" total_params = f'{note}' if note else "" children_html = "\n".join( @@ -276,15 +264,14 @@ def format(module: nnx.Module) -> Html: # noqa: A001 A `marimo.Html` object with the rendered tree. """ children = _child_modules(module) - total_params, total_other, total_bytes = _counts(module) + total_params, total_bytes = _counts(module) size_mb = total_bytes / (1024 * 1024) - state_note = f" +{_fmt_integer(total_other)} state" if total_other else "" header = ( f'
' f'{module.__class__.__name__}' f'' - f"{_fmt_integer(total_params)} params{state_note}" + f"{_fmt_integer(total_params)} params" f" · {size_mb:.1f} MB" f"" f"
" diff --git a/marimo/_smoke_tests/formatters/flax_formatters.py b/marimo/_smoke_tests/formatters/flax_formatters.py index 816e876fda3..a93ce9ce0ec 100644 --- a/marimo/_smoke_tests/formatters/flax_formatters.py +++ b/marimo/_smoke_tests/formatters/flax_formatters.py @@ -36,8 +36,8 @@ def _(): @app.cell def _(): - # CNN for image classification. BatchNorm carries non-trainable state - # (running stats), surfaced as a "+N state" note alongside the params. + # CNN for image classification. Only trainable params are counted + # (BatchNorm running statistics, like PyTorch buffers, are not). class SimpleCNN(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.conv1 = nnx.Conv(3, 32, kernel_size=(3, 3), padding="SAME", rngs=rngs) @@ -143,8 +143,7 @@ def __call__(self, x): @app.cell def _(): # NNX-specific layers: LoRA adapters (weight) and the parametric - # PReLU activation (activation), alongside a BatchNorm whose running - # statistics show up as "+N state". + # PReLU activation (activation), plus a BatchNorm for the norm color. class Adapted(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.lora = nnx.LoRALinear(128, 128, lora_rank=8, rngs=rngs) diff --git a/tests/_output/formatters/test_flax_formatters.py b/tests/_output/formatters/test_flax_formatters.py index 5081f75e656..7787fcb5b29 100644 --- a/tests/_output/formatters/test_flax_formatters.py +++ b/tests/_output/formatters/test_flax_formatters.py @@ -95,14 +95,26 @@ def __init__(self, rngs: nnx.Rngs) -> None: "linear_out", ] - def test_state_note(self) -> None: - """Non-trainable state (BatchNorm running stats) is surfaced.""" + def test_non_trainable_state_not_shown(self) -> None: + """Only trainable params are counted, matching PyTorch buffers. + + BatchNorm running statistics and a Dropout's PRNG state must not + appear in the output, so the count stays consistent with the + PyTorch formatter (which ignores buffers). + """ from flax import nnx from marimo._output.formatters.flax_formatters import format - html = format(nnx.BatchNorm(16, rngs=nnx.Rngs(0))).text - assert "state" in html + class Model(nnx.Module): + def __init__(self, rngs: nnx.Rngs) -> None: + self.bn = nnx.BatchNorm(16, rngs=rngs) + self.dropout = nnx.Dropout(0.2, rngs=rngs) + + html = format(Model(nnx.Rngs(0))).text + assert "state" not in html + # Trainable params (BatchNorm scale + bias = 32) are still shown. + assert "params" in html def test_category_badges(self) -> None: from flax import nnx @@ -171,25 +183,24 @@ def test_counts_params_only(self) -> None: from marimo._output.formatters.flax_formatters import _counts - # Linear(4, 8): kernel 4*8 + bias 8 = 40 params, no other state. - param_count, other_count, param_bytes = _counts( - nnx.Linear(4, 8, rngs=nnx.Rngs(0)) - ) + # Linear(4, 8): kernel 4*8 + bias 8 = 40 params. + param_count, param_bytes = _counts(nnx.Linear(4, 8, rngs=nnx.Rngs(0))) assert param_count == 40 - assert other_count == 0 assert param_bytes == 40 * 4 # float32 - def test_counts_with_state(self) -> None: + def test_counts_ignores_non_trainable_state(self) -> None: + """Only trainable params count; buffers/state are not included.""" from flax import nnx from marimo._output.formatters.flax_formatters import _counts - # BatchNorm(8): scale + bias = 16 params; mean + var = 16 state. - param_count, other_count, _ = _counts( - nnx.BatchNorm(8, rngs=nnx.Rngs(0)) - ) + # BatchNorm(8): scale + bias = 16 params; the 16 running-stat + # elements (BatchStat) are not counted. + param_count, _ = _counts(nnx.BatchNorm(8, rngs=nnx.Rngs(0))) assert param_count == 16 - assert other_count == 16 + + # Dropout has no trainable params (only PRNG state) -> 0. + assert _counts(nnx.Dropout(0.2, rngs=nnx.Rngs(0))) == (0, 0) def test_config_kwargs(self) -> None: from flax import nnx From eff0095596495d339cda84f0f31ebdaf7041f290 Mon Sep 17 00:00:00 2001 From: koaning Date: Tue, 16 Jun 2026 16:43:29 +0200 Subject: [PATCH 3/6] Share the NN-tree renderer between PyTorch and Flax formatters The tree-walking and node HTML construction was duplicated nearly line-for-line between the two formatters; only CSS/helpers had been shared. Introduce framework-agnostic TreeNode/LeafBody types and render_node/render_model in _nn_tree.py, so each formatter now only extracts data from its own modules. Also address PR review feedback: dim params-less Flax leaves (data-frozen) so the shared legend is accurate, raise the jax floor to >=0.4.27 (required by flax 0.10), and fix a stale smoke-test comment about the attention category. Co-Authored-By: Claude Opus 4.8 --- marimo/_output/formatters/_nn_tree.py | 173 +++++++++++++++++ marimo/_output/formatters/flax_formatters.py | 150 +++++---------- .../_output/formatters/pytorch_formatters.py | 178 ++++++------------ .../formatters/flax_formatters.py | 2 +- pyproject.toml | 3 +- .../formatters/test_flax_formatters.py | 23 +++ 6 files changed, 300 insertions(+), 229 deletions(-) diff --git a/marimo/_output/formatters/_nn_tree.py b/marimo/_output/formatters/_nn_tree.py index 8f7478375a5..22329c0c23d 100644 --- a/marimo/_output/formatters/_nn_tree.py +++ b/marimo/_output/formatters/_nn_tree.py @@ -9,9 +9,12 @@ from __future__ import annotations +import dataclasses import re import typing +from marimo._output.hypertext import Html + ModuleCategory = typing.Literal["weight", "act", "norm", "reg"] _LABELS: dict[ModuleCategory, str] = { @@ -88,6 +91,176 @@ def _footer_html() -> str: ) +@dataclasses.dataclass +class LeafBody: + """Expand-panel content for a leaf row: kwargs and dtype/device. + + `kwargs_inline` is shown (truncated) on the summary line; `kwargs_block` + is the multi-line version for the expanded panel. `array_label` is the + divider label between the kwargs and the dtype/device section + (`"tensor"` for PyTorch, `"array"` for Flax). + """ + + kwargs_inline: str = "" + kwargs_block: str = "" + dtype: str | None = None + device: str | None = None + array_label: str = "array" + + def has_content(self) -> bool: + return bool(self.kwargs_block) or self.dtype is not None + + +@dataclasses.dataclass +class TreeNode: + """A framework-agnostic node in the module tree, ready to render. + + Each formatter extracts these fields from its own module objects; the + renderer turns them into the shared HTML. A node is a container when + `children is not None`, otherwise a leaf. + """ + + name: str + type_name: str + category: ModuleCategory | None = None + params_note: str = "" + is_frozen: bool = False + positional: str = "" + body: LeafBody | None = None + children: list[TreeNode] | None = None + + +def _name_html(name: str) -> str: + return f'{name} ' if name else "" + + +def _type_html(type_name: str, category: ModuleCategory | None) -> str: + cat_attr = f' data-cat="{category}"' if category is not None else "" + return f'{type_name}' + + +def _expand_body_html(body: LeafBody) -> str: + """Assemble the expand panel: kwargs block, then dtype/device.""" + parts: list[str] = [] + if body.kwargs_block: + parts.append(body.kwargs_block) + if body.dtype is not None: + if parts: + parts.append( + '
' + f'{body.array_label}' + "
" + ) + parts.append( + f'dtype {body.dtype}' + f"
" + f'device {body.device}' + ) + return "".join(parts) + + +def render_node(node: TreeNode) -> str: + """Recursively render a non-root node (leaf or container) to HTML.""" + name_html = _name_html(node.name) + type_span = _type_html(node.type_name, node.category) + + if node.children is not None: + params = ( + f'{node.params_note}' + if node.params_note + else "" + ) + children_html = "\n".join(render_node(c) for c in node.children) + return ( + f'
' + f"" + f'' + f"{name_html}{type_span}" + f"{params}" + f"" + f'
{children_html}
' + f"
" + ) + + # Leaf row. + frozen = _frozen_attr(node.is_frozen) + pos = ( + f' {node.positional}' + if node.positional + else "" + ) + params = ( + f'{node.params_note}' + if node.params_note + else "" + ) + if node.body is not None and node.body.has_content(): + kw_inline = ( + f' {node.body.kwargs_inline}' + if node.body.kwargs_inline + else "" + ) + return ( + f'
' + f"" + f'' + f"{name_html}{type_span}{pos}{kw_inline}" + f"{params}" + f"" + f'
' + f"{_expand_body_html(node.body)}
" + f"
" + ) + return ( + f'
' + f'' + f"{name_html}{type_span}{pos}" + f"{params}" + f"
" + ) + + +def render_model( + *, + root_type: str, + summary: str, + nodes: list[TreeNode], + leaf_fallback: str = "", +) -> Html: + """Assemble the full tree: header, body, footer legend. + + `summary` is the right-aligned header text (e.g. `"36.1K params · …"`). + `nodes` are the root's direct children; when empty, the root is itself + a leaf and `leaf_fallback` (its inline args) is shown instead. + """ + header = ( + f'
' + f'{root_type}' + f'{summary}' + f"
" + ) + if nodes: + body_html = "\n".join(render_node(n) for n in nodes) + body = f'
{body_html}
' + else: + inner = ( + f'{leaf_fallback}' + if leaf_fallback + else "" + ) + body = ( + f'
' + f'
{inner}
' + f"
" + ) + divider = '
' + return Html( + f'
' + f"{header}{divider}{body}{_footer_html()}" + f"
" + ) + + _CSS = """\ .nn-t { font-size: 0.8125rem; diff --git a/marimo/_output/formatters/flax_formatters.py b/marimo/_output/formatters/flax_formatters.py index e5ad65ebb9f..782140f5734 100644 --- a/marimo/_output/formatters/flax_formatters.py +++ b/marimo/_output/formatters/flax_formatters.py @@ -16,11 +16,12 @@ import typing from marimo._output.formatters._nn_tree import ( - _CSS, + LeafBody, ModuleCategory, + TreeNode, _comma_to_br, _fmt_integer, - _footer_html, + render_model, ) from marimo._output.formatters.formatter_factory import FormatterFactory from marimo._output.hypertext import Html @@ -173,81 +174,48 @@ def _config_kwargs(mod: nnx.Module) -> str: return ", ".join(parts) -def _walk(name: str, mod: nnx.Module) -> str: - """Recursively build HTML tree for an nnx.Module (non-root nodes).""" - children = _child_modules(mod) +def _node(name: str, mod: nnx.Module) -> TreeNode: + """Build a `TreeNode` for an nnx.Module (recursing into children).""" type_name = mod.__class__.__name__ cat = _layer_category(mod) + children = _child_modules(mod) + param_count, _ = _counts(mod) + params_note = _fmt_integer(param_count) if param_count > 0 else "" - name_html = f'{html.escape(name)} ' - cat_attr = f' data-cat="{cat}"' if cat is not None else "" - type_span = f'{type_name}' - - if not children: - param_count, _ = _counts(mod) - kwargs = _config_kwargs(mod) - note = _fmt_integer(param_count) if param_count > 0 else "" - params = f'{note}' if note else "" - - # Build expand body: kwargs first, then dtype/device. - body_parts: list[str] = [] - if kwargs: - body_parts.append(_comma_to_br(kwargs)) - param_leaves = _param_leaves(mod) - if param_leaves: - dtype_s, device_s = _collect_dtype_device(param_leaves) - if body_parts: - body_parts.append( - '
' - 'array' - "
" - ) - body_parts.append( - f'dtype {dtype_s}' - f"
" - f'device {device_s}' - ) - - if body_parts: - kw_inline = ( - f' {kwargs}' if kwargs else "" - ) - return ( - f'
' - f"" - f'' - f"{name_html}{type_span}{kw_inline}" - f"{params}" - f"" - f'
{"".join(body_parts)}
' - f"
" - ) - return ( - f'
' - f'' - f"{name_html}{type_span}" - f"{params}" - f"
" + if children: + return TreeNode( + name=html.escape(name), + type_name=type_name, + category=cat, + params_note=params_note, + children=[_node(n, c) for n, c in children], ) - # Container node: aggregate all descendant parameters. - param_count, _ = _counts(mod) - note = _fmt_integer(param_count) if param_count > 0 else "" - total_params = f'{note}' if note else "" - - children_html = "\n".join( - _walk(child_name, child_mod) for child_name, child_mod in children - ) + # Leaf module. + kwargs = _config_kwargs(mod) + param_leaves = _param_leaves(mod) + body: LeafBody | None = None + if kwargs or param_leaves: + dtype = device = None + if param_leaves: + dtype, device = _collect_dtype_device(param_leaves) + body = LeafBody( + kwargs_inline=kwargs, + kwargs_block=_comma_to_br(kwargs) if kwargs else "", + dtype=dtype, + device=device, + array_label="array", + ) - return ( - f'
' - f"" - f'' - f"{name_html}{type_span}" - f"{total_params}" - f"" - f'
{children_html}
' - f"
" + return TreeNode( + name=html.escape(name), + type_name=type_name, + category=cat, + params_note=params_note, + # NNX has no frozen concept, but -- like PyTorch -- we dim + # params-less leaves (e.g. Dropout) to match the legend. + is_frozen=param_count == 0, + body=body, ) @@ -266,42 +234,16 @@ def format(module: nnx.Module) -> Html: # noqa: A001 children = _child_modules(module) total_params, total_bytes = _counts(module) size_mb = total_bytes / (1024 * 1024) + summary = f"{_fmt_integer(total_params)} params · {size_mb:.1f} MB" - header = ( - f'
' - f'{module.__class__.__name__}' - f'' - f"{_fmt_integer(total_params)} params" - f" · {size_mb:.1f} MB" - f"" - f"
" - ) - - if children: - body_html = "\n".join( - _walk(child_name, child_mod) for child_name, child_mod in children - ) - body = f'
{body_html}
' - else: - kwargs = _config_kwargs(module) - extra_html = ( - f'{kwargs}' if kwargs else "" - ) - body = ( - f'
' - f'
{extra_html}
' - f"
" - ) - - divider = '
' - footer = _footer_html() + leaf_fallback = "" if children else _config_kwargs(module) - html_str = ( - f'
' - f"{header}{divider}{body}{footer}" - f"
" + return render_model( + root_type=module.__class__.__name__, + summary=summary, + nodes=[_node(n, c) for n, c in children], + leaf_fallback=leaf_fallback, ) - return Html(html_str) class FlaxFormatter(FormatterFactory): diff --git a/marimo/_output/formatters/pytorch_formatters.py b/marimo/_output/formatters/pytorch_formatters.py index 6f94e465ebb..54c5486d085 100644 --- a/marimo/_output/formatters/pytorch_formatters.py +++ b/marimo/_output/formatters/pytorch_formatters.py @@ -7,12 +7,12 @@ import typing from marimo._output.formatters._nn_tree import ( - _CSS, + LeafBody, ModuleCategory, + TreeNode, _comma_to_br, _fmt_integer, - _footer_html, - _frozen_attr, + render_model, ) from marimo._output.formatters.formatter_factory import FormatterFactory from marimo._output.hypertext import Html @@ -144,102 +144,55 @@ def _layer_category(module: torch.nn.Module) -> ModuleCategory | None: return None -def _walk(mod: torch.nn.Module, name: str = "") -> str: - """Recursively build HTML tree for an nn.Module (non-root nodes).""" - children = list(mod.named_children()) +def _node(mod: torch.nn.Module, name: str = "") -> TreeNode: + """Build a `TreeNode` for an nn.Module (recursing into children).""" type_name = mod.__class__.__name__ - extra = _extra_repr_html(mod) cat = _layer_category(mod) + children = list(mod.named_children()) - name_html = f'{name} ' if name else "" - cat_attr = f' data-cat="{cat}"' if cat is not None else "" - type_span = f'{type_name}' - pos_args = ( - f' {extra.positional}' - if extra.positional - else "" - ) - - if not children: - own_params = list(mod.parameters(recurse=False)) - num_params = sum(p.numel() for p in own_params) - num_trainable = sum(p.numel() for p in own_params if p.requires_grad) - info = _trainable_info(num_params, num_trainable) - frozen = _frozen_attr(info.is_frozen or num_params == 0) - - params = ( - f'' - f"{_fmt_integer(num_params)}{info.note}" - if num_params > 0 - else "" - ) - - # Build expand body: kwargs first, then dtype/device - body_parts: list[str] = [] - if extra.kwargs: - body_parts.append(_comma_to_br(extra.kwargs)) - if own_params: - dtype_s, device_s = _collect_dtype_device(own_params) - if body_parts: - body_parts.append( - '
' - 'tensor' - "
" - ) - body_parts.append( - f'dtype {dtype_s}' - f"
" - f'device {device_s}' - ) - - if body_parts: - kw_inline = ( - f' {extra.kwargs}' - if extra.kwargs - else "" - ) - return ( - f'
' - f"" - f'' - f"{name_html}{type_span}{pos_args}{kw_inline}" - f"{params}" - f"" - f'
{"".join(body_parts)}
' - f"
" - ) - return ( - f'
' - f'' - f"{name_html}{type_span}{pos_args}" - f"{params}" - f"
" + if children: + all_sub = list(mod.parameters()) + total = sum(p.numel() for p in all_sub) + trainable = sum(p.numel() for p in all_sub if p.requires_grad) + info = _trainable_info(total, trainable) + return TreeNode( + name=name, + type_name=type_name, + category=cat, + params_note=f"{_fmt_integer(total)}{info.note}", + is_frozen=info.is_frozen, + children=[_node(c, n) for n, c in children], ) - # Container node: aggregate all descendant parameters - all_sub = list(mod.parameters()) - total_sub = sum(p.numel() for p in all_sub) - total_trainable = sum(p.numel() for p in all_sub if p.requires_grad) - info = _trainable_info(total_sub, total_trainable) - - total_params = ( - f'' - f"{_fmt_integer(total_sub)}{info.note}" - ) + # Leaf module. + own = list(mod.parameters(recurse=False)) + num = sum(p.numel() for p in own) + trainable = sum(p.numel() for p in own if p.requires_grad) + info = _trainable_info(num, trainable) + extra = _extra_repr_html(mod) - children_html = "\n".join( - _walk(child_mod, child_name) for child_name, child_mod in children - ) + body: LeafBody | None = None + if extra.kwargs or own: + dtype = device = None + if own: + dtype, device = _collect_dtype_device(own) + body = LeafBody( + kwargs_inline=extra.kwargs, + kwargs_block=_comma_to_br(extra.kwargs) if extra.kwargs else "", + dtype=dtype, + device=device, + array_label="tensor", + ) - return ( - f'
' - f"" - f'' - f"{name_html}{type_span}" - f"{total_params}" - f"" - f'
{children_html}
' - f"
" + return TreeNode( + name=name, + type_name=type_name, + category=cat, + params_note=f"{_fmt_integer(num)}{info.note}" if num > 0 else "", + # Dim frozen layers and params-less layers (e.g. activations). + is_frozen=info.is_frozen or num == 0, + positional=extra.positional, + body=body, ) @@ -256,8 +209,6 @@ def format(module: torch.nn.Module) -> Html: # noqa: A001 A `marimo.Html` object with the rendered tree. """ all_params = list(module.parameters()) - children = list(module.named_children()) - total_params = sum(p.numel() for p in all_params) trainable_params = sum(p.numel() for p in all_params if p.requires_grad) size_bytes = sum(p.numel() * p.element_size() for p in all_params) @@ -268,44 +219,25 @@ def format(module: torch.nn.Module) -> Html: # noqa: A001 if trainable_params != total_params else "" ) - header = ( - f'
' - f'{module.__class__.__name__}' - f'' + summary = ( f"{_fmt_integer(total_params)} params{trainable_note}" f" \u00b7 {size_mb:.1f} MB" - f"" - f"
" ) - if children: - body_html = "\n".join( - _walk(child_mod, child_name) for child_name, child_mod in children - ) - body = f'
{body_html}
' - else: + children = list(module.named_children()) + leaf_fallback = "" + if not children: extra = _extra_repr_html(module) - combined = ", ".join( + leaf_fallback = ", ".join( part for part in (extra.positional, extra.kwargs) if part ) - extra_html = ( - f'{combined}' if combined else "" - ) - body = ( - f'
' - f'
{extra_html}
' - f"
" - ) - - divider = '
' - footer = _footer_html() - html = ( - f'
' - f"{header}{divider}{body}{footer}" - f"
" + return render_model( + root_type=module.__class__.__name__, + summary=summary, + nodes=[_node(c, n) for n, c in children], + leaf_fallback=leaf_fallback, ) - return Html(html) class PyTorchFormatter(FormatterFactory): diff --git a/marimo/_smoke_tests/formatters/flax_formatters.py b/marimo/_smoke_tests/formatters/flax_formatters.py index a93ce9ce0ec..4be77ef62f8 100644 --- a/marimo/_smoke_tests/formatters/flax_formatters.py +++ b/marimo/_smoke_tests/formatters/flax_formatters.py @@ -95,7 +95,7 @@ def __call__(self, x): @app.cell def _(): - # Mini Transformer (attention is its own category in the legend) + # Mini Transformer (MultiHeadAttention maps to the "weight" category) class TransformerBlock(nnx.Module): def __init__(self, d_model: int, nhead: int, rngs: nnx.Rngs): self.attn = nnx.MultiHeadAttention( diff --git a/pyproject.toml b/pyproject.toml index 0860dfc824c..d8a1c648553 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -216,7 +216,8 @@ test-optional = [ "sympy>=1.13.3", "jupytext>=1.17.2", # For standard scientific computing/ ML - "jax>=0.4.0; python_version == '3.12'", + # flax 0.10 requires jax>=0.4.27, so keep the floors compatible. + "jax>=0.4.27; python_version == '3.12'", "flax>=0.10.0; python_version == '3.12'", "torch>=2.4.0; python_version == '3.12'", "scikit-bio>=0.6.3; python_version == '3.12'", diff --git a/tests/_output/formatters/test_flax_formatters.py b/tests/_output/formatters/test_flax_formatters.py index 7787fcb5b29..0db48edcb20 100644 --- a/tests/_output/formatters/test_flax_formatters.py +++ b/tests/_output/formatters/test_flax_formatters.py @@ -116,6 +116,29 @@ def __init__(self, rngs: nnx.Rngs) -> None: # Trainable params (BatchNorm scale + bias = 32) are still shown. assert "params" in html + def test_paramless_leaf_is_dimmed(self) -> None: + """Params-less leaves (e.g. Dropout) get `data-frozen`, matching the + shared "Frozen / no params" legend and the PyTorch formatter.""" + from flax import nnx + + from marimo._output.formatters.flax_formatters import format + + class Model(nnx.Module): + def __init__(self, rngs: nnx.Rngs) -> None: + self.linear = nnx.Linear(8, 8, rngs=rngs) + self.drop = nnx.Dropout(0.2, rngs=rngs) + + html = format(Model(nnx.Rngs(0))).text + + def row_open_tag(name: str) -> str: + """The opening tag of the leaf row containing `name`.""" + before = html[: html.index(f'nn-t-name">{name}<')] + return before[before.rindex(" None: from flax import nnx From 086bd7e3a3ecd9b9104581fe55e92a61d2cd166b Mon Sep 17 00:00:00 2001 From: koaning Date: Tue, 16 Jun 2026 17:16:19 +0200 Subject: [PATCH 4/6] use nodes, reply to copilot feedback --- marimo/_output/formatters/_nn_tree.py | 11 ++++++++--- marimo/_output/formatters/flax_formatters.py | 4 ++-- .../_output/formatters/test_flax_formatters.py | 7 ------- .../formatters/test_pytorch_formatters.py | 17 +++++++++++++++++ 4 files changed, 27 insertions(+), 12 deletions(-) diff --git a/marimo/_output/formatters/_nn_tree.py b/marimo/_output/formatters/_nn_tree.py index 22329c0c23d..298c4115dc1 100644 --- a/marimo/_output/formatters/_nn_tree.py +++ b/marimo/_output/formatters/_nn_tree.py @@ -10,6 +10,7 @@ from __future__ import annotations import dataclasses +import html import re import typing @@ -131,12 +132,16 @@ class TreeNode: def _name_html(name: str) -> str: - return f'{name} ' if name else "" + # Names can be arbitrary strings (e.g. torch `add_module`), so escape. + return f'{html.escape(name)} ' if name else "" def _type_html(type_name: str, category: ModuleCategory | None) -> str: + # Class names are arbitrary strings (e.g. via `type()`), so escape. cat_attr = f' data-cat="{category}"' if category is not None else "" - return f'{type_name}' + return ( + f'{html.escape(type_name)}' + ) def _expand_body_html(body: LeafBody) -> str: @@ -235,7 +240,7 @@ def render_model( """ header = ( f'
' - f'{root_type}' + f'{html.escape(root_type)}' f'{summary}' f"
" ) diff --git a/marimo/_output/formatters/flax_formatters.py b/marimo/_output/formatters/flax_formatters.py index 782140f5734..382043c6d8f 100644 --- a/marimo/_output/formatters/flax_formatters.py +++ b/marimo/_output/formatters/flax_formatters.py @@ -184,7 +184,7 @@ def _node(name: str, mod: nnx.Module) -> TreeNode: if children: return TreeNode( - name=html.escape(name), + name=name, type_name=type_name, category=cat, params_note=params_note, @@ -208,7 +208,7 @@ def _node(name: str, mod: nnx.Module) -> TreeNode: ) return TreeNode( - name=html.escape(name), + name=name, type_name=type_name, category=cat, params_note=params_note, diff --git a/tests/_output/formatters/test_flax_formatters.py b/tests/_output/formatters/test_flax_formatters.py index 0db48edcb20..250878bd0ce 100644 --- a/tests/_output/formatters/test_flax_formatters.py +++ b/tests/_output/formatters/test_flax_formatters.py @@ -273,13 +273,6 @@ def test_collect_dtype_device(self) -> None: # Empty -> en-dash placeholders. assert _collect_dtype_device([]) == ("–", "–") - def test_fmt_integer(self) -> None: - from marimo._output.formatters._nn_tree import _fmt_integer - - assert _fmt_integer(500) == "500" - assert _fmt_integer(1_500) == "1.5K" - assert _fmt_integer(1_500_000) == "1.5M" - def test_formatter_registration(self) -> None: """Smoke test: the formatter registers and produces output.""" register_formatters() diff --git a/tests/_output/formatters/test_pytorch_formatters.py b/tests/_output/formatters/test_pytorch_formatters.py index 09da499d08d..cab3ee9023e 100644 --- a/tests/_output/formatters/test_pytorch_formatters.py +++ b/tests/_output/formatters/test_pytorch_formatters.py @@ -297,3 +297,20 @@ def test_formatter_registration(self) -> None: assert mimetype == "text/html" assert "nn-t" in data assert "Linear" in data + + def test_escapes_html_in_names(self) -> None: + """Module/class names are escaped to prevent HTML injection. + + marimo renders this output as HTML in the browser (incl. served + apps and exported notebooks), and names can be arbitrary strings + (e.g. `add_module` / `type()`), so they must be escaped. + """ + from torch import nn + + from marimo._output.formatters.pytorch_formatters import format + + parent = nn.Module() + parent.add_module("", nn.Linear(2, 2)) + html = format(parent).text + assert " Date: Tue, 16 Jun 2026 15:17:01 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- marimo/_output/formatters/_nn_tree.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/marimo/_output/formatters/_nn_tree.py b/marimo/_output/formatters/_nn_tree.py index 298c4115dc1..b8e4e7c726d 100644 --- a/marimo/_output/formatters/_nn_tree.py +++ b/marimo/_output/formatters/_nn_tree.py @@ -133,15 +133,15 @@ class TreeNode: def _name_html(name: str) -> str: # Names can be arbitrary strings (e.g. torch `add_module`), so escape. - return f'{html.escape(name)} ' if name else "" + return ( + f'{html.escape(name)} ' if name else "" + ) def _type_html(type_name: str, category: ModuleCategory | None) -> str: # Class names are arbitrary strings (e.g. via `type()`), so escape. cat_attr = f' data-cat="{category}"' if category is not None else "" - return ( - f'{html.escape(type_name)}' - ) + return f'{html.escape(type_name)}' def _expand_body_html(body: LeafBody) -> str: From 8f9825a39efddb97a21014ba1946b115352ca466 Mon Sep 17 00:00:00 2001 From: koaning Date: Wed, 17 Jun 2026 10:35:55 +0200 Subject: [PATCH 6/6] Format billion-scale param counts as "B" in the NN-tree formatters _fmt_integer only handled K/M, so large (LLM-scale) models showed unwieldy raw counts. Add a billions branch and extend the test. Co-Authored-By: Claude Opus 4.8 --- marimo/_output/formatters/_nn_tree.py | 2 ++ tests/_output/formatters/test_pytorch_formatters.py | 1 + 2 files changed, 3 insertions(+) diff --git a/marimo/_output/formatters/_nn_tree.py b/marimo/_output/formatters/_nn_tree.py index b8e4e7c726d..ffe1aeaf9dd 100644 --- a/marimo/_output/formatters/_nn_tree.py +++ b/marimo/_output/formatters/_nn_tree.py @@ -48,6 +48,8 @@ def _frozen_attr(is_frozen: bool) -> str: def _fmt_integer(n: int) -> str: """Format int into a human readable string.""" + if n >= 1_000_000_000: + return f"{n / 1_000_000_000:.1f}B" if n >= 1_000_000: return f"{n / 1_000_000:.1f}M" if n >= 1_000: diff --git a/tests/_output/formatters/test_pytorch_formatters.py b/tests/_output/formatters/test_pytorch_formatters.py index cab3ee9023e..a2b8dc30603 100644 --- a/tests/_output/formatters/test_pytorch_formatters.py +++ b/tests/_output/formatters/test_pytorch_formatters.py @@ -148,6 +148,7 @@ def test_param_count_formatting(self) -> None: assert _fmt_integer(500) == "500" assert _fmt_integer(1_500) == "1.5K" assert _fmt_integer(1_500_000) == "1.5M" + assert _fmt_integer(1_500_000_000) == "1.5B" def test_extra_repr_html(self) -> None: from torch import nn