Skip to content

Commit

Permalink
Treat Group and VGroup more like list types
Browse files Browse the repository at this point in the history
This may not be the best way to address it, but at least temporarily it prevents linting issues for calls like VGroup(Circle())[0].get_radius()
  • Loading branch information
3b1b committed Feb 5, 2024
1 parent 7009f0f commit 100b108
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
12 changes: 8 additions & 4 deletions manimlib/mobject/mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Callable, Iterable, Iterator, Union, Tuple, Optional, TypeVar
from typing import Callable, Iterable, Iterator, Union, Tuple, Optional, TypeVar, Generic, List
import numpy.typing as npt
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, UniformDict, Self
from moderngl.context import Context
Expand All @@ -57,6 +57,7 @@
TimeBasedUpdater = Callable[["Mobject", float], "Mobject" | None]
NonTimeUpdater = Callable[["Mobject"], "Mobject" | None]
Updater = Union[TimeBasedUpdater, NonTimeUpdater]
SubmobjectType = TypeVar('SubmobjectType', bound='Mobject')


class Mobject(object):
Expand Down Expand Up @@ -352,7 +353,7 @@ def is_touching(self, mobject: Mobject, buff: float = 1e-2) -> bool:

# Family matters

def __getitem__(self, value: int | slice) -> Self:
def __getitem__(self, value: int | slice) -> Mobject:
if isinstance(value, slice):
GroupClass = self.get_group_class()
return GroupClass(*self.split().__getitem__(value))
Expand Down Expand Up @@ -2132,8 +2133,8 @@ def throw_error_if_no_points(self):
raise Exception(message.format(caller_name))


class Group(Mobject):
def __init__(self, *mobjects: Mobject, **kwargs):
class Group(Mobject, Generic[SubmobjectType]):
def __init__(self, *mobjects: SubmobjectType, **kwargs):
if not all([isinstance(m, Mobject) for m in mobjects]):
raise Exception("All submobjects must be of type Mobject")
super().__init__(**kwargs)
Expand All @@ -2143,6 +2144,9 @@ def __add__(self, other: Mobject | Group) -> Self:
assert isinstance(other, Mobject)
return self.add(other)

# This is just here to make linters happy with references to things like Group(...)[0]
def __getitem__(self, index) -> SubmobjectType:
return super().__getitem__(index)


class Point(Mobject):
Expand Down
12 changes: 9 additions & 3 deletions manimlib/mobject/types/vectorized_mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Callable, Iterable, Tuple, Any
from typing import Callable, Iterable, Tuple, Any, Generic, TypeVar
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array, Self
from moderngl.context import Context
SubVmobjectType = TypeVar('SubVmobjectType', bound='VMobject')

DEFAULT_STROKE_COLOR = GREY_A
DEFAULT_FILL_COLOR = GREY_C


class VMobject(Mobject):
fill_shader_folder: str = "quadratic_bezier_fill"
stroke_shader_folder: str = "quadratic_bezier_stroke"
Expand Down Expand Up @@ -1415,8 +1417,8 @@ def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
return [sw for sw in shader_wrappers if len(sw.vert_data) > 0]


class VGroup(VMobject):
def __init__(self, *vmobjects: VMobject, **kwargs):
class VGroup(VMobject, Generic[SubVmobjectType]):
def __init__(self, *vmobjects: SubVmobjectType, **kwargs):
super().__init__(**kwargs)
self.add(*vmobjects)
if vmobjects:
Expand All @@ -1426,6 +1428,10 @@ def __add__(self, other: VMobject) -> Self:
assert isinstance(other, VMobject)
return self.add(other)

# This is just here to make linters happy with references to things like VGroup(...)[0]
def __getitem__(self, index) -> SubVmobjectType:
return super().__getitem__(index)


class VectorizedPoint(Point, VMobject):
def __init__(
Expand Down

0 comments on commit 100b108

Please sign in to comment.