Skip to content

Commit

Permalink
Change from tracking time_based_updater and non_time_updater lists se…
Browse files Browse the repository at this point in the history
…parately to just tracking one list
  • Loading branch information
3b1b committed Mar 7, 2024
1 parent 83cd5d6 commit d3ba101
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 39 deletions.
65 changes: 27 additions & 38 deletions manimlib/mobject/mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@
from manimlib.utils.bezier import integer_interpolate
from manimlib.utils.bezier import interpolate
from manimlib.utils.paths import straight_path
from manimlib.utils.simple_functions import get_parameters
from manimlib.utils.simple_functions import get_num_args
from manimlib.utils.shaders import get_colormap_code
from manimlib.utils.space_ops import angle_of_vector
from manimlib.utils.space_ops import get_norm
Expand Down Expand Up @@ -687,8 +685,7 @@ def copy(self, deep: bool = False) -> Self:

# Similarly, instead of calling match_updaters, since we know the status
# won't have changed, just directly match.
result.non_time_updaters = list(self.non_time_updaters)
result.time_based_updaters = list(self.time_based_updaters)
result.updaters = list(self.updaters)
result._data_has_changed = True
result._shaders_initialized = False

Expand Down Expand Up @@ -832,8 +829,7 @@ def get_grid(
# Updating

def init_updaters(self):
self.time_based_updaters: list[TimeBasedUpdater] = list()
self.non_time_updaters: list[NonTimeUpdater] = list()
self.updaters: list[Updater] = list()
self._has_updaters_in_family: Optional[bool] = False
self.updating_suspended: bool = False

Expand All @@ -843,56 +839,47 @@ def update(self, dt: float = 0, recurse: bool = True) -> Self:
if recurse:
for submob in self.submobjects:
submob.update(dt, recurse)
for updater in self.time_based_updaters:
updater(self, dt)
for updater in self.non_time_updaters:
updater(self)
for updater in self.updaters:
# This is hacky, but if an updater takes dt as an arg,
# it will be passed the change in time from here
if "dt" in updater.__code__.co_varnames:
updater(self, dt=dt)
else:
updater(self)
return self

def get_updaters(self) -> list[Updater]:
return [*self.time_based_updaters, *self.non_time_updaters]
return self.updaters

def add_updater(self, update_func: Updater, call: bool = True) -> Self:
if get_num_args(update_func) > 1:
self.time_based_updaters.append(update_func)
else:
self.non_time_updaters.append(update_func)

self.updaters.append(update_func)
if call:
self.update(dt=0)

self.refresh_has_updater_status()
return self

def insert_updater(self, update_func: Updater, index=0):
if get_num_args(update_func) > 1:
self.time_based_updaters.insert(index, update_func)
else:
self.non_time_updaters.insert(index, update_func)

self.updaters.insert(index, update_func)
self.refresh_has_updater_status()
return self

def remove_updater(self, update_func: Updater) -> Self:
for updater_list in [self.time_based_updaters, self.non_time_updaters]:
while update_func in updater_list:
updater_list.remove(update_func)
while update_func in self.updaters:
self.updaters.remove(update_func)
self.refresh_has_updater_status()
return self

def clear_updaters(self, recurse: bool = True) -> Self:
self.time_based_updaters = []
self.non_time_updaters = []
if recurse:
for submob in self.submobjects:
submob.clear_updaters()
self.refresh_has_updater_status()
for mob in self.get_family(recurse):
mob.updaters = []
mob._has_updaters_in_family = False
for parent in self.get_ancestors():
parent._has_updaters_in_family = False
return self

def match_updaters(self, mobject: Mobject) -> Self:
self.clear_updaters()
for updater in mobject.get_updaters():
self.add_updater(updater)
self.updaters = list(mobject.updaters)
self.refresh_has_updater_status()
return self

def suspend_updating(self, recurse: bool = True) -> Self:
Expand All @@ -916,13 +903,15 @@ def resume_updating(self, recurse: bool = True, call_updater: bool = True) -> Se
def has_updaters(self) -> bool:
if self._has_updaters_in_family is None:
# Recompute and save
res = bool(self.time_based_updaters or self.non_time_updaters)
self._has_updaters_in_family = res or any(sm.has_updaters() for sm in self.submobjects)
self._has_updaters_in_family = bool(self.updaters) or any(
sm.has_updaters() for sm in self.submobjects
)
return self._has_updaters_in_family

def refresh_has_updater_status(self) -> Self:
for mob in (self, *self.parents):
mob._has_updaters_in_family = None
self._has_updaters_in_family = None
for parent in self.parents:
parent.refresh_has_updater_status()
return self

# Check if mark as static or not for camera
Expand Down
2 changes: 1 addition & 1 deletion manimlib/utils/simple_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def gen_choose(n: int, r: int) -> int:


def get_num_args(function: Callable) -> int:
return len(list(get_parameters(function)))
return function.__code__.co_argcount


def get_parameters(function: Callable) -> Iterable[str]:
Expand Down

0 comments on commit d3ba101

Please sign in to comment.