diff --git a/marimo/_runtime/reload/autoreload.py b/marimo/_runtime/reload/autoreload.py index 5292a86a60d..7ffd040d0a0 100644 --- a/marimo/_runtime/reload/autoreload.py +++ b/marimo/_runtime/reload/autoreload.py @@ -558,12 +558,20 @@ def superreload( continue new_refs = [] + kept_old_obj = None for old_ref in old_objects[key]: old_obj = old_ref() if old_obj is None: continue new_refs.append(old_ref) update_generic(old_obj, new_obj) + kept_old_obj = old_obj + + # Keep the updated old object in the module namespace so that + # dependent modules (e.g. `from enums import Fruits`) still refer + # to the same class/function instances after reload. + if kept_old_obj is not None: + module.__dict__[name] = kept_old_obj if new_refs: old_objects[key] = new_refs diff --git a/tests/_runtime/reload/test_autoreload.py b/tests/_runtime/reload/test_autoreload.py index d30c43cec56..53edd100874 100644 --- a/tests/_runtime/reload/test_autoreload.py +++ b/tests/_runtime/reload/test_autoreload.py @@ -846,6 +846,73 @@ def func2(): assert isinstance2(func1, "not_func", types.FunctionType) is False +def test_reload_enum_across_dependent_modules( + tmp_path: pathlib.Path, py_modname: str +): + """Reloading a module with an Enum must not break equality in dependents. + + Regression test for https://github.com/marimo-team/marimo/issues/9808 + """ + sys.path.append(str(tmp_path)) + enums_name = f"{py_modname}_enums" + utils_name = f"{py_modname}_utils" + enums_file = tmp_path / f"{enums_name}.py" + utils_file = tmp_path / f"{utils_name}.py" + + enums_file.write_text( + textwrap.dedent( + """ + from enum import Enum + + class Fruits(Enum): + APPLE = "apple" + BANANA = "banana" + ORANGE = "orange" + + class A: + a = "A" + """ + ) + ) + utils_file.write_text( + textwrap.dedent( + f""" + from {enums_name} import Fruits + + def is_orange(fruit: Fruits): + return fruit == Fruits.ORANGE + """ + ) + ) + + enums_mod = importlib.import_module(enums_name) + utils_mod = importlib.import_module(utils_name) + reloader = ModuleReloader() + reloader.check(sys.modules, reload=False) + + assert utils_mod.is_orange(enums_mod.Fruits.ORANGE) is True + + update_file( + enums_file, + """ + from enum import Enum + + class Fruits(Enum): + APPLE = "apple" + BANANA = "banana" + ORANGE = "orange" + + class A: + a = "B" + """, + ) + reloader.check(sys.modules, reload=True) + + assert utils_mod.Fruits is enums_mod.Fruits + assert utils_mod.is_orange(enums_mod.Fruits.ORANGE) is True + assert enums_mod.A.a == "B" + + class TestSuperreload: """Tests for superreload function""" diff --git a/tests/_runtime/reload/test_module_watcher.py b/tests/_runtime/reload/test_module_watcher.py index 0a7cb7ba617..27bf57ecc87 100644 --- a/tests/_runtime/reload/test_module_watcher.py +++ b/tests/_runtime/reload/test_module_watcher.py @@ -39,6 +39,100 @@ def _setup_test_sleep(): # these tests use random filenames for modules because they share # the same sys.modules object, and each test needs fresh modules +@pytest.mark.flaky(reruns=3) +async def test_reload_enum_across_dependent_modules( + tmp_path: pathlib.Path, + py_modname: str, + execution_kernel: Kernel, + exec_req: ExecReqProvider, +): + """Reloading an Enum module must not break comparisons in dependents. + + Regression test for https://github.com/marimo-team/marimo/issues/9808 + """ + k = execution_kernel + sys.path.append(str(tmp_path)) + enums_name = f"{py_modname}_enums" + utils_name = f"{py_modname}_utils" + enums_file = tmp_path / f"{enums_name}.py" + utils_file = tmp_path / f"{utils_name}.py" + + enums_file.write_text( + textwrap.dedent( + """ + from enum import Enum + + class Fruits(Enum): + APPLE = "apple" + BANANA = "banana" + ORANGE = "orange" + + class A: + a = "A" + """ + ) + ) + utils_file.write_text( + textwrap.dedent( + f""" + from {enums_name} import Fruits + + def is_orange(fruit: Fruits): + return fruit == Fruits.ORANGE + """ + ) + ) + + config = copy.deepcopy(DEFAULT_CONFIG) + config["runtime"]["auto_reload"] = "lazy" + k.set_user_config(UpdateUserConfigCommand(config=config)) + await k.run( + [ + er_1 := exec_req.get( + f"from {enums_name} import Fruits; " + f"from {utils_name} import is_orange" + ), + er_2 := exec_req.get("result = is_orange(Fruits.ORANGE)"), + er_3 := exec_req.get("pass"), + ] + ) + assert k.globals["result"] is True + + update_file( + enums_file, + """ + from enum import Enum + + class Fruits(Enum): + APPLE = "apple" + BANANA = "banana" + ORANGE = "orange" + + class A: + a = "B" + """, + ) + + retries = 0 + while retries < 10: + await asyncio.sleep(INTERVAL) + retries += 1 + if ( + k.graph.cells[er_1.cell_id].stale + and k.graph.cells[er_2.cell_id].stale + ): + break + + assert k.graph.cells[er_1.cell_id].stale + assert k.graph.cells[er_2.cell_id].stale + assert not k.graph.cells[er_3.cell_id].stale + await k.run_stale_cells() + assert not k.graph.cells[er_1.cell_id].stale + assert not k.graph.cells[er_2.cell_id].stale + assert not k.graph.cells[er_3.cell_id].stale + assert k.globals["result"] is True + + @pytest.mark.flaky(reruns=3) async def test_reload_function( tmp_path: pathlib.Path,