Skip to content

Commit

Permalink
Optimize quantization kernels (openvinotoolkit#51)
Browse files Browse the repository at this point in the history
Switched to using 2D grid for per-channel kernels and cleaned up
the code duplications across the extensions.
  • Loading branch information
vshampor committed Jul 10, 2020
1 parent cc16aa2 commit a0c1c2b
Show file tree
Hide file tree
Showing 28 changed files with 957 additions and 945 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
graft nncf/extensions/*
12 changes: 0 additions & 12 deletions nncf/binarization/cpu/__init__.py

This file was deleted.

12 changes: 0 additions & 12 deletions nncf/binarization/cuda/__init__.py

This file was deleted.

39 changes: 22 additions & 17 deletions nncf/binarization/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,38 @@
limitations under the License.
"""

import os
import pathlib
import os.path
from nncf.definitions import get_install_type

from torch.utils.cpp_extension import load

from nncf.utils import set_build_dir_for_venv
from nncf.definitions import get_install_type, NNCF_PACKAGE_ROOT_DIR

set_build_dir_for_venv()


BASE_EXT_DIR = os.path.join(NNCF_PACKAGE_ROOT_DIR, "extensions/src/binarization")

EXT_INCLUDE_DIRS = [
os.path.join(NNCF_PACKAGE_ROOT_DIR, "extensions/include"),
]

CPU_EXT_SRC_LIST = [
os.path.join(BASE_EXT_DIR, "cpu/functions_cpu.cpp"),
os.path.join(NNCF_PACKAGE_ROOT_DIR, "extensions/src/common/cpu/tensor_funcs.cpp")
]

if "VIRTUAL_ENV" in os.environ:
build_dir = os.path.join(os.environ["VIRTUAL_ENV"], "torch_extensions")
pathlib.Path(build_dir).mkdir(parents=True, exist_ok=True)
os.environ["TORCH_EXTENSIONS_DIR"] = build_dir
CUDA_EXT_SRC_LIST = [
os.path.join(BASE_EXT_DIR, "cuda/functions_cuda.cpp"),
os.path.join(BASE_EXT_DIR, "cuda/functions_cuda_impl.cu")
]

ext_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "cpu")
BinarizedFunctionsCPU = load(
'binarized_functions_cpu', [
os.path.join(ext_dir, 'functions_cpu.cpp')
],
'binarized_functions_cpu', CPU_EXT_SRC_LIST, extra_include_paths=EXT_INCLUDE_DIRS,
verbose=False
)

if get_install_type() == "GPU":
ext_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "cuda")
BinarizedFunctionsCUDA = load(
'binarized_functions_cuda', [
os.path.join(ext_dir, 'functions_cuda.cpp'),
os.path.join(ext_dir, 'functions_cuda_kernel.cu')
],
'binarized_functions_cuda', CUDA_EXT_SRC_LIST, extra_include_paths=EXT_INCLUDE_DIRS,
verbose=False
)
19 changes: 19 additions & 0 deletions nncf/extensions/include/binarization/functions_cuda_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef _BINARIZATION_FUNCTIONS_CUDA_IMPL_H_
#define _BINARIZATION_FUNCTIONS_CUDA_IMPL_H_

at::Tensor wb_cuda_forward(
at::Tensor input,
bool per_channel);

at::Tensor ab_cuda_forward(
at::Tensor input,
at::Tensor scale,
at::Tensor thresholds);

std::vector<at::Tensor> ab_cuda_backward(
at::Tensor grad_output,
at::Tensor input,
at::Tensor scale,
at::Tensor output);

#endif // _BINARIZATION_FUNCTIONS_CUDA_IMPL_H_
9 changes: 9 additions & 0 deletions nncf/extensions/include/common_cpu_funcs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#ifndef _COMMON_CPU_FUNCS_H_
#define _COMMON_CPU_FUNCS_H_

#include <torch/torch.h>

void sum_like(at::Tensor& target_tensor, const at::Tensor& ref_tensor);
void sum_to_act_channels(at::Tensor& target_tensor);

#endif // _COMMON_CPU_FUNCS_H_
70 changes: 70 additions & 0 deletions nncf/extensions/include/common_cuda_defs.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#ifndef _COMMON_CUDA_DEFS_CUH_
#define _COMMON_CUDA_DEFS_CUH_

#include <ATen/ATen.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <vector>
#include <THC/THC.h>


const uint32_t CUDA_WARP_SIZE = 32;
const uint32_t CUDA_TARGET_NUM_THREADS_PER_SM = 2048; // Will decide upon a number of threads per block and blocks per grid based on the workload to hit this target
const uint32_t CUDA_TARGET_SM_COUNT = 72; // RTX 2080 Ti
const uint32_t CUDA_MAX_NUM_THREADS_PER_BLOCK = 1024; // Maximum for all CUDA compute capabilities up to 8.0
const uint16_t CUDA_MAX_WARPS_PER_BLOCK = CUDA_MAX_NUM_THREADS_PER_BLOCK / CUDA_WARP_SIZE;
const uint32_t CUDA_BLOCKS_PER_GRID_FOR_UNIFORM_ELTWISE = CUDA_TARGET_SM_COUNT * CUDA_TARGET_NUM_THREADS_PER_SM / CUDA_MAX_NUM_THREADS_PER_BLOCK;
const uint16_t CUDA_MAX_GRID_SIZE_Y = 65535;

inline uint32_t GET_BLOCKS(const uint32_t total_required_threads) {
return (total_required_threads + CUDA_MAX_NUM_THREADS_PER_BLOCK - 1) / CUDA_MAX_NUM_THREADS_PER_BLOCK;
}


template<class I>
inline I align(I num, I alignment)
{
return (num & ~(alignment - 1)) + alignment;
}

inline dim3 get_2d_grid_size_for_per_channel(const uint32_t scale_count)
{
// X will correspond to scale count, Y will be determined in order to hit the thread-per-SM target
uint32_t grid_size_x = scale_count;
uint32_t available_threads_per_scale = static_cast<uint32_t>((CUDA_TARGET_SM_COUNT * CUDA_TARGET_NUM_THREADS_PER_SM + 0.0) / grid_size_x);
uint32_t available_warps_per_scale = align(available_threads_per_scale, CUDA_WARP_SIZE) / CUDA_WARP_SIZE;
uint32_t blocks_per_scale = std::max(1U, available_warps_per_scale / static_cast<uint32_t>(CUDA_MAX_WARPS_PER_BLOCK));
uint16_t grid_size_y = std::min(blocks_per_scale, static_cast<uint32_t>(CUDA_MAX_GRID_SIZE_Y));

return dim3(grid_size_x, grid_size_y);
}



#ifdef DO_PROFILE
#define PROFILE(CODE) \
int iter = 10; \
for (int i = 0; i < iter; i++) { \
CODE \
} \
cudaDeviceSynchronize(); \
auto start = std::chrono::steady_clock::now(); \
for (int i = 0; i < iter; i++) { \
CODE \
} \
cudaDeviceSynchronize(); \
auto end = std::chrono::steady_clock::now(); \
std::chrono::duration<double> diff = (end - start) / iter; \
std::cout << "PROFILE: avg kernel runtime = " << \
std::chrono::duration_cast<std::chrono::nanoseconds>(diff).count() \
<< " ns" << std::endl; \
cudaError_t err = cudaGetLastError(); \
if (err != cudaSuccess) \
std::cout << "CUDA error: " << cudaGetErrorString(err) << std::endl;
#else
#define PROFILE(CODE) CODE
#endif

#endif // _COMMON_CUDA_DEFS_CUH_
174 changes: 174 additions & 0 deletions nncf/extensions/include/common_cuda_funcs.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
#ifndef _COMMON_CUDA_FUNCS_CUH_
#define _COMMON_CUDA_FUNCS_CUH_

// Have to define common CUDA __device__ funcs in headers because moving them
// to separate translation units will require relocatable device code compilation,
// which is rumoured to degrade performance.

#include "common_cuda_defs.cuh"

// support only warp size = 32
template <typename scalar_t>
__device__ void sum_warp(volatile scalar_t* sharr) {
int tidx = threadIdx.x & 31;
if (tidx < 16) {
sharr[tidx] += sharr[tidx + 16];
sharr[tidx] += sharr[tidx + 8];
sharr[tidx] += sharr[tidx + 4];
sharr[tidx] += sharr[tidx + 2];
sharr[tidx] += sharr[tidx + 1];
}
}

// Since volatile c10::Half arithmetic is not supported, will have to sacrifice
// the implicit warp-synchronous programming in favor of explicit intra-warp thread
// synchronization

template <typename scalar_t>
__device__ void sum_warp_with_explicit_sync(scalar_t* sharr) {
uint16_t tidx = threadIdx.x & 31;
if (tidx < 16) {
sharr[tidx] += sharr[tidx + 16];
}
__syncwarp();
if (tidx < 16) {
sharr[tidx] += sharr[tidx + 8];
}
__syncwarp();
if (tidx < 16) {
sharr[tidx] += sharr[tidx + 4];
}
__syncwarp();
if (tidx < 16) {
sharr[tidx] += sharr[tidx + 2];
}
__syncwarp();
if (tidx < 16) {
sharr[tidx] += sharr[tidx + 1];
}
__syncwarp();
}

template <typename scalar_t>
__device__ inline void gather_warp_execution_results(scalar_t* sharr, const uint16_t tidx) {
sharr[tidx] = tidx * CUDA_WARP_SIZE < CUDA_MAX_NUM_THREADS_PER_BLOCK ? sharr[tidx * CUDA_WARP_SIZE] : static_cast<scalar_t>(0.0);
}


// Reduces the contents of a shared memory array of CUDA_MAX_NUM_THREADS_PER_BLOCK using
// warp-powered reduction. The final sum will be stored in the 0-th element of the shared memory array.
template <typename scalar_t>
__device__ void reduce_in_block_using_warp_sums(scalar_t* __restrict__ sh_mem,
uint16_t tidx) {
__syncthreads();
// Will reduce the summation to CUDA_MAX_WARPS_PER_BLOCK elements that are
// spaced CUDA_WARP_SIZE elements apart in the shared memory
sum_warp(sh_mem + (tidx & ~(CUDA_WARP_SIZE - 1)));

__syncthreads();
if (tidx < CUDA_MAX_WARPS_PER_BLOCK) {
// Do warp reduction again - because currently CUDA_MAX_WARPS_PER_BLOCK == CUDA_WARP_SIZE, this
// will lead to the 0-th element of the shared memory containing the entire per-block sum
gather_warp_execution_results(sh_mem, tidx);
sum_warp(sh_mem);
}
}


__device__ bool last_block(int32_t* counter, uint32_t total_blocks_count) {
__threadfence();

int last = 0;
if (threadIdx.x == 0) {
last = atomicAdd(counter, 1);
}

return __syncthreads_or(last == total_blocks_count - 1);
}


template <typename scalar_t>
__device__ void reduce_with_shared_memory(
scalar_t* __restrict__ sh_arr,
scalar_t current_thread_sum,
const uint16_t tidx,
const uint32_t bidx,
scalar_t* __restrict__ dev_tmp,
int32_t* __restrict__ dev_last_block_counter,
scalar_t* __restrict__ grad,
uint32_t total_number_of_blocks) {

// Put each thread sum element into shared memory (CUDA_MAX_NUM_THREADS_PER_BLOCK elements in total)
sh_arr[tidx] = current_thread_sum;

// Do warp reduction on the entire shared memory of a single block
reduce_in_block_using_warp_sums(sh_arr, tidx);

// Store the per-block sum for each block in the pre-allocated array (which has dimensions equal to grid dimensions)
if (tidx == 0) {
dev_tmp[bidx] = sh_arr[0];
}

// Synchronize blocks and make the last block of the grid do the reduction across the per-block sums
// to obtain final sums
if (last_block(dev_last_block_counter, total_number_of_blocks)) {

// WARNING: seems like this will only work for total number of blocks to reduce across that is < CUDA_MAX_NUM_THREADS_PER_BLOCK
sh_arr[tidx] = tidx < total_number_of_blocks ? dev_tmp[tidx] : static_cast<scalar_t>(0.0);
reduce_in_block_using_warp_sums(sh_arr, tidx);

if (tidx == 0) {
grad[0] = sh_arr[0];
}
}
}



// Remove this and other FP16 template specializations once arithmetic operators are implemented in c10
// for volatile c10::Half

__device__ void reduce_in_block_using_warp_sums_with_explicit_sync(c10::Half* __restrict__ sh_mem,
uint16_t tidx) {
__syncthreads();
sum_warp_with_explicit_sync(sh_mem + (tidx & ~(CUDA_WARP_SIZE - 1)));

__syncthreads();
if (tidx < CUDA_MAX_WARPS_PER_BLOCK) {
gather_warp_execution_results(sh_mem, tidx);
sum_warp_with_explicit_sync(sh_mem);
}

}

template <>
__device__ void reduce_with_shared_memory<c10::Half>(
c10::Half* __restrict__ sh_arr,
c10::Half sum,
const uint16_t tidx,
const uint32_t bidx,
c10::Half* __restrict__ dev_tmp,
int32_t* __restrict__ dev_last_block_counter,
c10::Half* __restrict__ grad,
uint32_t total_number_of_blocks) {
sh_arr[tidx] = sum;

reduce_in_block_using_warp_sums_with_explicit_sync(sh_arr, tidx);

if (tidx == 0) {
dev_tmp[bidx] = sh_arr[0];
}

if (last_block(dev_last_block_counter, total_number_of_blocks)) {
sh_arr[tidx] = tidx < gridDim.x ? dev_tmp[tidx] : static_cast<c10::Half>(0.0);

reduce_in_block_using_warp_sums_with_explicit_sync(sh_arr, tidx);

if (tidx == 0) {
grad[0] = sh_arr[0];
}
}
}


#endif // _COMMON_CUDA_FUNCS_CUH_
9 changes: 9 additions & 0 deletions nncf/extensions/include/common_defs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#ifndef _COMMON_DEFS_H_
#define _COMMON_DEFS_H_


#define CHECK_CPU(x) TORCH_CHECK(!x.is_cuda(), #x " must be a CPU tensor")
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")

#endif // _COMMON_DEFS_H_
21 changes: 21 additions & 0 deletions nncf/extensions/include/quantization/functions_cuda_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef _QUANTIZATION_FUNCTIONS_CUDA_IMPL_H_
#define _QUANTIZATION_FUNCTIONS_CUDA_IMPL_H_

at::Tensor q_cuda_forward(
at::Tensor input,
at::Tensor input_low,
at::Tensor input_range,
int levels);


std::vector<at::Tensor> q_cuda_backward(
at::Tensor grad_output,
at::Tensor input,
at::Tensor input_low,
at::Tensor input_range,
int levels,
int level_low,
int level_high);

#endif // _QUANTIZATION_FUNCTIONS_CUDA_IMPL_H_

Loading

0 comments on commit a0c1c2b

Please sign in to comment.