Skip to content

Commit

Permalink
tested
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghang1989 committed May 14, 2017
1 parent 984cce3 commit 55dbd84
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 43 deletions.
25 changes: 14 additions & 11 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,22 @@
ENCODING_LIB = os.path.join(package_base, 'lib/libENCODING.so')

def make_relative_rpath(path):
if platform.system() == 'Darwin':
return '-Wl,-rpath,' + path
else:
return '-Wl,-rpath,' + path
if platform.system() == 'Darwin':
return '-Wl,-rpath,' + path
else:
return '-Wl,-rpath,' + path

extra_link_args = []


ffi = create_extension(
'encoding._ext.encoding_lib',
package=True,
headers=headers,
sources=sources,
define_macros=defines,
relative_to=__file__,
with_cuda=with_cuda,
'encoding._ext.encoding_lib',
package=True,
headers=headers,
sources=sources,
define_macros=defines,
relative_to=__file__,
with_cuda=with_cuda,
include_dirs = include_path,
extra_link_args = [
make_relative_rpath(os.path.join(package_base, 'lib')),
Expand Down
8 changes: 2 additions & 6 deletions encoding/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,13 @@ IF(NOT ENCODING_INSTALL_LIB_SUBDIR)
ENDIF()

SET(CMAKE_MACOSX_RPATH 1)
#SET(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -std=c++11")
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")

FILE(GLOB src-cuda kernel/*.cu)

CUDA_INCLUDE_DIRECTORIES(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/kernel
${Torch_INSTALL_INCLUDE}
)
CUDA_ADD_LIBRARY(ENCODING SHARED ${src-cuda})
Expand All @@ -63,11 +64,6 @@ IF(MSVC)
SET_TARGET_PROPERTIES(ENCODING PROPERTIES PREFIX "lib" IMPORT_PREFIX "lib")
ENDIF()

INCLUDE_DIRECTORIES(
./include
${CMAKE_CURRENT_SOURCE_DIR}
${Torch_INSTALL_INCLUDE}
)
TARGET_LINK_LIBRARIES(ENCODING
${THC_LIBRARIES}
${TH_LIBRARIES}
Expand Down
1 change: 1 addition & 0 deletions encoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch
from torch.nn.modules.module import Module
from torch.autograd import Function
from ._ext import encoding_lib

class aggregate(Function):
Expand Down
7 changes: 4 additions & 3 deletions encoding/kernel/generic/encoding_kernel.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
*/
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "kernel/generic/encoding_kernel.c"
#define THC_GENERIC_FILE "generic/encoding_kernel.c"
#else
/*
template <int Dim>
THCDeviceTensor<float, Dim> devicetensor(THCState *state, THCTensor *t) {
if (!t) {
Expand All @@ -36,7 +37,7 @@ THCDeviceTensor<float, Dim> devicetensor(THCState *state, THCTensor *t) {
}
return THCDeviceTensor<float, Dim>(THCTensor_(data)(state, t), size);
}

*/
__global__ void Encoding_(Aggregate_Forward_kernel) (
THCDeviceTensor<real, 3> E,
THCDeviceTensor<real, 3> A,
Expand Down Expand Up @@ -71,7 +72,7 @@ void Encoding_(Aggregate_Forward)(THCState *state, THCTensor *E_, THCTensor *A_,
if (THCTensor_(nDimension)(state, E_) != 3 ||
THCTensor_(nDimension)(state, A_) != 3 ||
THCTensor_(nDimension)(state, R_) != 4)
perror("Encoding: incorrect input dims. \n");
THError("Encoding: incorrect input dims. \n");
/* Device tensors */
THCDeviceTensor<real, 3> E = devicetensor<3>(state, E_);
THCDeviceTensor<real, 3> A = devicetensor<3>(state, A_);
Expand Down
39 changes: 36 additions & 3 deletions encoding/kernel/thc_encoding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,48 @@
#include "THCDeviceTensor.cuh"
#include "THCDeviceTensorUtils.cuh"

#include "thc_encoding.h"

// this symbol will be resolved automatically from PyTorch libs
extern THCState *state;

//#define torch_(NAME) TH_CONCAT_3(torch_, Real, NAME)
//#define torch_Tensor TH_CONCAT_STRING_3(torch., Real, Tensor)

#define Encoding_(NAME) TH_CONCAT_4(Encoding_, Real, _, NAME)
#define THCTensor TH_CONCAT_3(TH,CReal,Tensor)
#define THCTensor_(NAME) TH_CONCAT_4(TH,CReal,Tensor_,NAME)

template <int Dim>
THCDeviceTensor<float, Dim> devicetensor(THCState *state, THCudaTensor *t) {
if (!t) {
return THCDeviceTensor<float, Dim>();
}

int inDim = THCudaTensor_nDimension(state, t);
if (inDim == Dim) {
return toDeviceTensor<float, Dim>(state, t);
}

// View in which the last dimensions are collapsed or expanded as needed
THAssert(THCudaTensor_isContiguous(state, t));
int size[Dim];
for (int i = 0; i < Dim || i < inDim; ++i) {
if (i < Dim && i < inDim) {
size[i] = t->size[i];
} else if (i < Dim) {
size[i] = 1;
} else {
size[Dim - 1] *= t->size[i];
}
}
return THCDeviceTensor<float, Dim>(THCudaTensor_data(state, t), size);
}

#ifdef __cplusplus
extern "C" {
#endif

#include "generic/encoding_kernel.c"
#include "THC/THCGenerateFloatType.h"

#ifdef __cplusplus
}
#endif
11 changes: 8 additions & 3 deletions encoding/kernel/thc_encoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
// this symbol will be resolved automatically from PyTorch libs
extern THCState *state;

//#define torch_(NAME) TH_CONCAT_3(torch_, Real, NAME)
//#define torch_Tensor TH_CONCAT_STRING_3(torch., Real, Tensor)

#define Encoding_(NAME) TH_CONCAT_4(Encoding_, Real, _, NAME)
#define THCTensor TH_CONCAT_3(TH,CReal,Tensor)
#define THCTensor_(NAME) TH_CONCAT_4(TH,CReal,Tensor_,NAME)

#ifdef __cplusplus
extern "C" {
#endif

#include "generic/encoding_kernel.h"
#include "THC/THCGenerateFloatType.h"

#ifdef __cplusplus
}
#endif
8 changes: 8 additions & 0 deletions encoding/src/encoding_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,13 @@

extern THCState *state;

#ifdef __cplusplus
extern "C" {
#endif

#include "generic/encoding_generic.c"
#include "THC/THCGenerateFloatType.h"

#ifdef __cplusplus
}
#endif
41 changes: 24 additions & 17 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,29 @@

this_file = os.path.dirname(__file__)

extra_compile_args = ['-std=c++11', '-Wno-write-strings']
if os.getenv('PYTORCH_BINARY_BUILD') and platform.system() == 'Linux':
print('PYTORCH_BINARY_BUILD found. Static linking libstdc++ on Linux')
extra_compile_args += ['-static-libstdc++']
extra_link_args += ['-static-libstdc++']

setup(
name="encoding",
version="0.0.1",
description="PyTorch Encoding Layer",
url="https://github.com/zhanghang1989/PyTorch-Encoding-Layer",
author="Hang Zhang",
author_email="zhang.hang@rutgers.edu",
# Require cffi.
install_requires=["cffi>=1.0.0"],
setup_requires=["cffi>=1.0.0"],
# Exclude the build files.
packages=find_packages(exclude=["build"]),
# Package where to put the extensions. Has to be a prefix of build.py.
ext_package="",
# Extensions to compile.
cffi_modules=[
os.path.join(this_file, "build.py:ffi")
],
name="encoding",
version="0.0.1",
description="PyTorch Encoding Layer",
url="https://github.com/zhanghang1989/PyTorch-Encoding-Layer",
author="Hang Zhang",
author_email="zhang.hang@rutgers.edu",
# Require cffi.
install_requires=["cffi>=1.0.0"],
setup_requires=["cffi>=1.0.0"],
# Exclude the build files.
packages=find_packages(exclude=["build"]),
extra_compile_args=extra_compile_args,
# Package where to put the extensions. Has to be a prefix of build.py.
ext_package="",
# Extensions to compile.
cffi_modules=[
os.path.join(this_file, "build.py:ffi")
],
)

0 comments on commit 55dbd84

Please sign in to comment.