forked from 3b1b/manim
-
Notifications
You must be signed in to change notification settings - Fork 0
/
shader_wrapper.py
165 lines (140 loc) · 5.09 KB
/
shader_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
import re
import moderngl
import numpy as np
import copy
from manimlib.utils.directories import get_shader_dir
from manimlib.utils.file_ops import find_file
# Mobjects that should be rendered with
# the same shader will be organized and
# clumped together based on keeping track
# of a dict holding all the relevant information
# to that shader
class ShaderWrapper(object):
def __init__(self,
vert_data=None,
vert_indices=None,
shader_folder=None,
uniforms=None, # A dictionary mapping names of uniform variables
texture_paths=None, # A dictionary mapping names to filepaths for textures.
depth_test=False,
render_primitive=moderngl.TRIANGLE_STRIP,
):
self.vert_data = vert_data
self.vert_indices = vert_indices
self.vert_attributes = vert_data.dtype.names
self.shader_folder = shader_folder
self.uniforms = uniforms or dict()
self.texture_paths = texture_paths or dict()
self.depth_test = depth_test
self.render_primitive = str(render_primitive)
self.init_program_code()
self.refresh_id()
def copy(self):
result = copy.copy(self)
result.vert_data = np.array(self.vert_data)
if result.vert_indices is not None:
result.vert_indices = np.array(self.vert_indices)
if self.uniforms:
result.uniforms = dict(self.uniforms)
if self.texture_paths:
result.texture_paths = dict(self.texture_paths)
return result
def is_valid(self):
return all([
self.vert_data is not None,
self.program_code["vertex_shader"] is not None,
self.program_code["fragment_shader"] is not None,
])
def get_id(self):
return self.id
def get_program_id(self):
return self.program_id
def create_id(self):
# A unique id for a shader
return "|".join(map(str, [
self.program_id,
self.uniforms,
self.texture_paths,
self.depth_test,
self.render_primitive,
]))
def refresh_id(self):
self.program_id = self.create_program_id()
self.id = self.create_id()
def create_program_id(self):
return hash("".join((
self.program_code[f"{name}_shader"] or ""
for name in ("vertex", "geometry", "fragment")
)))
def init_program_code(self):
def get_code(name):
return get_shader_code_from_file(
os.path.join(self.shader_folder, f"{name}.glsl")
)
self.program_code = {
"vertex_shader": get_code("vert"),
"geometry_shader": get_code("geom"),
"fragment_shader": get_code("frag"),
}
def get_program_code(self):
return self.program_code
def replace_code(self, old, new):
code_map = self.program_code
for (name, code) in code_map.items():
if code_map[name] is None:
continue
code_map[name] = re.sub(old, new, code_map[name])
self.refresh_id()
def combine_with(self, *shader_wrappers):
# Assume they are of the same type
if len(shader_wrappers) == 0:
return
if self.vert_indices is not None:
num_verts = len(self.vert_data)
indices_list = [self.vert_indices]
data_list = [self.vert_data]
for sw in shader_wrappers:
indices_list.append(sw.vert_indices + num_verts)
data_list.append(sw.vert_data)
num_verts += len(sw.vert_data)
self.vert_indices = np.hstack(indices_list)
self.vert_data = np.hstack(data_list)
else:
self.vert_data = np.hstack([self.vert_data, *[sw.vert_data for sw in shader_wrappers]])
return self
# For caching
filename_to_code_map = {}
def get_shader_code_from_file(filename):
if not filename:
return None
if filename in filename_to_code_map:
return filename_to_code_map[filename]
try:
filepath = find_file(
filename,
directories=[get_shader_dir(), "/"],
extensions=[],
)
except IOError:
return None
with open(filepath, "r") as f:
result = f.read()
# To share functionality between shaders, some functions are read in
# from other files an inserted into the relevant strings before
# passing to ctx.program for compiling
# Replace "#INSERT " lines with relevant code
insertions = re.findall(r"^#INSERT .*\.glsl$", result, flags=re.MULTILINE)
for line in insertions:
inserted_code = get_shader_code_from_file(
os.path.join("inserts", line.replace("#INSERT ", ""))
)
result = result.replace(line, inserted_code)
filename_to_code_map[filename] = result
return result
def get_colormap_code(rgb_list):
data = ",".join(
"vec3({}, {}, {})".format(*rgb)
for rgb in rgb_list
)
return f"vec3[{len(rgb_list)}]({data})"