Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creating threads to update visualization asynchronously #2656

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 123 additions & 19 deletions mesa/visualization/solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

import asyncio
import inspect
import threading
import time
from collections.abc import Callable
from typing import TYPE_CHECKING, Literal

Expand Down Expand Up @@ -56,6 +58,7 @@ def SolaraViz(
simulator: Simulator | None = None,
model_params=None,
name: str | None = None,
use_threads: bool = False,
):
"""Solara visualization component.

Expand All @@ -75,6 +78,8 @@ def SolaraViz(
This controls the speed of the model's automatic stepping. Defaults to 100 ms.
render_interval (int, optional): Controls how often plots are updated during a simulation,
allowing users to skip intermediate steps and update graphs less frequently.
use_threads: Flag for indicating whether to utilize multi-threading for model execution.
When checked, the model will utilize multiple threads,adjust based on system capabilities.
simulator: A simulator that controls the model (optional)
model_params (dict, optional): Parameters for (re-)instantiating a model.
Can include user-adjustable parameters and fixed parameters. Defaults to None.
Expand Down Expand Up @@ -113,6 +118,7 @@ def SolaraViz(
reactive_model_parameters = solara.use_reactive({})
reactive_play_interval = solara.use_reactive(play_interval)
reactive_render_interval = solara.use_reactive(render_interval)
reactive_use_threads = solara.use_reactive(use_threads)
with solara.AppBar():
solara.AppBarTitle(name if name else model.value.__class__.__name__)

Expand All @@ -134,12 +140,21 @@ def SolaraViz(
max=100,
step=2,
)
if reactive_use_threads.value:
solara.Text("Increase play interval to avoid skipping plots")

solara.Checkbox(
label="Use Threads",
value=reactive_use_threads,
on_value=lambda v: reactive_use_threads.set(v),
)
if not isinstance(simulator, Simulator):
ModelController(
model,
model_parameters=reactive_model_parameters,
play_interval=reactive_play_interval,
render_interval=reactive_render_interval,
use_threads=reactive_use_threads,
)
else:
SimulatorController(
Expand All @@ -148,6 +163,7 @@ def SolaraViz(
model_parameters=reactive_model_parameters,
play_interval=reactive_play_interval,
render_interval=reactive_render_interval,
use_threads=reactive_use_threads,
)
with solara.Card("Model Parameters"):
ModelCreator(
Expand Down Expand Up @@ -209,6 +225,7 @@ def ModelController(
model_parameters: dict | solara.Reactive[dict] = None,
play_interval: int | solara.Reactive[int] = 100,
render_interval: int | solara.Reactive[int] = 1,
use_threads: bool | solara.Reactive[bool] = False,
):
"""Create controls for model execution (step, play, pause, reset).

Expand All @@ -217,37 +234,77 @@ def ModelController(
model_parameters: Reactive parameters for (re-)instantiating a model.
play_interval: Interval for playing the model steps in milliseconds.
render_interval: Controls how often the plots are updated during simulation steps.Higher value reduce update frequency.
use_threads: Flag for indicating whether to utilize multi-threading for model execution.
"""
playing = solara.use_reactive(False)
running = solara.use_reactive(True)
if model_parameters is None:
model_parameters = {}
model_parameters = solara.use_reactive(model_parameters)

async def step():
while playing.value and running.value:
await asyncio.sleep(play_interval.value / 1000)
do_step()
visualization_pause_event = threading.Event()

def step():
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
while running.value and playing.value:
time.sleep(play_interval.value / 1000)
do_step()
if use_threads.value:
visualization_pause_event.set()
except Exception as e:
print(f"Error in step: {e}")
return
finally:
loop.close()

def visualization_task():
if use_threads.value:
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
while playing.value and running.value:
visualization_pause_event.wait()
visualization_pause_event.clear()
force_update()
except Exception as e:
print(f"Error in visualization_task: {e}")
finally:
loop.close()

solara.lab.use_task(
step, dependencies=[playing.value, running.value], prefer_threaded=False
step, dependencies=[playing.value, running.value], prefer_threaded=True
)

solara.use_thread(
visualization_task,
dependencies=[playing.value, running.value],
)

@function_logger(__name__)
def do_step():
"""Advance the model by the number of steps specified by the render_interval slider."""
for _ in range(render_interval.value):
model.value.step()

running.value = model.value.running
if playing.value:
for _ in range(render_interval.value):
model.value.step()
running.value = model.value.running
if not playing.value:
break
if not use_threads.value:
force_update()

force_update()
else:
for _ in range(render_interval.value):
model.value.step()
running.value = model.value.running
force_update()

@function_logger(__name__)
def do_reset():
"""Reset the model to its initial state."""
playing.value = False
running.value = True
visualization_pause_event.clear()
_mesa_logger.log(
10,
f"creating new {model.value.__class__} instance with {model_parameters.value}",
Expand Down Expand Up @@ -283,6 +340,7 @@ def SimulatorController(
model_parameters: dict | solara.Reactive[dict] = None,
play_interval: int | solara.Reactive[int] = 100,
render_interval: int | solara.Reactive[int] = 1,
use_threads: bool | solara.Reactive[bool] = False,
):
"""Create controls for model execution (step, play, pause, reset).

Expand All @@ -292,6 +350,7 @@ def SimulatorController(
model_parameters: Reactive parameters for (re-)instantiating a model.
play_interval: Interval for playing the model steps in milliseconds.
render_interval: Controls how often the plots are updated during simulation steps.Higher values reduce update frequency.
use_threads: Flag for indicating whether to utilize multi-threading for model execution.

Notes:
The `step button` increments the step by the value specified in the `render_interval` slider.
Expand All @@ -302,27 +361,72 @@ def SimulatorController(
if model_parameters is None:
model_parameters = {}
model_parameters = solara.use_reactive(model_parameters)

async def step():
while playing.value and running.value:
await asyncio.sleep(play_interval.value / 1000)
do_step()
visualization_pause_event = threading.Event()
pause_step_event = threading.Event()

def step():
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
while running.value and playing.value:
time.sleep(play_interval.value / 1000)
if use_threads.value:
pause_step_event.wait()
pause_step_event.clear()
do_step()
if use_threads.value:
visualization_pause_event.set()
except Exception as e:
print(f"Error in step: {e}")
finally:
loop.close()

def visualization_task():
if use_threads.value:
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
pause_step_event.set()
while playing.value and running.value:
visualization_pause_event.wait()
visualization_pause_event.clear()
force_update()
pause_step_event.set()
except Exception as e:
print(f"Error in visualization_task: {e}")
return
finally:
loop.close()

solara.lab.use_task(
step, dependencies=[playing.value, running.value], prefer_threaded=False
)
solara.lab.use_task(visualization_task, dependencies=[playing.value])

def do_step():
"""Advance the model by the number of steps specified by the render_interval slider."""
simulator.run_for(render_interval.value)
running.value = model.value.running
force_update()
if playing.value:
for _ in range(render_interval.value):
simulator.run_for(1)
running.value = model.value.running
if not playing.value:
break
if not use_threads.value:
force_update()

else:
for _ in range(render_interval.value):
simulator.run_for(1)
running.value = model.value.running
force_update()

def do_reset():
"""Reset the model to its initial state."""
playing.value = False
running.value = True
simulator.reset()
visualization_pause_event.clear()
pause_step_event.clear()
model.value = model.value = model.value.__class__(
simulator=simulator, **model_parameters.value
)
Expand Down
Loading