Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions marimo/_runtime/reload/autoreload.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,12 +558,20 @@ def superreload(
continue

new_refs = []
kept_old_obj = None
for old_ref in old_objects[key]:
old_obj = old_ref()
if old_obj is None:
continue
new_refs.append(old_ref)
update_generic(old_obj, new_obj)
kept_old_obj = old_obj

# Keep the updated old object in the module namespace so that
# dependent modules (e.g. `from enums import Fruits`) still refer
# to the same class/function instances after reload.
if kept_old_obj is not None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1: The new namespace rebinding restores stale objects even when update_generic cannot update them, so unsupported module attributes can be reverted to pre-reload state.

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At marimo/_runtime/reload/autoreload.py, line 573:

<comment>The new namespace rebinding restores stale objects even when `update_generic` cannot update them, so unsupported module attributes can be reverted to pre-reload state.</comment>

<file context>
@@ -558,12 +558,20 @@ def superreload(
+        # Keep the updated old object in the module namespace so that
+        # dependent modules (e.g. `from enums import Fruits`) still refer
+        # to the same class/function instances after reload.
+        if kept_old_obj is not None:
+            module.__dict__[name] = kept_old_obj
 
</file context>

module.__dict__[name] = kept_old_obj

if new_refs:
old_objects[key] = new_refs
Expand Down
67 changes: 67 additions & 0 deletions tests/_runtime/reload/test_autoreload.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,73 @@ def func2():
assert isinstance2(func1, "not_func", types.FunctionType) is False


def test_reload_enum_across_dependent_modules(
tmp_path: pathlib.Path, py_modname: str
):
"""Reloading a module with an Enum must not break equality in dependents.

Regression test for https://github.com/marimo-team/marimo/issues/9808
"""
sys.path.append(str(tmp_path))
enums_name = f"{py_modname}_enums"
utils_name = f"{py_modname}_utils"
enums_file = tmp_path / f"{enums_name}.py"
utils_file = tmp_path / f"{utils_name}.py"

enums_file.write_text(
textwrap.dedent(
"""
from enum import Enum

class Fruits(Enum):
APPLE = "apple"
BANANA = "banana"
ORANGE = "orange"

class A:
a = "A"
"""
)
)
utils_file.write_text(
textwrap.dedent(
f"""
from {enums_name} import Fruits

def is_orange(fruit: Fruits):
return fruit == Fruits.ORANGE
"""
)
)

enums_mod = importlib.import_module(enums_name)
utils_mod = importlib.import_module(utils_name)
reloader = ModuleReloader()
reloader.check(sys.modules, reload=False)

assert utils_mod.is_orange(enums_mod.Fruits.ORANGE) is True

update_file(
enums_file,
"""
from enum import Enum

class Fruits(Enum):
APPLE = "apple"
BANANA = "banana"
ORANGE = "orange"

class A:
a = "B"
""",
)
reloader.check(sys.modules, reload=True)

assert utils_mod.Fruits is enums_mod.Fruits
assert utils_mod.is_orange(enums_mod.Fruits.ORANGE) is True
assert enums_mod.A.a == "B"


class TestSuperreload:
"""Tests for superreload function"""

Expand Down
94 changes: 94 additions & 0 deletions tests/_runtime/reload/test_module_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,100 @@ def _setup_test_sleep():

# these tests use random filenames for modules because they share
# the same sys.modules object, and each test needs fresh modules
@pytest.mark.flaky(reruns=3)
async def test_reload_enum_across_dependent_modules(
tmp_path: pathlib.Path,
py_modname: str,
execution_kernel: Kernel,
exec_req: ExecReqProvider,
):
"""Reloading an Enum module must not break comparisons in dependents.

Regression test for https://github.com/marimo-team/marimo/issues/9808
"""
k = execution_kernel
sys.path.append(str(tmp_path))
enums_name = f"{py_modname}_enums"
utils_name = f"{py_modname}_utils"
enums_file = tmp_path / f"{enums_name}.py"
utils_file = tmp_path / f"{utils_name}.py"

enums_file.write_text(
textwrap.dedent(
"""
from enum import Enum

class Fruits(Enum):
APPLE = "apple"
BANANA = "banana"
ORANGE = "orange"

class A:
a = "A"
"""
)
)
utils_file.write_text(
textwrap.dedent(
f"""
from {enums_name} import Fruits

def is_orange(fruit: Fruits):
return fruit == Fruits.ORANGE
"""
)
)

config = copy.deepcopy(DEFAULT_CONFIG)
config["runtime"]["auto_reload"] = "lazy"
k.set_user_config(UpdateUserConfigCommand(config=config))
await k.run(
[
er_1 := exec_req.get(
f"from {enums_name} import Fruits; "
f"from {utils_name} import is_orange"
),
er_2 := exec_req.get("result = is_orange(Fruits.ORANGE)"),
er_3 := exec_req.get("pass"),
]
)
assert k.globals["result"] is True

update_file(
enums_file,
"""
from enum import Enum

class Fruits(Enum):
APPLE = "apple"
BANANA = "banana"
ORANGE = "orange"

class A:
a = "B"
""",
)

retries = 0
while retries < 10:
await asyncio.sleep(INTERVAL)
retries += 1
if (
k.graph.cells[er_1.cell_id].stale
and k.graph.cells[er_2.cell_id].stale
):
break

assert k.graph.cells[er_1.cell_id].stale
assert k.graph.cells[er_2.cell_id].stale
assert not k.graph.cells[er_3.cell_id].stale
await k.run_stale_cells()
assert not k.graph.cells[er_1.cell_id].stale
assert not k.graph.cells[er_2.cell_id].stale
assert not k.graph.cells[er_3.cell_id].stale
assert k.globals["result"] is True


@pytest.mark.flaky(reruns=3)
async def test_reload_function(
tmp_path: pathlib.Path,
Expand Down
Loading