diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/components/matplotlib.py index a80ef4c2832..11fa9d87e5c 100644 --- a/mesa/visualization/components/matplotlib.py +++ b/mesa/visualization/components/matplotlib.py @@ -11,7 +11,7 @@ from matplotlib.figure import Figure import mesa -from mesa.experimental.cell_space import VoronoiGrid +from mesa.experimental.cell_space import Grid, VoronoiGrid from mesa.space import PropertyLayer from mesa.visualization.utils import update_counter @@ -52,16 +52,20 @@ def SpaceMatplotlib( if space is None: space = getattr(model, "space", None) - if isinstance(space, mesa.space._Grid): - _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model) - elif isinstance(space, mesa.space.ContinuousSpace): - _draw_continuous_space(space, space_ax, agent_portrayal, model) - elif isinstance(space, mesa.space.NetworkGrid): - _draw_network_grid(space, space_ax, agent_portrayal) - elif isinstance(space, VoronoiGrid): - _draw_voronoi(space, space_ax, agent_portrayal) - elif space is None and propertylayer_portrayal: - draw_property_layers(space_ax, space, propertylayer_portrayal, model) + # https://stackoverflow.com/questions/67524641/convert-multiple-isinstance-checks-to-structural-pattern-matching + match space: + case mesa.space._Grid(): + _draw_continuous_space(space, space_ax, agent_portrayal, model) + case mesa.space.NetworkGrid(): + _draw_network_grid(space, space_ax, agent_portrayal) + case VoronoiGrid(): + _draw_voronoi(space, space_ax, agent_portrayal) + case Grid(): # matches OrthogonalMooreGrid, OrthogonalVonNeumannGrid, and Hexgrid + # fixme add a separate draw method for hexgrids in the future + _draw_discrete_space_grid(space, space_ax, agent_portrayal) + case None: + if propertylayer_portrayal: + draw_property_layers(space_ax, space, propertylayer_portrayal, model) solara.FigureMatplotlib( space_fig, format="png", bbox_inches="tight", dependencies=dependencies @@ -291,6 +295,44 @@ def portray(g): space_ax.plot(*zip(*polygon), color="black") # Plot polygon edges in black +def _draw_discrete_space_grid(space: Grid, space_ax, agent_portrayal): + if space._ndims != 2: + raise ValueError("Space must be 2D") + + def portray(g): + x = [] + y = [] + s = [] # size + c = [] # color + + for cell in g.all_cells: + for agent in cell.agents: + data = agent_portrayal(agent) + x.append(cell.coordinate[0]) + y.append(cell.coordinate[1]) + if "size" in data: + s.append(data["size"]) + if "color" in data: + c.append(data["color"]) + out = {"x": x, "y": y} + out["s"] = s + if len(c) > 0: + out["c"] = c + + return out + + space_ax.set_xlim(0, space.width) + space_ax.set_ylim(0, space.height) + + # Draw grid lines + for x in range(space.width + 1): + space_ax.axvline(x, color="gray", linestyle=":") + for y in range(space.height + 1): + space_ax.axhline(y, color="gray", linestyle=":") + + space_ax.scatter(**portray(space)) + + def make_plot_measure(measure: str | dict[str, str] | list[str] | tuple[str]): """Create a plotting function for a specified measure.