Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev v2 #95

Merged
merged 6 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
MSM for large sizes (#88)
* bugs fixed

* bugs fixed

* remove short msm from extern call

* code cleaning
  • Loading branch information
HadarIngonyama authored and jeremyfelder committed Jun 1, 2023
commit 9ebf3d4f340a4bdf460f0c080e4db9293d584266
4 changes: 4 additions & 0 deletions icicle/appUtils/msm/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
test_msm:
mkdir -p work
nvcc -o work/test_msm -I. tests/msm_test.cu
work/test_msm
43 changes: 20 additions & 23 deletions icicle/appUtils/msm/msm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,16 @@ __global__ void split_scalars_kernel(unsigned *buckets_indices, unsigned *point_

//this kernel adds up the points in each bucket
template <typename P, typename A>
__global__ void accumulate_buckets_kernel(P *__restrict__ buckets, unsigned *__restrict__ bucket_offsets,
unsigned *__restrict__ bucket_sizes, unsigned *__restrict__ single_bucket_indices, unsigned *__restrict__ point_indices, A *__restrict__ points, unsigned nof_buckets, unsigned batch_size, unsigned msm_idx_shift){
// __global__ void accumulate_buckets_kernel(P *__restrict__ buckets, unsigned *__restrict__ bucket_offsets,
// unsigned *__restrict__ bucket_sizes, unsigned *__restrict__ single_bucket_indices, unsigned *__restrict__ point_indices, A *__restrict__ points, unsigned nof_buckets, unsigned batch_size, unsigned msm_idx_shift){
__global__ void accumulate_buckets_kernel(P *buckets, unsigned *bucket_offsets, unsigned *bucket_sizes, unsigned *single_bucket_indices, unsigned *point_indices, A *points, unsigned nof_buckets, unsigned *nof_buckets_to_compute, unsigned msm_idx_shift){

unsigned tid = (blockIdx.x * blockDim.x) + threadIdx.x;
unsigned msm_index = single_bucket_indices[tid]>>msm_idx_shift;
unsigned bucket_index = msm_index * nof_buckets + (single_bucket_indices[tid]&((1<<msm_idx_shift)-1));
unsigned bucket_size = bucket_sizes[tid];
if (tid>=nof_buckets*batch_size || bucket_size == 0){ //if the bucket is empty we don't need to continue
if (tid>=*nof_buckets_to_compute){
return;
}
unsigned msm_index = single_bucket_indices[tid]>>msm_idx_shift;
unsigned bucket_index = msm_index * nof_buckets + (single_bucket_indices[tid]&((1<<msm_idx_shift)-1));
unsigned bucket_offset = bucket_offsets[tid];
for (unsigned i = 0; i < bucket_sizes[tid]; i++) //add the relevant points starting from the relevant offset up to the bucket size
{
Expand All @@ -106,7 +106,7 @@ template <typename P>
__global__ void big_triangle_sum_kernel(P* buckets, P* final_sums, unsigned nof_bms, unsigned c){

unsigned tid = (blockIdx.x * blockDim.x) + threadIdx.x;
if (tid>nof_bms) return;
if (tid>=nof_bms) return;
P line_sum = buckets[(tid+1)*(1<<c)-1];
final_sums[tid] = line_sum;
for (unsigned i = (1<<c)-2; i >0; i--)
Expand Down Expand Up @@ -195,12 +195,13 @@ void bucket_method_msm(unsigned bitsize, unsigned c, S *scalars, A *points, unsi
NUM_BLOCKS = (size * (nof_bms+1) + NUM_THREADS - 1) / NUM_THREADS;
split_scalars_kernel<<<NUM_BLOCKS, NUM_THREADS>>>(bucket_indices + size, point_indices + size, d_scalars, size, msm_log_size,
nof_bms, bm_bitsize, c); //+size - leaving the first bm free for the out of place sort later

//sort indices - the indices are sorted from smallest to largest in order to group together the points that belong to each bucket
unsigned *sort_indices_temp_storage{};
size_t sort_indices_temp_storage_bytes;
cub::DeviceRadixSort::SortPairs(sort_indices_temp_storage, sort_indices_temp_storage_bytes, bucket_indices + size, bucket_indices,
point_indices + size, point_indices, size);

cudaMalloc(&sort_indices_temp_storage, sort_indices_temp_storage_bytes);
for (unsigned i = 0; i < nof_bms; i++) {
unsigned offset_out = i * size;
Expand Down Expand Up @@ -240,7 +241,7 @@ void bucket_method_msm(unsigned bitsize, unsigned c, S *scalars, A *points, unsi
NUM_THREADS = 1 << 8;
NUM_BLOCKS = (nof_buckets + NUM_THREADS - 1) / NUM_THREADS;
accumulate_buckets_kernel<<<NUM_BLOCKS, NUM_THREADS>>>(buckets, bucket_offsets, bucket_sizes, single_bucket_indices, point_indices,
d_points, nof_buckets, 1, c+bm_bitsize);
d_points, nof_buckets, nof_buckets_to_compute, c+bm_bitsize);

#ifdef SSM_SUM
//sum each bucket
Expand Down Expand Up @@ -351,12 +352,6 @@ void batched_bucket_method_msm(unsigned bitsize, unsigned c, S *scalars, A *poin
cub::DeviceRadixSort::SortPairs(sort_indices_temp_storage, sort_indices_temp_storage_bytes, bucket_indices + msm_size, sorted_bucket_indices,
point_indices + msm_size, sorted_point_indices, total_size * nof_bms);
cudaMalloc(&sort_indices_temp_storage, sort_indices_temp_storage_bytes);
// for (unsigned i = 0; i < nof_bms*batch_size; i++) {
// unsigned offset_out = i * msm_size;
// unsigned offset_in = offset_out + msm_size;
// cub::DeviceRadixSort::SortPairs(sort_indices_temp_storage, sort_indices_temp_storage_bytes, bucket_indices + offset_in,
// bucket_indices + offset_out, point_indices + offset_in, point_indices + offset_out, msm_size);
// }
cub::DeviceRadixSort::SortPairs(sort_indices_temp_storage, sort_indices_temp_storage_bytes, bucket_indices + msm_size, sorted_bucket_indices,
point_indices + msm_size, sorted_point_indices, total_size * nof_bms);
cudaFree(sort_indices_temp_storage);
Expand Down Expand Up @@ -391,7 +386,7 @@ void batched_bucket_method_msm(unsigned bitsize, unsigned c, S *scalars, A *poin
NUM_THREADS = 1 << 8;
NUM_BLOCKS = (total_nof_buckets + NUM_THREADS - 1) / NUM_THREADS;
accumulate_buckets_kernel<<<NUM_BLOCKS, NUM_THREADS>>>(buckets, bucket_offsets, bucket_sizes, single_bucket_indices, sorted_point_indices,
d_points, nof_buckets, batch_size, c+bm_bitsize);
d_points, nof_buckets, total_nof_buckets_to_compute, c+bm_bitsize);

#ifdef SSM_SUM
//sum each bucket
Expand Down Expand Up @@ -424,7 +419,7 @@ void batched_bucket_method_msm(unsigned bitsize, unsigned c, S *scalars, A *poin
NUM_THREADS = 1<<8;
NUM_BLOCKS = (batch_size + NUM_THREADS - 1) / NUM_THREADS;
final_accumulation_kernel<P, S><<<NUM_BLOCKS,NUM_THREADS>>>(bm_sums, on_device ? final_results : d_final_results, batch_size, nof_bms, c);

//copy final result to host
cudaDeviceSynchronize();
if (!on_device)
Expand Down Expand Up @@ -461,8 +456,7 @@ __global__ void to_proj_kernel(A* affine_points, P* proj_points, unsigned N){

//the function computes msm using ssm
template <typename S, typename P, typename A>
void short_msm(S *h_scalars, A *h_points, unsigned size, P* h_final_result, bool on_device){ //works up to 2^8

void short_msm(S *h_scalars, A *h_points, unsigned size, P* h_final_result){ //works up to 2^8
S *scalars;
A *a_points;
P *p_points;
Expand Down Expand Up @@ -507,12 +501,12 @@ void short_msm(S *h_scalars, A *h_points, unsigned size, P* h_final_result, bool
template <typename A, typename S, typename P>
void reference_msm(S* scalars, A* a_points, unsigned size){

P points[size];
P *points = new P[size];
// P points[size];
for (unsigned i = 0; i < size ; i++)
{
points[i] = P::from_affine(a_points[i]);
}


P res = P::zero();

Expand All @@ -527,7 +521,10 @@ void reference_msm(S* scalars, A* a_points, unsigned size){
}

unsigned get_optimal_c(const unsigned size) {
return 10;
if (size < 17)
return 1;
// return 15;
return ceil(log2(size))-4;
}

//this function is used to compute msms of size larger than 256
Expand All @@ -549,4 +546,4 @@ void batched_large_msm(S* scalars, A* points, unsigned batch_size, unsigned msm_
unsigned bitsize = 255;
batched_bucket_method_msm(bitsize, c, scalars, points, batch_size, msm_size, result, on_device);
}
#endif
#endif
2 changes: 2 additions & 0 deletions icicle/appUtils/msm/msm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,6 @@ void large_msm(S* scalars, A* points, unsigned size, P* result, bool on_device);
template <typename S, typename P, typename A>
void short_msm(S *h_scalars, A *h_points, unsigned size, P* h_final_result, bool on_device);

template <typename A, typename S, typename P>
void reference_msm(S* scalars, A* a_points, unsigned size);
#endif
188 changes: 188 additions & 0 deletions icicle/appUtils/msm/tests/msm_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
#include <iostream>
#include <chrono>
#include <vector>
#include "msm.cu"
#include "../../utils/cuda_utils.cuh"
#include "../../primitives/projective.cuh"
#include "../../primitives/field.cuh"
#include "../../curves/bls12_381/curve_config.cuh"

using namespace BLS12_381;

class Dummy_Scalar {
public:
static constexpr unsigned NBITS = 32;

unsigned x;

friend HOST_INLINE std::ostream& operator<<(std::ostream& os, const Dummy_Scalar& scalar) {
os << scalar.x;
return os;
}

HOST_DEVICE_INLINE unsigned get_scalar_digit(unsigned digit_num, unsigned digit_width) {
return (x>>(digit_num*digit_width))&((1<<digit_width)-1);
}

friend HOST_DEVICE_INLINE Dummy_Scalar operator+(Dummy_Scalar p1, const Dummy_Scalar& p2) {
return {p1.x+p2.x};
}

friend HOST_DEVICE_INLINE bool operator==(const Dummy_Scalar& p1, const Dummy_Scalar& p2) {
return (p1.x == p2.x);
}

friend HOST_DEVICE_INLINE bool operator==(const Dummy_Scalar& p1, const unsigned p2) {
return (p1.x == p2);
}

// static HOST_DEVICE_INLINE Dummy_Scalar neg(const Dummy_Scalar &scalar) {
// return {Dummy_Scalar::neg(point.x)};
// }
static HOST_INLINE Dummy_Scalar rand_host() {
return {(unsigned)rand()};
}
};

class Dummy_Projective {

public:
Dummy_Scalar x;

static HOST_DEVICE_INLINE Dummy_Projective zero() {
return {0};
}

static HOST_DEVICE_INLINE Dummy_Projective to_affine(const Dummy_Projective &point) {
return {point.x};
}

static HOST_DEVICE_INLINE Dummy_Projective from_affine(const Dummy_Projective &point) {
return {point.x};
}

// static HOST_DEVICE_INLINE Dummy_Projective neg(const Dummy_Projective &point) {
// return {Dummy_Scalar::neg(point.x)};
// }

friend HOST_DEVICE_INLINE Dummy_Projective operator+(Dummy_Projective p1, const Dummy_Projective& p2) {
return {p1.x+p2.x};
}

// friend HOST_DEVICE_INLINE Dummy_Projective operator-(Dummy_Projective p1, const Dummy_Projective& p2) {
// return p1 + neg(p2);
// }

friend HOST_INLINE std::ostream& operator<<(std::ostream& os, const Dummy_Projective& point) {
os << point.x;
return os;
}

friend HOST_DEVICE_INLINE Dummy_Projective operator*(Dummy_Scalar scalar, const Dummy_Projective& point) {
Dummy_Projective res = zero();
#ifdef CUDA_ARCH
#pragma unroll
#endif
for (int i = 0; i < Dummy_Scalar::NBITS; i++) {
if (i > 0) {
res = res + res;
}
if (scalar.get_scalar_digit(Dummy_Scalar::NBITS - i - 1, 1)) {
res = res + point;
}
}
return res;
}

friend HOST_DEVICE_INLINE bool operator==(const Dummy_Projective& p1, const Dummy_Projective& p2) {
return (p1.x == p2.x);
}

static HOST_DEVICE_INLINE bool is_zero(const Dummy_Projective &point) {
return point.x == 0;
}

static HOST_INLINE Dummy_Projective rand_host() {
return {(unsigned)rand()};
}
};

//switch between dummy and real:

typedef scalar_t test_scalar;
typedef projective_t test_projective;
typedef affine_t test_affine;

// typedef Dummy_Scalar test_scalar;
// typedef Dummy_Projective test_projective;
// typedef Dummy_Projective test_affine;

int main()
{
unsigned batch_size = 4;
unsigned msm_size = 1<<15;
unsigned N = batch_size*msm_size;

test_scalar *scalars = new test_scalar[N];
test_affine *points = new test_affine[N];

for (unsigned i=0;i<N;i++){
scalars[i] = (i%msm_size < 10)? test_scalar::rand_host() : scalars[i-10];
points[i] = (i%msm_size < 10)? test_projective::to_affine(test_projective::rand_host()): points[i-10];
// scalars[i] = test_scalar::rand_host();
// points[i] = test_projective::to_affine(test_projective::rand_host());
}
std::cout<<"finished generating"<<std::endl;

// projective_t *short_res = (projective_t*)malloc(sizeof(projective_t));
// test_projective *large_res = (test_projective*)malloc(sizeof(test_projective));
test_projective large_res[batch_size];
test_projective batched_large_res[batch_size];
// fake_point *large_res = (fake_point*)malloc(sizeof(fake_point));
// fake_point batched_large_res[256];


// short_msm<scalar_t, projective_t, affine_t>(scalars, points, N, short_res);
for (unsigned i=0;i<batch_size;i++){
large_msm<test_scalar, test_projective, test_affine>(scalars+msm_size*i, points+msm_size*i, msm_size, large_res+i, false);
// std::cout<<"final result large"<<std::endl;
// std::cout<<test_projective::to_affine(*large_res)<<std::endl;
}
auto begin = std::chrono::high_resolution_clock::now();
batched_large_msm<test_scalar, test_projective, test_affine>(scalars, points, batch_size, msm_size, batched_large_res, false);
// large_msm<test_scalar, test_projective, test_affine>(scalars, points, msm_size, large_res, false);
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
printf("Time measured: %.3f seconds.\n", elapsed.count() * 1e-9);
std::cout<<test_projective::to_affine(large_res[0])<<std::endl;

// reference_msm<test_affine, test_scalar, test_projective>(scalars, points, msm_size);

std::cout<<"final results batched large"<<std::endl;
bool success = true;
for (unsigned i = 0; i < batch_size; i++)
{
std::cout<<test_projective::to_affine(batched_large_res[i])<<std::endl;
if (test_projective::to_affine(large_res[i])==test_projective::to_affine(batched_large_res[i])){
std::cout<<"good"<<std::endl;
}
else{
std::cout<<"miss"<<std::endl;
std::cout<<test_projective::to_affine(large_res[i])<<std::endl;
success = false;
}
}
if (success){
std::cout<<"success!"<<std::endl;
}

// std::cout<<batched_large_res[0]<<std::endl;
// std::cout<<batched_large_res[1]<<std::endl;
// std::cout<<projective_t::to_affine(batched_large_res[0])<<std::endl;
// std::cout<<projective_t::to_affine(batched_large_res[1])<<std::endl;

// std::cout<<"final result short"<<std::endl;
// std::cout<<pr<<std::endl;

return 0;
}
8 changes: 1 addition & 7 deletions icicle/curves/bls12_377/msm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,7 @@ int msm_cuda_bls12_377(BLS12_377::projective_t *out, BLS12_377::affine_t points[
{
try
{
if (count>256){
large_msm<BLS12_377::scalar_t, BLS12_377::projective_t, BLS12_377::affine_t>(scalars, points, count, out, false);
}
else{
short_msm<BLS12_377::scalar_t, BLS12_377::projective_t, BLS12_377::affine_t>(scalars, points, count, out, false);
}

large_msm<BLS12_377::scalar_t, BLS12_377::projective_t, BLS12_377::affine_t>(scalars, points, count, out, false);
return CUDA_SUCCESS;
}
catch (const std::runtime_error &ex)
Expand Down
8 changes: 1 addition & 7 deletions icicle/curves/bls12_381/msm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,7 @@ int msm_cuda_bls12_381(BLS12_381::projective_t *out, BLS12_381::affine_t points[
{
try
{
if (count>256){
large_msm<BLS12_381::scalar_t, BLS12_381::projective_t, BLS12_381::affine_t>(scalars, points, count, out, false);
}
else{
short_msm<BLS12_381::scalar_t, BLS12_381::projective_t, BLS12_381::affine_t>(scalars, points, count, out, false);
}

large_msm<BLS12_381::scalar_t, BLS12_381::projective_t, BLS12_381::affine_t>(scalars, points, count, out, false);
return CUDA_SUCCESS;
}
catch (const std::runtime_error &ex)
Expand Down
8 changes: 1 addition & 7 deletions icicle/curves/bn254/msm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,7 @@ int msm_cuda_bn254(BN254::projective_t *out, BN254::affine_t points[],
{
try
{
if (count>256){
large_msm<BN254::scalar_t, BN254::projective_t, BN254::affine_t>(scalars, points, count, out, false);
}
else{
short_msm<BN254::scalar_t, BN254::projective_t, BN254::affine_t>(scalars, points, count, out, false);
}

large_msm<BN254::scalar_t, BN254::projective_t, BN254::affine_t>(scalars, points, count, out, false);
return CUDA_SUCCESS;
}
catch (const std::runtime_error &ex)
Expand Down
8 changes: 1 addition & 7 deletions icicle/curves/curve_template/msm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,7 @@ int msm_cuda_CURVE_NAME_L(CURVE_NAME_U::projective_t *out, CURVE_NAME_U::affine_
{
try
{
if (count>256){
large_msm<CURVE_NAME_U::scalar_t, CURVE_NAME_U::projective_t, CURVE_NAME_U::affine_t>(scalars, points, count, out, false);
}
else{
short_msm<CURVE_NAME_U::scalar_t, CURVE_NAME_U::projective_t, CURVE_NAME_U::affine_t>(scalars, points, count, out, false);
}

large_msm<CURVE_NAME_U::scalar_t, CURVE_NAME_U::projective_t, CURVE_NAME_U::affine_t>(scalars, points, count, out, false);
return CUDA_SUCCESS;
}
catch (const std::runtime_error &ex)
Expand Down