Skip to content

Commit

Permalink
FIx Mobject.replace_shader_code
Browse files Browse the repository at this point in the history
  • Loading branch information
3b1b committed Feb 3, 2023
1 parent d10745a commit c477701
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
13 changes: 8 additions & 5 deletions manimlib/mobject/mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(
self.bounding_box: Vect3Array = np.zeros((3, 3))
self._shaders_initialized: bool = False
self._data_has_changed: bool = True
self.shader_code_replacements: dict[str, str] = dict()

self.init_data()
self._data_defaults = np.ones(1, dtype=self.data.dtype)
Expand Down Expand Up @@ -1895,12 +1896,12 @@ def deactivate_depth_test(self, recurse: bool = True) -> Self:

# Shader code manipulation

@affects_data
def replace_shader_code(self, old: str, new: str) -> Self:
# TODO, will this work with VMobject structure, given
# that it does not simpler return shader_wrappers of
# family?
for wrapper in self.get_shader_wrapper_list():
wrapper.replace_code(old, new)
self.shader_code_replacements[old] = new
self._shaders_initialized = False
for mob in self.get_ancestors():
mob._shaders_initialized = False
return self

def set_color_by_code(self, glsl_code: str) -> Self:
Expand Down Expand Up @@ -1969,6 +1970,8 @@ def get_shader_wrapper(self, ctx: Context) -> ShaderWrapper:
self.shader_wrapper.vert_indices = self.get_shader_vert_indices()
self.shader_wrapper.bind_to_mobject_uniforms(self.get_uniforms())
self.shader_wrapper.depth_test = self.depth_test
for old, new in self.shader_code_replacements.items():
self.shader_wrapper.replace_code(old, new)
return self.shader_wrapper

def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
Expand Down
9 changes: 7 additions & 2 deletions manimlib/mobject/types/vectorized_mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,6 +1292,10 @@ def init_shader_data(self, ctx: Context):
self.fill_shader_wrapper,
self.stroke_shader_wrapper,
]
for sw in self.shader_wrappers:
rep = self.family_members_with_points()[0]
for old, new in rep.shader_code_replacements.items():
sw.replace_code(old, new)

def refresh_shader_wrapper_id(self) -> Self:
if not self._shaders_initialized:
Expand Down Expand Up @@ -1355,8 +1359,9 @@ def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
self.stroke_shader_wrapper.read_in(stroke_datas),
]
for sw in shader_wrappers:
sw.bind_to_mobject_uniforms(family[0].get_uniforms())
sw.depth_test = family[0].depth_test
rep = family[0] # Representative family member
sw.bind_to_mobject_uniforms(rep.get_uniforms())
sw.depth_test = rep.depth_test
return [sw for sw in shader_wrappers if len(sw.vert_data) > 0]


Expand Down

0 comments on commit c477701

Please sign in to comment.