Skip to content

Commit

Permalink
chore: add type hints to manimlib.mobject
Browse files Browse the repository at this point in the history
  • Loading branch information
TonyCrane committed Feb 15, 2022
1 parent 3744844 commit 4c16bfc
Show file tree
Hide file tree
Showing 15 changed files with 736 additions and 360 deletions.
17 changes: 11 additions & 6 deletions manimlib/mobject/boolean_ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import numpy as np
import pathops

Expand All @@ -7,7 +9,7 @@
# Boolean operations between 2D mobjects
# Borrowed from from https://github.com/ManimCommunity/manim/

def _convert_vmobject_to_skia_path(vmobject):
def _convert_vmobject_to_skia_path(vmobject: VMobject) -> pathops.Path:
path = pathops.Path()
subpaths = vmobject.get_subpaths_from_points(vmobject.get_all_points())
for subpath in subpaths:
Expand All @@ -21,7 +23,10 @@ def _convert_vmobject_to_skia_path(vmobject):
return path


def _convert_skia_path_to_vmobject(path, vmobject):
def _convert_skia_path_to_vmobject(
path: pathops.Path,
vmobject: VMobject
) -> VMobject:
PathVerb = pathops.PathVerb
current_path_start = np.array([0.0, 0.0, 0.0])
for path_verb, points in path:
Expand All @@ -45,7 +50,7 @@ def _convert_skia_path_to_vmobject(path, vmobject):


class Union(VMobject):
def __init__(self, *vmobjects, **kwargs):
def __init__(self, *vmobjects: VMobject, **kwargs):
if len(vmobjects) < 2:
raise ValueError("At least 2 mobjects needed for Union.")
super().__init__(**kwargs)
Expand All @@ -59,7 +64,7 @@ def __init__(self, *vmobjects, **kwargs):


class Difference(VMobject):
def __init__(self, subject, clip, **kwargs):
def __init__(self, subject: VMobject, clip: VMobject, **kwargs):
super().__init__(**kwargs)
outpen = pathops.Path()
pathops.difference(
Expand All @@ -71,7 +76,7 @@ def __init__(self, subject, clip, **kwargs):


class Intersection(VMobject):
def __init__(self, *vmobjects, **kwargs):
def __init__(self, *vmobjects: VMobject, **kwargs):
if len(vmobjects) < 2:
raise ValueError("At least 2 mobjects needed for Intersection.")
super().__init__(**kwargs)
Expand All @@ -94,7 +99,7 @@ def __init__(self, *vmobjects, **kwargs):


class Exclusion(VMobject):
def __init__(self, *vmobjects, **kwargs):
def __init__(self, *vmobjects: VMobject, **kwargs):
if len(vmobjects) < 2:
raise ValueError("At least 2 mobjects needed for Exclusion.")
super().__init__(**kwargs)
Expand Down
37 changes: 26 additions & 11 deletions manimlib/mobject/changing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from __future__ import annotations

from typing import Callable

import numpy as np

from manimlib.constants import BLUE_D
from manimlib.constants import BLUE_B
from manimlib.constants import BLUE_E
Expand All @@ -20,23 +25,23 @@ class AnimatedBoundary(VGroup):
"fade_rate_func": smooth,
}

def __init__(self, vmobject, **kwargs):
def __init__(self, vmobject: VMobject, **kwargs):
super().__init__(**kwargs)
self.vmobject = vmobject
self.boundary_copies = [
self.vmobject: VMobject = vmobject
self.boundary_copies: list[VMobject] = [
vmobject.copy().set_style(
stroke_width=0,
fill_opacity=0
)
for x in range(2)
]
self.add(*self.boundary_copies)
self.total_time = 0
self.total_time: float = 0
self.add_updater(
lambda m, dt: self.update_boundary_copies(dt)
)

def update_boundary_copies(self, dt):
def update_boundary_copies(self, dt: float) -> None:
# Not actual time, but something which passes at
# an altered rate to make the implementation below
# cleaner
Expand Down Expand Up @@ -67,7 +72,13 @@ def update_boundary_copies(self, dt):

self.total_time += dt

def full_family_become_partial(self, mob1, mob2, a, b):
def full_family_become_partial(
self,
mob1: VMobject,
mob2: VMobject,
a: float,
b: float
):
family1 = mob1.family_members_with_points()
family2 = mob2.family_members_with_points()
for sm1, sm2 in zip(family1, family2):
Expand All @@ -84,14 +95,14 @@ class TracedPath(VMobject):
"time_per_anchor": 1 / 15,
}

def __init__(self, traced_point_func, **kwargs):
def __init__(self, traced_point_func: Callable[[], np.ndarray], **kwargs):
super().__init__(**kwargs)
self.traced_point_func = traced_point_func
self.time = 0
self.traced_points = []
self.time: float = 0
self.traced_points: list[np.ndarray] = []
self.add_updater(lambda m, dt: m.update_path(dt))

def update_path(self, dt):
def update_path(self, dt: float):
if dt == 0:
return self
point = self.traced_point_func().copy()
Expand Down Expand Up @@ -133,7 +144,11 @@ class TracingTail(TracedPath):
"time_traced": 1.0,
}

def __init__(self, mobject_or_func, **kwargs):
def __init__(
self,
mobject_or_func: Mobject | Callable[[], np.ndarray],
**kwargs
):
if isinstance(mobject_or_func, Mobject):
func = mobject_or_func.get_center
else:
Expand Down
Loading

0 comments on commit 4c16bfc

Please sign in to comment.