Skip to content

Commit

Permalink
Clean up updater matters, prune unused functions
Browse files Browse the repository at this point in the history
  • Loading branch information
3b1b committed Mar 7, 2024
1 parent 70b839e commit 83cd5d6
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 37 deletions.
32 changes: 10 additions & 22 deletions manimlib/mobject/mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from manimlib.utils.bezier import interpolate
from manimlib.utils.paths import straight_path
from manimlib.utils.simple_functions import get_parameters
from manimlib.utils.simple_functions import get_num_args
from manimlib.utils.shaders import get_colormap_code
from manimlib.utils.space_ops import angle_of_vector
from manimlib.utils.space_ops import get_norm
Expand All @@ -51,7 +52,7 @@


if TYPE_CHECKING:
from typing import Callable, Iterator, Union, Tuple, Optional
from typing import Callable, Iterator, Union, Tuple, Optional, Any
import numpy.typing as npt
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, UniformDict, Self
from moderngl.context import Context
Expand Down Expand Up @@ -831,9 +832,9 @@ def get_grid(
# Updating

def init_updaters(self):
self.time_based_updaters: list[TimeBasedUpdater] = []
self.non_time_updaters: list[NonTimeUpdater] = []
self.has_updaters: bool = False
self.time_based_updaters: list[TimeBasedUpdater] = list()
self.non_time_updaters: list[NonTimeUpdater] = list()
self._has_updaters_in_family: Optional[bool] = False
self.updating_suspended: bool = False

def update(self, dt: float = 0, recurse: bool = True) -> Self:
Expand All @@ -848,36 +849,23 @@ def update(self, dt: float = 0, recurse: bool = True) -> Self:
updater(self)
return self

def get_time_based_updaters(self) -> list[TimeBasedUpdater]:
return self.time_based_updaters

def has_time_based_updater(self) -> bool:
return len(self.time_based_updaters) > 0

def get_updaters(self) -> list[Updater]:
return self.time_based_updaters + self.non_time_updaters

def get_family_updaters(self) -> list[Updater]:
return list(it.chain(*[sm.get_updaters() for sm in self.get_family()]))
return [*self.time_based_updaters, *self.non_time_updaters]

def add_updater(
self,
update_func: Updater,
call_updater: bool = True
) -> Self:
if "dt" in get_parameters(update_func):
def add_updater(self, update_func: Updater, call: bool = True) -> Self:
if get_num_args(update_func) > 1:
self.time_based_updaters.append(update_func)
else:
self.non_time_updaters.append(update_func)

if call_updater:
if call:
self.update(dt=0)

self.refresh_has_updater_status()
return self

def insert_updater(self, update_func: Updater, index=0):
if "dt" in get_parameters(update_func):
if get_num_args(update_func) > 1:
self.time_based_updaters.insert(index, update_func)
else:
self.non_time_updaters.insert(index, update_func)
Expand Down
14 changes: 3 additions & 11 deletions manimlib/scene/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,17 +345,9 @@ def update_mobjects(self, dt: float) -> None:
mobject.update(dt)

def should_update_mobjects(self) -> bool:
return self.always_update_mobjects or any([
len(mob.get_family_updaters()) > 0
for mob in self.mobjects
])

def has_time_based_updaters(self) -> bool:
return any([
sm.has_time_based_updater()
for mob in self.mobjects()
for sm in mob.get_family()
])
return self.always_update_mobjects or any(
mob.has_updaters() for mob in self.mobjects
)

# Related to time

Expand Down
8 changes: 4 additions & 4 deletions manimlib/utils/simple_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, TypeVar
from typing import Callable, TypeVar, Iterable
from manimlib.typing import FloatArray

Scalable = TypeVar("Scalable", float, FloatArray)
Expand All @@ -30,11 +30,11 @@ def gen_choose(n: int, r: int) -> int:


def get_num_args(function: Callable) -> int:
return len(get_parameters(function))
return len(list(get_parameters(function)))


def get_parameters(function: Callable) -> list:
return list(inspect.signature(function).parameters.keys())
def get_parameters(function: Callable) -> Iterable[str]:
return inspect.signature(function).parameters.keys()

# Just to have a less heavyweight name for this extremely common operation
#
Expand Down

0 comments on commit 83cd5d6

Please sign in to comment.