@@ -41,8 +41,10 @@ def construct(self):
41
41
42
42
import itertools as it
43
43
from collections .abc import Iterable , Sequence
44
+ from typing import Any , Callable
44
45
45
46
import numpy as np
47
+ from typing_extensions import Self
46
48
47
49
from manim .mobject .mobject import Mobject
48
50
from manim .mobject .opengl .opengl_compatibility import ConvertToOpenGL
@@ -56,7 +58,7 @@ def construct(self):
56
58
# Not sure if we should keep it or not.
57
59
58
60
59
- def matrix_to_tex_string (matrix ) :
61
+ def matrix_to_tex_string (matrix : np . ndarray ) -> str :
60
62
matrix = np .array (matrix ).astype ("str" )
61
63
if matrix .ndim == 1 :
62
64
matrix = matrix .reshape ((matrix .size , 1 ))
@@ -67,7 +69,7 @@ def matrix_to_tex_string(matrix):
67
69
return prefix + " \\ \\ " .join (rows ) + suffix
68
70
69
71
70
- def matrix_to_mobject (matrix ) :
72
+ def matrix_to_mobject (matrix : np . ndarray ) -> MathTex :
71
73
return MathTex (matrix_to_tex_string (matrix ))
72
74
73
75
@@ -170,14 +172,14 @@ def __init__(
170
172
bracket_v_buff : float = MED_SMALL_BUFF ,
171
173
add_background_rectangles_to_entries : bool = False ,
172
174
include_background_rectangle : bool = False ,
173
- element_to_mobject : type [MathTex ] = MathTex ,
175
+ element_to_mobject : type [Mobject ] | Callable [..., Mobject ] = MathTex ,
174
176
element_to_mobject_config : dict = {},
175
177
element_alignment_corner : Sequence [float ] = DR ,
176
178
left_bracket : str = "[" ,
177
179
right_bracket : str = "]" ,
178
180
stretch_brackets : bool = True ,
179
181
bracket_config : dict = {},
180
- ** kwargs ,
182
+ ** kwargs : Any ,
181
183
):
182
184
self .v_buff = v_buff
183
185
self .h_buff = h_buff
@@ -205,7 +207,7 @@ def __init__(
205
207
if self .include_background_rectangle :
206
208
self .add_background_rectangle ()
207
209
208
- def _matrix_to_mob_matrix (self , matrix ) :
210
+ def _matrix_to_mob_matrix (self , matrix : np . ndarray ) -> list [ list [ Mobject ]] :
209
211
return [
210
212
[
211
213
self .element_to_mobject (item , ** self .element_to_mobject_config )
@@ -214,7 +216,7 @@ def _matrix_to_mob_matrix(self, matrix):
214
216
for row in matrix
215
217
]
216
218
217
- def _organize_mob_matrix (self , matrix ) :
219
+ def _organize_mob_matrix (self , matrix : list [ list [ Mobject ]]) -> Self :
218
220
for i , row in enumerate (matrix ):
219
221
for j , _ in enumerate (row ):
220
222
mob = matrix [i ][j ]
@@ -224,7 +226,7 @@ def _organize_mob_matrix(self, matrix):
224
226
)
225
227
return self
226
228
227
- def _add_brackets (self , left : str = "[" , right : str = "]" , ** kwargs ) :
229
+ def _add_brackets (self , left : str = "[" , right : str = "]" , ** kwargs : Any ) -> Self :
228
230
"""Adds the brackets to the Matrix mobject.
229
231
230
232
See Latex document for various bracket types.
@@ -278,13 +280,13 @@ def _add_brackets(self, left: str = "[", right: str = "]", **kwargs):
278
280
self .add (l_bracket , r_bracket )
279
281
return self
280
282
281
- def get_columns (self ):
283
+ def get_columns (self ) -> VGroup :
282
284
r"""Return columns of the matrix as VGroups.
283
285
284
286
Returns
285
287
--------
286
- List[ :class:`~.VGroup`]
287
- Each VGroup contains a column of the matrix.
288
+ :class:`~.VGroup`
289
+ The VGroup contains a nested VGroup for each column of the matrix.
288
290
289
291
Examples
290
292
--------
@@ -305,7 +307,7 @@ def construct(self):
305
307
)
306
308
)
307
309
308
- def set_column_colors (self , * colors : str ):
310
+ def set_column_colors (self , * colors : str ) -> Self :
309
311
r"""Set individual colors for each columns of the matrix.
310
312
311
313
Parameters
@@ -335,13 +337,13 @@ def construct(self):
335
337
column .set_color (color )
336
338
return self
337
339
338
- def get_rows (self ):
340
+ def get_rows (self ) -> VGroup :
339
341
r"""Return rows of the matrix as VGroups.
340
342
341
343
Returns
342
344
--------
343
- List[ :class:`~.VGroup`]
344
- Each VGroup contains a row of the matrix.
345
+ :class:`~.VGroup`
346
+ The VGroup contains a nested VGroup for each row of the matrix.
345
347
346
348
Examples
347
349
--------
@@ -357,7 +359,7 @@ def construct(self):
357
359
"""
358
360
return VGroup (* (VGroup (* row ) for row in self .mob_matrix ))
359
361
360
- def set_row_colors (self , * colors : str ):
362
+ def set_row_colors (self , * colors : str ) -> Self :
361
363
r"""Set individual colors for each row of the matrix.
362
364
363
365
Parameters
@@ -387,7 +389,7 @@ def construct(self):
387
389
row .set_color (color )
388
390
return self
389
391
390
- def add_background_to_entries (self ):
392
+ def add_background_to_entries (self ) -> Self :
391
393
"""Add a black background rectangle to the matrix,
392
394
see above for an example.
393
395
@@ -400,7 +402,7 @@ def add_background_to_entries(self):
400
402
mob .add_background_rectangle ()
401
403
return self
402
404
403
- def get_mob_matrix (self ) -> list [list [MathTex ]]:
405
+ def get_mob_matrix (self ) -> list [list [Mobject ]]:
404
406
"""Return the underlying mob matrix mobjects.
405
407
406
408
Returns
@@ -410,7 +412,7 @@ def get_mob_matrix(self) -> list[list[MathTex]]:
410
412
"""
411
413
return self .mob_matrix
412
414
413
- def get_entries (self ):
415
+ def get_entries (self ) -> VGroup :
414
416
"""Return the individual entries of the matrix.
415
417
416
418
Returns
@@ -483,9 +485,9 @@ def construct(self):
483
485
def __init__ (
484
486
self ,
485
487
matrix : Iterable ,
486
- element_to_mobject : Mobject = DecimalNumber ,
487
- element_to_mobject_config : dict [str , Mobject ] = {"num_decimal_places" : 1 },
488
- ** kwargs ,
488
+ element_to_mobject : type [ Mobject ] = DecimalNumber ,
489
+ element_to_mobject_config : dict [str , Any ] = {"num_decimal_places" : 1 },
490
+ ** kwargs : Any ,
489
491
):
490
492
"""
491
493
Will round/truncate the decimal places as per the provided config.
@@ -526,7 +528,10 @@ def construct(self):
526
528
"""
527
529
528
530
def __init__ (
529
- self , matrix : Iterable , element_to_mobject : Mobject = Integer , ** kwargs
531
+ self ,
532
+ matrix : Iterable ,
533
+ element_to_mobject : type [Mobject ] = Integer ,
534
+ ** kwargs : Any ,
530
535
):
531
536
"""
532
537
Will round if there are decimal entries in the matrix.
@@ -560,7 +565,12 @@ def construct(self):
560
565
self.add(m0)
561
566
"""
562
567
563
- def __init__ (self , matrix , element_to_mobject = lambda m : m , ** kwargs ):
568
+ def __init__ (
569
+ self ,
570
+ matrix : Iterable ,
571
+ element_to_mobject : type [Mobject ] | Callable [..., Mobject ] = lambda m : m ,
572
+ ** kwargs : Any ,
573
+ ):
564
574
super ().__init__ (matrix , element_to_mobject = element_to_mobject , ** kwargs )
565
575
566
576
@@ -569,7 +579,7 @@ def get_det_text(
569
579
determinant : int | str | None = None ,
570
580
background_rect : bool = False ,
571
581
initial_scale_factor : float = 2 ,
572
- ):
582
+ ) -> VGroup :
573
583
r"""Helper function to create determinant.
574
584
575
585
Parameters
0 commit comments