Skip to content

Commit cbfc509

Browse files
committed
implement VideoViz to record model runs in a video
1 parent abc97b3 commit cbfc509

File tree

3 files changed

+244
-0
lines changed

3 files changed

+244
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,6 @@ dmypy.json
9292
# JS dependencies
9393
mesa/visualization/templates/external/
9494
mesa/visualization/templates/js/external/
95+
96+
# Video
97+
**/*.mp4
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Example of using VideoViz with the Schelling model."""
2+
3+
from mesa.examples.basic.schelling.model import Schelling
4+
from mesa.visualization.video_viz import (
5+
VideoViz,
6+
make_measure_component,
7+
make_space_component,
8+
)
9+
10+
# Create model
11+
model = Schelling(10, 10)
12+
13+
14+
def agent_portrayal(agent):
15+
"""Portray agents based on their type."""
16+
if agent is None:
17+
return {}
18+
19+
portrayal = {
20+
"color": "red" if agent.type == 0 else "blue",
21+
"size": 25,
22+
"marker": "s", # square marker
23+
}
24+
return portrayal
25+
26+
27+
# Create visualization with space and some metrics
28+
viz = VideoViz(
29+
model,
30+
[
31+
make_space_component(agent_portrayal=agent_portrayal),
32+
make_measure_component("happy"),
33+
],
34+
title="Schelling's Segregation Model",
35+
)
36+
37+
# Record simulation
38+
if __name__ == "__main__":
39+
viz.record(steps=50, filepath="schelling.mp4")
40+
print("Video saved to: schelling.mp4")

mesa/visualization/video_viz.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
"""Mesa visualization module for recording videos of model simulations.
2+
3+
This module uses Matplotlib to create visualizations of model spaces and
4+
measures, and records them as videos.
5+
6+
Please install FFmpeg to use this module:
7+
- macOS: brew install ffmpeg
8+
- Linux: sudo apt-get install ffmpeg
9+
- Windows: download from https://ffmpeg.org/download.html
10+
"""
11+
12+
import shutil
13+
from collections.abc import Callable, Sequence
14+
from pathlib import Path
15+
16+
import matplotlib.animation as animation
17+
import matplotlib.pyplot as plt
18+
import numpy as np
19+
20+
import mesa
21+
from mesa.visualization.mpl_drawing import (
22+
draw_plot,
23+
draw_space,
24+
)
25+
26+
27+
def make_space_component(
28+
agent_portrayal: Callable | None = None,
29+
propertylayer_portrayal: dict | None = None,
30+
post_process: Callable | None = None,
31+
**space_drawing_kwargs,
32+
) -> Callable[[mesa.Model, plt.Axes | None], plt.Axes]:
33+
"""Create a Matplotlib-based space visualization component.
34+
35+
Args:
36+
agent_portrayal: Function to portray agents.
37+
propertylayer_portrayal: Dictionary of PropertyLayer portrayal specifications
38+
post_process : a callable that will be called with the Axes instance. Allows for fine tuning plots (e.g., control ticks)
39+
space_drawing_kwargs : additional keyword arguments to be passed on to the underlying space drawer function. See
40+
the functions for drawing the various spaces for further details.
41+
42+
``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
43+
"size", "marker", and "zorder". Other field are ignored and will result in a user warning.
44+
45+
46+
Returns:
47+
function: A function that returns a Axes instance with the space drawn
48+
"""
49+
if agent_portrayal is None:
50+
51+
def agent_portrayal(a):
52+
return {}
53+
54+
def _make_space_component(model, ax=None):
55+
space = getattr(model, "grid", None) or getattr(model, "space", None)
56+
ax = draw_space(
57+
space,
58+
agent_portrayal,
59+
propertylayer_portrayal,
60+
ax,
61+
**space_drawing_kwargs,
62+
)
63+
if post_process:
64+
post_process(ax)
65+
return ax
66+
67+
return _make_space_component
68+
69+
70+
def make_measure_component(
71+
measure: Callable,
72+
post_process: Callable | None = None,
73+
**kwargs,
74+
) -> Callable[[mesa.Model, plt.Axes | None], plt.Axes]:
75+
"""Create a plotting function for a specified measure.
76+
77+
Args:
78+
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
79+
post_process : a callable that will be called with the Axes instance. Allows for fine tuning plots (e.g., control ticks)
80+
kwargs: Additional keyword arguments to pass to the MeasureRendererMatplotlib constructor.
81+
82+
Returns:
83+
function: A function that returns a Axes instance with the measure(s) drawn
84+
"""
85+
86+
def _make_measure_component(model, ax=None):
87+
ax = draw_plot(model, measure, ax, **kwargs)
88+
if post_process:
89+
post_process(ax)
90+
return ax
91+
92+
return _make_measure_component
93+
94+
95+
class VideoViz:
96+
"""Create high-quality video recordings of model simulations."""
97+
98+
def __init__(
99+
self,
100+
model: mesa.Model,
101+
components: Sequence[Callable[[mesa.Model, plt.Axes | None], plt.Axes]],
102+
*,
103+
title: str | None = None,
104+
figsize: tuple[float, float] | None = None,
105+
grid: tuple[int, int] | None = None,
106+
):
107+
"""Initialize video visualization configuration.
108+
109+
Args:
110+
model: The model to simulate and record
111+
components: Sequence of component objects defining what to visualize
112+
title: Optional title for the video
113+
figsize: Optional figure size in inches (width, height)
114+
grid: Optional (rows, cols) for custom layout. Auto-calculated if None.
115+
"""
116+
# Check if FFmpeg is available
117+
if not shutil.which("ffmpeg"):
118+
raise RuntimeError(
119+
"FFmpeg not found. Please install FFmpeg to save animations:\n"
120+
" - macOS: brew install ffmpeg\n"
121+
" - Linux: sudo apt-get install ffmpeg\n"
122+
" - Windows: download from https://ffmpeg.org/download.html"
123+
)
124+
self.model = model
125+
self.components = components
126+
self.title = title
127+
self.figsize = figsize
128+
self.grid = grid or self._calculate_grid(len(components))
129+
130+
# Setup figure and axes
131+
self.fig, self.axes = self._setup_figure()
132+
133+
def record(
134+
self,
135+
*,
136+
steps: int,
137+
filepath: str | Path,
138+
dpi: int = 100,
139+
fps: int = 10,
140+
codec: str = "h264",
141+
bitrate: int = 2000,
142+
) -> None:
143+
"""Record model simulation to video file.
144+
145+
Args:
146+
steps: Number of simulation steps to record
147+
filepath: Where to save the video file
148+
dpi: Resolution of the output video
149+
fps: Frames per second in the output video
150+
codec: Video codec to use
151+
bitrate: Video bitrate in kbps (default: 2000)
152+
153+
Raises:
154+
RuntimeError: If FFmpeg is not installed
155+
"""
156+
filepath = Path(filepath)
157+
158+
def update(frame_num):
159+
# Update model state
160+
self.model.step()
161+
162+
# Render all visualization frames
163+
for component, ax in zip(self.components, self.axes):
164+
ax.clear()
165+
component(self.model, ax)
166+
return self.axes
167+
168+
# Create and save animation
169+
anim = animation.FuncAnimation(
170+
self.fig, update, frames=steps, interval=1000 / fps, blit=False
171+
)
172+
173+
writer = animation.FFMpegWriter(
174+
fps=fps,
175+
codec=codec,
176+
bitrate=bitrate, # Now passing as integer
177+
)
178+
179+
anim.save(filepath, writer=writer, dpi=dpi)
180+
181+
def _calculate_grid(self, n_frames: int) -> tuple[int, int]:
182+
"""Calculate optimal grid layout for given number of frames."""
183+
cols = min(3, n_frames) # Max 3 columns
184+
rows = int(np.ceil(n_frames / cols))
185+
return (rows, cols)
186+
187+
def _setup_figure(self):
188+
"""Setup matplotlib figure and axes."""
189+
if not self.figsize:
190+
self.figsize = (5 * self.grid[1], 5 * self.grid[0])
191+
fig = plt.figure(figsize=self.figsize)
192+
axes = []
193+
194+
for i in range(len(self.components)):
195+
ax = fig.add_subplot(self.grid[0], self.grid[1], i + 1)
196+
axes.append(ax)
197+
198+
if self.title:
199+
fig.suptitle(self.title, fontsize=16)
200+
fig.tight_layout()
201+
return fig, axes

0 commit comments

Comments
 (0)