Skip to content

Commit

Permalink
reduce memory consumption in hash_blocks (ingonyama-zk#100)
Browse files Browse the repository at this point in the history
* reduce memory consumption in hash_blocks
  • Loading branch information
ChickenLover committed Jun 8, 2023
1 parent 434ab70 commit 26f2f5c
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 28 deletions.
33 changes: 18 additions & 15 deletions icicle/appUtils/poseidon/poseidon.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "poseidon.cuh"

template <typename S>
__global__ void prepare_poseidon_states(S * inp, S * states, size_t number_of_states, S domain_tag, const PoseidonConfiguration<S> config) {
__global__ void prepare_poseidon_states(S * states, size_t number_of_states, S domain_tag, const PoseidonConfiguration<S> config) {
int idx = (blockIdx.x * blockDim.x) + threadIdx.x;
int state_number = idx / config.t;
if (state_number >= number_of_states) {
Expand All @@ -15,7 +15,7 @@ __global__ void prepare_poseidon_states(S * inp, S * states, size_t number_of_st
if (element_number == 0) {
prepared_element = domain_tag;
} else {
prepared_element = inp[state_number * (config.t - 1) + element_number - 1];
prepared_element = states[state_number * config.t + element_number - 1];
}

// Add pre-round constant
Expand Down Expand Up @@ -149,16 +149,20 @@ __global__ void get_hash_results(S * states, size_t number_of_states, S * out, i

template <typename S>
__host__ void Poseidon<S>::hash_blocks(const S * inp, size_t blocks, S * out, HashType hash_type) {
// Used in matrix multiplication

S * states, * inp_device;
S * states;

// allocate memory for {blocks} states of {t} scalars each
cudaMalloc(&states, blocks * this->t * sizeof(S));
if (cudaMalloc(&states, blocks * this->t * sizeof(S)) != CUDA_SUCCESS) {
throw std::runtime_error("Failed memory allocation on the device");
}

// Move input to cuda
cudaMalloc(&inp_device, blocks * (this->t - 1) * sizeof(S));
cudaMemcpy(inp_device, inp, blocks * (this->t - 1) * sizeof(S), cudaMemcpyHostToDevice);
// To-Do: consider using async to not wait for the huge memcpy to finish
// This is where the input matrix of size Arity x NumberOfBlocks is
// padded and coppied to device in a T x NumberOfBlocks matrix
cudaMemcpy2D(states, this->t * sizeof(S), // Device pointer and device pitch
inp, (this->t - 1) * sizeof(S), // Host pointer and pitch
(this->t - 1) * sizeof(S), blocks, // Size of the source matrix (Arity x NumberOfBlocks)
cudaMemcpyHostToDevice);

size_t rc_offset = 0;

Expand Down Expand Up @@ -191,14 +195,13 @@ __host__ void Poseidon<S>::hash_blocks(const S * inp, size_t blocks, S * out, Ha
#endif

// Domain separation and adding pre-round constants
prepare_poseidon_states <<< number_of_blocks, number_of_threads >>> (inp_device, states, blocks, domain_tag, this->config);
prepare_poseidon_states <<< number_of_blocks, number_of_threads >>> (states, blocks, domain_tag, this->config);
rc_offset += this->t;
cudaFree(inp_device);

#if !defined(__CUDA_ARCH__) && defined(DEBUG)
cudaDeviceSynchronize();
std::cout << "Domain separation: " << rc_offset << std::endl;
print_buffer_from_cuda<S>(states, blocks * this->t);
//print_buffer_from_cuda<S>(states, blocks * this->t);

auto end_time = std::chrono::high_resolution_clock::now();
auto elapsed_time = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
Expand All @@ -213,7 +216,7 @@ __host__ void Poseidon<S>::hash_blocks(const S * inp, size_t blocks, S * out, Ha
#if !defined(__CUDA_ARCH__) && defined(DEBUG)
cudaDeviceSynchronize();
std::cout << "Full rounds 1. RCOFFSET: " << rc_offset << std::endl;
print_buffer_from_cuda<S>(states, blocks * this->t);
// print_buffer_from_cuda<S>(states, blocks * this->t);

end_time = std::chrono::high_resolution_clock::now();
elapsed_time = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
Expand All @@ -228,7 +231,7 @@ __host__ void Poseidon<S>::hash_blocks(const S * inp, size_t blocks, S * out, Ha
#if !defined(__CUDA_ARCH__) && defined(DEBUG)
cudaDeviceSynchronize();
std::cout << "Partial rounds. RCOFFSET: " << rc_offset << std::endl;
print_buffer_from_cuda<S>(states, blocks * this->t);
//print_buffer_from_cuda<S>(states, blocks * this->t);

end_time = std::chrono::high_resolution_clock::now();
elapsed_time = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
Expand All @@ -242,7 +245,7 @@ __host__ void Poseidon<S>::hash_blocks(const S * inp, size_t blocks, S * out, Ha
#if !defined(__CUDA_ARCH__) && defined(DEBUG)
cudaDeviceSynchronize();
std::cout << "Full rounds 2. RCOFFSET: " << rc_offset << std::endl;
print_buffer_from_cuda<S>(states, blocks * this->t);
//print_buffer_from_cuda<S>(states, blocks * this->t);
end_time = std::chrono::high_resolution_clock::now();
elapsed_time = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
std::cout << "Elapsed time: " << elapsed_time.count() << " ms" << std::endl;
Expand Down
10 changes: 4 additions & 6 deletions icicle/appUtils/poseidon/poseidon.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,16 @@
#include <sstream>
#include <chrono>

#define ARITY 3

template <typename S>
__host__ void print_buffer_from_cuda(S * device_ptr, size_t size) {
__host__ void print_buffer_from_cuda(S * device_ptr, size_t size, size_t t) {
S * buffer = static_cast< S * >(malloc(size * sizeof(S)));
cudaMemcpy(buffer, device_ptr, size * sizeof(S), cudaMemcpyDeviceToHost);

std::cout << "Start print" << std::endl;
for(int i = 0; i < size / ARITY; i++) {
for(int i = 0; i < size / t; i++) {
std::cout << "State #" << i << std::endl;
for (int j = 0; j < ARITY; j++) {
std::cout << buffer[i * ARITY + j] << std::endl;
for (int j = 0; j < t; j++) {
std::cout << buffer[i * t + j] << std::endl;
}
std::cout << std::endl;
}
Expand Down
3 changes: 1 addition & 2 deletions icicle/appUtils/poseidon/poseidon_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@ int main(int argc, char* argv[]) {

Poseidon<BLS12_381::scalar_t> poseidon(arity);

int number_of_blocks = 4;
int number_of_blocks = 1024;

BLS12_381::scalar_t input = BLS12_381::scalar_t::zero();
BLS12_381::scalar_t * in_ptr = static_cast< BLS12_381::scalar_t * >(malloc(number_of_blocks * arity * sizeof(BLS12_381::scalar_t)));
for (uint32_t i = 0; i < number_of_blocks * arity; i++) {
// std::cout << input << std::endl;
in_ptr[i] = input;
input = input + BLS12_381::scalar_t::one();
}
Expand Down
16 changes: 11 additions & 5 deletions src/test_bls12_381.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,19 +234,25 @@ extern "C" {
) -> c_int;
}

pub fn poseidon_multi_bls12_381(input: &[ScalarField_BLS12_381], arity: usize, device_id: usize) -> Vec<ScalarField_BLS12_381> {
pub fn poseidon_multi_bls12_381(input: &[ScalarField_BLS12_381], arity: usize, device_id: usize) -> Result<Vec<ScalarField_BLS12_381>, std::io::Error> {
let number_of_blocks = input.len() / arity;
let mut out = vec![ScalarField_BLS12_381::zero(); number_of_blocks];
unsafe {
poseidon_multi_cuda_bls12_381(
let res = poseidon_multi_cuda_bls12_381(
input as *const _ as *const ScalarField_BLS12_381,
out.as_mut_slice() as *mut _ as *mut ScalarField_BLS12_381,
number_of_blocks,
arity as c_uint,
device_id);
device_id
);

// TO-DO: go for better expression of error types
if res != 0 {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "Error executing poseidon_multi"));
}
}

out
Ok(out)
}

pub fn msm_bls12_381(points: &[PointAffineNoInfinity_BLS12_381], scalars: &[ScalarField_BLS12_381], device_id: usize) -> Point_BLS12_381 {
Expand Down Expand Up @@ -925,7 +931,7 @@ pub(crate) mod tests_bls12_381 {
for arity in arities {
// Generate scalars sequence [0, 1, ... arity * number_of_blocks]
let scalars: Vec<ScalarField_BLS12_381> = (0..arity * number_of_blocks).map(|i| ScalarField_BLS12_381::from_ark(Fr_BLS12_381::from(i as i32).into_repr())).collect();
let out = poseidon_multi_bls12_381(&scalars, arity, 0);
let out = poseidon_multi_bls12_381(&scalars, arity, 0).unwrap();

let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push(format!("test_vectors/poseidon_1024_{}", arity));
Expand Down

0 comments on commit 26f2f5c

Please sign in to comment.