2D Heat Equation Solution using MLX and some observations #83
sck-at-ucy
started this conversation in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I have implemented a simple solution of the 2D Heat Conduction Equation with 2 Neumann and 2 Dirichlet BCs. I have the code implemented both using PyTorch and the MLX framework and I am testing the relative performance on an M2 Ultra with 128GB memory.
I started with a relatively small sized problem, and found the MLX version to be roughly about 2X faster then the Torch version. Then, I used a much higher grid resolution than actually needed, the idea was to make the problem sufficiently large to utilize the GPU substantially. Higher resolution means smaller time steps, so to reach the same final state more time steps are needed, this is because of the explicit Euler time discretization and the stability criterion.
Here are some observations from this simple exercise and I would love to get the comments/insights from @awni and the team.
Observation 1. For the small-sized problem the MLX version was about 2X faster than the Torch version.
Observation 2. For the large-sized problem, the MLX code would seem to hang if run for sufficiently long times (many steps). It would run fine below a certain number of steps, how many steps it would run without hanging is dependent on the problem size.
Observation 3. Adding an mx.eval()
if step % nsteps == 0: mx.eval(T)
where nsteps defines every how many steps mx.eval(T) is applied, solved the problem and the code runs fine and faster than Torch, provided nsteps < than the number of steps that cause the code to hang. However, if nsteps is larger than the number of steps that would cause the code to hung, then I get an error:
Process finished with exit code 139 (interrupted by signal 11:SIGSEGV)
Observation 4. The number of steps that causes the segmentation fault is dependent on the resolution used (problem size).
Observation 5. Actually, for max_steps < "number of steps where the segmentation fault shows up", the MLX version is almost 10X faster than the Torch version for this large-sized problem.
These observations suggest to me that both the good performance (2X to 10X !) and the problems are perhaps tied to the lazy evaluation? Is there a better way to handle this?
Any insights about these observations and any ideas to help me understand what is going on behind the scenes?
Overall, really impressed!
2D_Heat_Equation_MLX.txt
Beta Was this translation helpful? Give feedback.
All reactions