forked from zhanghang1989/PyTorch-Encoding
-
Notifications
You must be signed in to change notification settings - Fork 0
/
build.py
74 lines (63 loc) · 2.16 KB
/
build.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import os
import torch
import platform
import subprocess
from torch.utils.ffi import create_extension
lib_path = os.path.join(os.path.dirname(torch.__file__), 'lib')
cwd = os.path.dirname(os.path.realpath(__file__))
encoding_lib_path = os.path.join(cwd, "encoding", "lib")
# clean the build files
clean_cmd = ['bash', 'clean.sh']
subprocess.check_call(clean_cmd)
# build CUDA library
os.environ['TORCH_BUILD_DIR'] = lib_path
if platform.system() == 'Darwin':
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.1.dylib')
ENCODING_LIB = os.path.join(cwd, 'encoding/lib/libENCODING.dylib')
else:
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.so.1')
ENCODING_LIB = os.path.join(cwd, 'encoding/lib/libENCODING.so')
build_all_cmd = ['bash', 'encoding/make.sh']
subprocess.check_call(build_all_cmd, env=dict(os.environ))
# build FFI
sources = ['encoding/src/encoding_lib.cpp']
headers = [
'encoding/src/encoding_lib.h',
]
defines = [('WITH_CUDA', None)]
with_cuda = True
include_path = [os.path.join(lib_path, 'include'),
os.path.join(cwd,'encoding/kernel'),
os.path.join(cwd,'encoding/kernel/include'),
os.path.join(cwd,'encoding/src/')]
def make_relative_rpath(path):
if platform.system() == 'Darwin':
return '-Wl,-rpath,' + path
else:
return '-Wl,-rpath,' + path
ffi = create_extension(
'encoding._ext.encoding_lib',
package=True,
headers=headers,
sources=sources,
define_macros=defines,
relative_to=__file__,
with_cuda=with_cuda,
include_dirs = include_path,
extra_link_args = [
make_relative_rpath(lib_path),
make_relative_rpath(encoding_lib_path),
ENCODING_LIB,
],
)
if __name__ == '__main__':
ffi.build()