-
Notifications
You must be signed in to change notification settings - Fork 96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MSM improvements #372
MSM improvements #372
Conversation
.to_ark(); | ||
Self::ArkEquivalent::new_unchecked(proj_x, proj_y) | ||
if *self == Self::zero() { | ||
Self::ArkEquivalent::zero() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it more efficient than the else case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The else case just doesn't cover zero. new_unchecked
assumes that inputs represent a valid non-zero point.
@@ -168,14 +166,20 @@ where | |||
for batch_size in batch_sizes { | |||
let mut points = C::generate_random_affine_points(test_size * batch_size); | |||
let mut scalars = vec![C::ScalarField::zero(); test_size * batch_size]; | |||
|
|||
// add some zero points | |||
for _ in 0..100 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider moving this logic to generate_random_...()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the one hand we can
On the other, I don't think users of a function called generate_random_affine_points
expect zero points being sprinkled in. We're not promising cryptographic rng here or anything and this function shouldn't be used to create secure randomness but still...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I just thought that if you need it more than once it's worth writing it once. You could wrap this function in a test util function that accepts the probability of each point to be zeroed but you should decide if it makes sense to you. If not that's fine too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, creating a separate test util function makes sense, will do
unsigned start = (sorted_bucket_sizes_sum[tid] + nof_pts_per_thread - 1) / nof_pts_per_thread + tid; | ||
unsigned end = (sorted_bucket_sizes_sum[tid + 1] + nof_pts_per_thread - 1) / nof_pts_per_thread + tid + 1; | ||
for (unsigned i = start; i < end; i++) { | ||
bucket_indices[i] = tid | ((i - start) << log_nof_large_buckets); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you sure this is correct when nof_buckets is not a power of two?
If you assume it is, please add a comment and maybe also assert where calling the kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be correct for non-powers-of-2 and in most cases nof_large_buckets
is not a power of 2
What this line does is just packing two values into one number bucket_indices[i]
. tid
goes into the lowest log_nof_large_buckets
bits and i - start
goes into the rest. So log_nof_large_buckets
in this case just means the number of bits needed to represent tid
which varies from 0 to nof_large_buckets
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, as long as log_nof_large_buckets = ceil(log2(nof_large_buckets)) this is correct. Since it's not verified inside the kernel, it could mix the two fields so that's why I suggested to make verify it at kernel launch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In principle we can just take log inside the kernel, I just wanted to avoid all threads doing identical work. Verification is definitely cheaper though, can do it inside the kernel if you want
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My motivation is to avoid debugging cases like nof_buckets=100, (int)log2(nof_buckets)=6.
I would personally add a comment about this assumption next to the kernel param and make sure the host is calling correctly. If you prefer assert inside the kernel it's fine too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment though maybe I should've done a more full doc comment for each kernel (but they are internal so I don't feel the need to spend too much time documenting them tbh)
// sort by bucket sizes | ||
unsigned h_nof_buckets_to_compute; | ||
CHK_IF_RETURN(cudaMemcpyAsync( | ||
&h_nof_buckets_to_compute, nof_buckets_to_compute, sizeof(unsigned), cudaMemcpyDeviceToHost, stream)); | ||
|
||
// if all points are 0 just return point 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a question about removing this block. Turned out it was there for a good reason. If all the scalars are zero, there's a division by zero error inside. I will push a more elegant fix for this tomorrow.
Tried to explain the parts which raised questions during live review with comments, also dealt with all scalars being zero case. See the diagram with visual explanation of key changes in this PR: BenchmarksWe discussed that the change in large bucket accumulation demands measuring how performance changes with
One weird detail is that for uniform distributions |
const unsigned c, | ||
const unsigned threads_per_bucket, | ||
const unsigned max_run_length) | ||
const int points_per_thread, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not unsigned?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
honestly I don't fully understand the details of the large bucket accumulation, but overall looks good to me.
I approve but you may want Hadar to review too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
## Contents of this release [FEAT]: support for multi-device execution: #356 [FEAT]: full support for new mixed-radix NTT: #367, #368 and #371 [FEAT]: examples for Poseidon hash and tree builder based on it (currently only on C++ side): #375 [PERF]: MSM performance upgrades & zero point handling: #372
Describe the changes
MSM can now handle zero base points. They are represented as affine points with
x
andy
coordinates equal to zero which is (as far as I know) consistent with gnark and rapidsnark but not arkworks. Rust tests are changed accordingly.A number of performance and memory improvements have been made:
large_bucket_size
to make each thread in large bucket accumulation do the same amount of work as expected in normal accumulation. At the same time, this worsens potential memory bottleneck in large bucket accumulation because the amount of memory allocated here in the old version is proportional to the number of threads in the largest bucket times the number of large buckets. So if the largest bucket is really large and there are lots of much smaller large buckets, we might run out of memory. So I allocated only as much memory as necessary for each large bucket, depending on its size. This complicates the code but I think the speedup and memory savings are worth it.Benchmarks
Measurements are made on an RTX 3090Ti card for the bn254 curve, H2D memory operations not included.
In the first experiment, ~30% of scalars are equal to 1, there are also 10 random scalars each with frequency around 1%. The rest of scalars are chosen uniformly at random. In the second experiment, all scalars are chosen uniformly at random.
2^22
2^22
2^24
2^14
2^8
Failures and future work
I spent quite a bit of time trying to make reduction and scan methods from CUB and thrust work with our point types. I thought that it would be nice to let well-optimised CUDA libraries handle load balancing in bucket accumulation for us. As it turns out, there are many issues with this approach:
Overall I think this is a dead end and it's not worth trying to swap our custom bucket accumulation to anything CUB or thrust provide nowadays.
In terms of future work, arbitrary choice of
c
is still not supported as of this PR (though I think I can implement it pretty quickly) and a bunch of other improvements like xyzz accumulation, signed digits/scalars and computing MSM in chunks as described earlier are still in TODO status.