forked from 3b1b/manim
-
Notifications
You must be signed in to change notification settings - Fork 0
/
shader_wrapper.py
301 lines (249 loc) · 9.64 KB
/
shader_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
from __future__ import annotations
import copy
import os
import re
import OpenGL.GL as gl
import moderngl
import numpy as np
from manimlib.utils.iterables import resize_array
from manimlib.utils.shaders import get_shader_code_from_file
from manimlib.utils.shaders import get_shader_program
from manimlib.utils.shaders import image_path_to_texture
from manimlib.utils.shaders import get_texture_id
from manimlib.utils.shaders import get_fill_canvas
from manimlib.utils.shaders import release_texture
from manimlib.utils.shaders import set_program_uniform
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import List, Optional, Dict
from manimlib.typing import UniformDict
# Mobjects that should be rendered with
# the same shader will be organized and
# clumped together based on keeping track
# of a dict holding all the relevant information
# to that shader
class ShaderWrapper(object):
def __init__(
self,
ctx: moderngl.context.Context,
vert_data: np.ndarray,
vert_indices: Optional[np.ndarray] = None,
shader_folder: Optional[str] = None,
mobject_uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables
texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures.
depth_test: bool = False,
render_primitive: int = moderngl.TRIANGLE_STRIP,
):
self.ctx = ctx
self.vert_data = vert_data
self.vert_indices = (vert_indices or np.zeros(0)).astype(int)
self.vert_attributes = vert_data.dtype.names
self.shader_folder = shader_folder
self.depth_test = depth_test
self.render_primitive = render_primitive
self.program_uniform_mirror: UniformDict = dict()
self.bind_to_mobject_uniforms(mobject_uniforms or dict())
self.init_program_code()
self.init_program()
if texture_paths is not None:
self.init_textures(texture_paths)
self.init_vao()
self.refresh_id()
def init_program_code(self) -> None:
def get_code(name: str) -> str | None:
return get_shader_code_from_file(
os.path.join(self.shader_folder, f"{name}.glsl")
)
self.program_code: dict[str, str | None] = {
"vertex_shader": get_code("vert"),
"geometry_shader": get_code("geom"),
"fragment_shader": get_code("frag"),
}
def init_program(self):
if not self.shader_folder:
self.program = None
self.vert_format = None
return
self.program = get_shader_program(self.ctx, **self.program_code)
self.vert_format = moderngl.detect_format(self.program, self.vert_attributes)
def init_textures(self, texture_paths: dict[str, str]):
names_to_ids = {
name: get_texture_id(image_path_to_texture(path, self.ctx))
for name, path in texture_paths.items()
}
self.update_program_uniforms(names_to_ids)
def init_vao(self):
self.vbo = None
self.ibo = None
self.vao = None
def bind_to_mobject_uniforms(self, mobject_uniforms: UniformDict):
self.mobject_uniforms = mobject_uniforms
def __eq__(self, shader_wrapper: ShaderWrapper):
return all((
np.all(self.vert_data == shader_wrapper.vert_data),
np.all(self.vert_indices == shader_wrapper.vert_indices),
self.shader_folder == shader_wrapper.shader_folder,
all(
self.mobject_uniforms[key] == shader_wrapper.mobject_uniforms[key]
for key in self.mobject_uniforms
),
self.depth_test == shader_wrapper.depth_test,
self.render_primitive == shader_wrapper.render_primitive,
))
def copy(self):
result = copy.copy(self)
result.ctx = self.ctx
result.vert_data = self.vert_data.copy()
result.vert_indices = self.vert_indices.copy()
result.init_vao()
return result
def is_valid(self) -> bool:
return all([
self.vert_data is not None,
self.program_code["vertex_shader"] is not None,
self.program_code["fragment_shader"] is not None,
])
def get_id(self) -> str:
return self.id
def create_id(self) -> str:
# A unique id for a shader
program_id = hash("".join(
self.program_code[f"{name}_shader"] or ""
for name in ("vertex", "geometry", "fragment")
))
return "|".join(map(str, [
program_id,
self.mobject_uniforms,
self.depth_test,
self.render_primitive,
]))
def refresh_id(self) -> None:
self.id = self.create_id()
def replace_code(self, old: str, new: str) -> None:
code_map = self.program_code
for name in code_map:
if code_map[name] is None:
continue
code_map[name] = re.sub(old, new, code_map[name])
self.init_program()
self.refresh_id()
# Changing context
def use_clip_plane(self):
if "clip_plane" not in self.mobject_uniforms:
return False
return any(self.mobject_uniforms["clip_plane"])
def set_ctx_depth_test(self, enable: bool = True) -> None:
if enable:
self.ctx.enable(moderngl.DEPTH_TEST)
else:
self.ctx.disable(moderngl.DEPTH_TEST)
def set_ctx_clip_plane(self, enable: bool = True) -> None:
if enable:
gl.glEnable(gl.GL_CLIP_DISTANCE0)
# Adding data
def combine_with(self, *shader_wrappers: ShaderWrapper) -> ShaderWrapper:
if len(shader_wrappers) > 0:
data_list = [self.vert_data, *(sw.vert_data for sw in shader_wrappers)]
indices_list = [self.vert_indices, *(sw.vert_indices for sw in shader_wrappers)]
self.read_in(data_list, indices_list)
return self
def read_in(
self,
data_list: List[np.ndarray],
indices_list: List[np.ndarray] | None = None
) -> ShaderWrapper:
# Assume all are of the same type
total_len = sum(len(data) for data in data_list)
self.vert_data = resize_array(self.vert_data, total_len)
if total_len == 0:
return self
# Stack the data
np.concatenate(data_list, out=self.vert_data)
if indices_list is None:
self.vert_indices = resize_array(self.vert_indices, 0)
return self
total_verts = sum(len(vi) for vi in indices_list)
if total_verts == 0:
return self
self.vert_indices = resize_array(self.vert_indices, total_verts)
# Stack vert_indices, but adding the appropriate offset
# alogn the way
n_points = 0
n_verts = 0
for data, indices in zip(data_list, indices_list):
new_n_verts = n_verts + len(indices)
self.vert_indices[n_verts:new_n_verts] = indices + n_points
n_verts = new_n_verts
n_points += len(data)
return self
# Related to data and rendering
def pre_render(self):
self.set_ctx_depth_test(self.depth_test)
self.set_ctx_clip_plane(self.use_clip_plane())
def render(self):
assert(self.vao is not None)
self.vao.render()
def update_program_uniforms(self, camera_uniforms: UniformDict):
if self.program is None:
return
for name, value in (*self.mobject_uniforms.items(), *camera_uniforms.items()):
set_program_uniform(self.program, name, value)
def get_vertex_buffer_object(self, refresh: bool = True):
if refresh:
self.vbo = self.ctx.buffer(self.vert_data)
return self.vbo
def get_index_buffer_object(self, refresh: bool = True):
if refresh and len(self.vert_indices) > 0:
self.ibo = self.ctx.buffer(self.vert_indices.astype(np.uint32))
return self.ibo
def generate_vao(self, refresh: bool = True):
self.release()
# Data buffer
vbo = self.get_vertex_buffer_object(refresh)
ibo = self.get_index_buffer_object(refresh)
# Vertex array object
self.vao = self.ctx.vertex_array(
program=self.program,
content=[(vbo, self.vert_format, *self.vert_attributes)],
index_buffer=ibo,
mode=self.render_primitive,
)
return self.vao
def release(self):
for obj in (self.vbo, self.ibo, self.vao):
if obj is not None:
obj.release()
self.vbo = None
self.ibo = None
self.vao = None
class FillShaderWrapper(ShaderWrapper):
def __init__(
self,
ctx: moderngl.context.Context,
*args,
**kwargs
):
super().__init__(ctx, *args, **kwargs)
self.fill_canvas = get_fill_canvas(self.ctx)
def render(self):
winding = (len(self.vert_indices) == 0)
self.program['winding'].value = winding
if not winding:
super().render()
return
original_fbo = self.ctx.fbo
texture_fbo, texture_vao = self.fill_canvas
texture_fbo.clear()
texture_fbo.use()
gl.glBlendFuncSeparate(
# Ordinary blending for colors
gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA,
# The effect of blending with -a / (1 - a)
# should be to cancel out
gl.GL_ONE_MINUS_DST_ALPHA, gl.GL_ONE,
)
super().render()
original_fbo.use()
gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA)
texture_vao.render()
gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA)