Skip to content

Commit f124235

Browse files
Adding type annotations to polyhedra.py and matrix.py (#4322)
* Fixed all mypy errors in polyhedra.py * Added type annotations to matrix.py
1 parent 9e74ee7 commit f124235

File tree

3 files changed

+49
-44
lines changed

3 files changed

+49
-44
lines changed

manim/mobject/matrix.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ def construct(self):
4141

4242
import itertools as it
4343
from collections.abc import Iterable, Sequence
44+
from typing import Any, Callable
4445

4546
import numpy as np
47+
from typing_extensions import Self
4648

4749
from manim.mobject.mobject import Mobject
4850
from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL
@@ -56,7 +58,7 @@ def construct(self):
5658
# Not sure if we should keep it or not.
5759

5860

59-
def matrix_to_tex_string(matrix):
61+
def matrix_to_tex_string(matrix: np.ndarray) -> str:
6062
matrix = np.array(matrix).astype("str")
6163
if matrix.ndim == 1:
6264
matrix = matrix.reshape((matrix.size, 1))
@@ -67,7 +69,7 @@ def matrix_to_tex_string(matrix):
6769
return prefix + " \\\\ ".join(rows) + suffix
6870

6971

70-
def matrix_to_mobject(matrix):
72+
def matrix_to_mobject(matrix: np.ndarray) -> MathTex:
7173
return MathTex(matrix_to_tex_string(matrix))
7274

7375

@@ -170,14 +172,14 @@ def __init__(
170172
bracket_v_buff: float = MED_SMALL_BUFF,
171173
add_background_rectangles_to_entries: bool = False,
172174
include_background_rectangle: bool = False,
173-
element_to_mobject: type[MathTex] = MathTex,
175+
element_to_mobject: type[Mobject] | Callable[..., Mobject] = MathTex,
174176
element_to_mobject_config: dict = {},
175177
element_alignment_corner: Sequence[float] = DR,
176178
left_bracket: str = "[",
177179
right_bracket: str = "]",
178180
stretch_brackets: bool = True,
179181
bracket_config: dict = {},
180-
**kwargs,
182+
**kwargs: Any,
181183
):
182184
self.v_buff = v_buff
183185
self.h_buff = h_buff
@@ -205,7 +207,7 @@ def __init__(
205207
if self.include_background_rectangle:
206208
self.add_background_rectangle()
207209

208-
def _matrix_to_mob_matrix(self, matrix):
210+
def _matrix_to_mob_matrix(self, matrix: np.ndarray) -> list[list[Mobject]]:
209211
return [
210212
[
211213
self.element_to_mobject(item, **self.element_to_mobject_config)
@@ -214,7 +216,7 @@ def _matrix_to_mob_matrix(self, matrix):
214216
for row in matrix
215217
]
216218

217-
def _organize_mob_matrix(self, matrix):
219+
def _organize_mob_matrix(self, matrix: list[list[Mobject]]) -> Self:
218220
for i, row in enumerate(matrix):
219221
for j, _ in enumerate(row):
220222
mob = matrix[i][j]
@@ -224,7 +226,7 @@ def _organize_mob_matrix(self, matrix):
224226
)
225227
return self
226228

227-
def _add_brackets(self, left: str = "[", right: str = "]", **kwargs):
229+
def _add_brackets(self, left: str = "[", right: str = "]", **kwargs: Any) -> Self:
228230
"""Adds the brackets to the Matrix mobject.
229231
230232
See Latex document for various bracket types.
@@ -278,13 +280,13 @@ def _add_brackets(self, left: str = "[", right: str = "]", **kwargs):
278280
self.add(l_bracket, r_bracket)
279281
return self
280282

281-
def get_columns(self):
283+
def get_columns(self) -> VGroup:
282284
r"""Return columns of the matrix as VGroups.
283285
284286
Returns
285287
--------
286-
List[:class:`~.VGroup`]
287-
Each VGroup contains a column of the matrix.
288+
:class:`~.VGroup`
289+
The VGroup contains a nested VGroup for each column of the matrix.
288290
289291
Examples
290292
--------
@@ -305,7 +307,7 @@ def construct(self):
305307
)
306308
)
307309

308-
def set_column_colors(self, *colors: str):
310+
def set_column_colors(self, *colors: str) -> Self:
309311
r"""Set individual colors for each columns of the matrix.
310312
311313
Parameters
@@ -335,13 +337,13 @@ def construct(self):
335337
column.set_color(color)
336338
return self
337339

338-
def get_rows(self):
340+
def get_rows(self) -> VGroup:
339341
r"""Return rows of the matrix as VGroups.
340342
341343
Returns
342344
--------
343-
List[:class:`~.VGroup`]
344-
Each VGroup contains a row of the matrix.
345+
:class:`~.VGroup`
346+
The VGroup contains a nested VGroup for each row of the matrix.
345347
346348
Examples
347349
--------
@@ -357,7 +359,7 @@ def construct(self):
357359
"""
358360
return VGroup(*(VGroup(*row) for row in self.mob_matrix))
359361

360-
def set_row_colors(self, *colors: str):
362+
def set_row_colors(self, *colors: str) -> Self:
361363
r"""Set individual colors for each row of the matrix.
362364
363365
Parameters
@@ -387,7 +389,7 @@ def construct(self):
387389
row.set_color(color)
388390
return self
389391

390-
def add_background_to_entries(self):
392+
def add_background_to_entries(self) -> Self:
391393
"""Add a black background rectangle to the matrix,
392394
see above for an example.
393395
@@ -400,7 +402,7 @@ def add_background_to_entries(self):
400402
mob.add_background_rectangle()
401403
return self
402404

403-
def get_mob_matrix(self) -> list[list[MathTex]]:
405+
def get_mob_matrix(self) -> list[list[Mobject]]:
404406
"""Return the underlying mob matrix mobjects.
405407
406408
Returns
@@ -410,7 +412,7 @@ def get_mob_matrix(self) -> list[list[MathTex]]:
410412
"""
411413
return self.mob_matrix
412414

413-
def get_entries(self):
415+
def get_entries(self) -> VGroup:
414416
"""Return the individual entries of the matrix.
415417
416418
Returns
@@ -483,9 +485,9 @@ def construct(self):
483485
def __init__(
484486
self,
485487
matrix: Iterable,
486-
element_to_mobject: Mobject = DecimalNumber,
487-
element_to_mobject_config: dict[str, Mobject] = {"num_decimal_places": 1},
488-
**kwargs,
488+
element_to_mobject: type[Mobject] = DecimalNumber,
489+
element_to_mobject_config: dict[str, Any] = {"num_decimal_places": 1},
490+
**kwargs: Any,
489491
):
490492
"""
491493
Will round/truncate the decimal places as per the provided config.
@@ -526,7 +528,10 @@ def construct(self):
526528
"""
527529

528530
def __init__(
529-
self, matrix: Iterable, element_to_mobject: Mobject = Integer, **kwargs
531+
self,
532+
matrix: Iterable,
533+
element_to_mobject: type[Mobject] = Integer,
534+
**kwargs: Any,
530535
):
531536
"""
532537
Will round if there are decimal entries in the matrix.
@@ -560,7 +565,12 @@ def construct(self):
560565
self.add(m0)
561566
"""
562567

563-
def __init__(self, matrix, element_to_mobject=lambda m: m, **kwargs):
568+
def __init__(
569+
self,
570+
matrix: Iterable,
571+
element_to_mobject: type[Mobject] | Callable[..., Mobject] = lambda m: m,
572+
**kwargs: Any,
573+
):
564574
super().__init__(matrix, element_to_mobject=element_to_mobject, **kwargs)
565575

566576

@@ -569,7 +579,7 @@ def get_det_text(
569579
determinant: int | str | None = None,
570580
background_rect: bool = False,
571581
initial_scale_factor: float = 2,
572-
):
582+
) -> VGroup:
573583
r"""Helper function to create determinant.
574584
575585
Parameters

manim/mobject/three_d/polyhedra.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING
5+
from collections.abc import Hashable
6+
from typing import TYPE_CHECKING, Any
67

78
import numpy as np
89

@@ -14,7 +15,7 @@
1415

1516
if TYPE_CHECKING:
1617
from manim.mobject.mobject import Mobject
17-
from manim.typing import Point3D
18+
from manim.typing import Point3D, Point3DLike_Array
1819

1920
__all__ = [
2021
"Polyhedron",
@@ -96,10 +97,10 @@ def construct(self):
9697

9798
def __init__(
9899
self,
99-
vertex_coords: list[list[float] | np.ndarray],
100+
vertex_coords: Point3DLike_Array,
100101
faces_list: list[list[int]],
101102
faces_config: dict[str, str | int | float | bool] = {},
102-
graph_config: dict[str, str | int | float | bool] = {},
103+
graph_config: dict[str, Any] = {},
103104
):
104105
super().__init__()
105106
self.faces_config = dict(
@@ -116,7 +117,7 @@ def __init__(
116117
)
117118
self.vertex_coords = vertex_coords
118119
self.vertex_indices = list(range(len(self.vertex_coords)))
119-
self.layout = dict(enumerate(self.vertex_coords))
120+
self.layout: dict[Hashable, Any] = dict(enumerate(self.vertex_coords))
120121
self.faces_list = faces_list
121122
self.face_coords = [[self.layout[j] for j in i] for i in faces_list]
122123
self.edges = self.get_edges(self.faces_list)
@@ -129,27 +130,27 @@ def __init__(
129130

130131
def get_edges(self, faces_list: list[list[int]]) -> list[tuple[int, int]]:
131132
"""Creates list of cyclic pairwise tuples."""
132-
edges = []
133+
edges: list[tuple[int, int]] = []
133134
for face in faces_list:
134135
edges += zip(face, face[1:] + face[:1])
135136
return edges
136137

137138
def create_faces(
138139
self,
139-
face_coords: list[list[list | np.ndarray]],
140+
face_coords: Point3DLike_Array,
140141
) -> VGroup:
141142
"""Creates VGroup of faces from a list of face coordinates."""
142143
face_group = VGroup()
143144
for face in face_coords:
144145
face_group.add(Polygon(*face, **self.faces_config))
145146
return face_group
146147

147-
def update_faces(self, m: Mobject):
148+
def update_faces(self, m: Mobject) -> None:
148149
face_coords = self.extract_face_coords()
149150
new_faces = self.create_faces(face_coords)
150151
self.faces.match_points(new_faces)
151152

152-
def extract_face_coords(self) -> list[list[np.ndarray]]:
153+
def extract_face_coords(self) -> Point3DLike_Array:
153154
"""Extracts the coordinates of the vertices in the graph.
154155
Used for updating faces.
155156
"""
@@ -181,7 +182,7 @@ def construct(self):
181182
self.add(obj)
182183
"""
183184

184-
def __init__(self, edge_length: float = 1, **kwargs):
185+
def __init__(self, edge_length: float = 1, **kwargs: Any):
185186
unit = edge_length * np.sqrt(2) / 4
186187
super().__init__(
187188
vertex_coords=[
@@ -216,7 +217,7 @@ def construct(self):
216217
self.add(obj)
217218
"""
218219

219-
def __init__(self, edge_length: float = 1, **kwargs):
220+
def __init__(self, edge_length: float = 1, **kwargs: Any):
220221
unit = edge_length * np.sqrt(2) / 2
221222
super().__init__(
222223
vertex_coords=[
@@ -262,7 +263,7 @@ def construct(self):
262263
self.add(obj)
263264
"""
264265

265-
def __init__(self, edge_length: float = 1, **kwargs):
266+
def __init__(self, edge_length: float = 1, **kwargs: Any):
266267
unit_a = edge_length * ((1 + np.sqrt(5)) / 4)
267268
unit_b = edge_length * (1 / 2)
268269
super().__init__(
@@ -327,7 +328,7 @@ def construct(self):
327328
self.add(obj)
328329
"""
329330

330-
def __init__(self, edge_length: float = 1, **kwargs):
331+
def __init__(self, edge_length: float = 1, **kwargs: Any):
331332
unit_a = edge_length * ((1 + np.sqrt(5)) / 4)
332333
unit_b = edge_length * ((3 + np.sqrt(5)) / 4)
333334
unit_c = edge_length * (1 / 2)
@@ -427,7 +428,7 @@ def construct(self):
427428
self.add(dots)
428429
"""
429430

430-
def __init__(self, *points: Point3D, tolerance: float = 1e-5, **kwargs):
431+
def __init__(self, *points: Point3D, tolerance: float = 1e-5, **kwargs: Any):
431432
# Build Convex Hull
432433
array = np.array(points)
433434
hull = QuickHull(tolerance)

mypy.ini

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,6 @@ ignore_errors = True
120120
[mypy-manim.mobject.logo]
121121
ignore_errors = True
122122

123-
[mypy-manim.mobject.matrix]
124-
ignore_errors = True
125-
126123
[mypy-manim.mobject.mobject]
127124
ignore_errors = True
128125

@@ -171,9 +168,6 @@ ignore_errors = True
171168
[mypy-manim.mobject.text.text_mobject]
172169
ignore_errors = True
173170

174-
[mypy-manim.mobject.three_d.polyhedra]
175-
ignore_errors = True
176-
177171
[mypy-manim.mobject.three_d.three_dimensions]
178172
ignore_errors = True
179173

0 commit comments

Comments
 (0)