Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MSM improvements #372

Merged
merged 14 commits into from
Feb 15, 2024
Merged

MSM improvements #372

merged 14 commits into from
Feb 15, 2024

Conversation

DmytroTym
Copy link
Contributor

Describe the changes

MSM can now handle zero base points. They are represented as affine points with x and y coordinates equal to zero which is (as far as I know) consistent with gnark and rapidsnark but not arkworks. Rust tests are changed accordingly.

A number of performance and memory improvements have been made:

  1. Kernel launches, allocations and frees are moved around to minimise memory footprint and better parallelise copying bases from host to device and sorting scalars. Still, compute often has to wait for copying bases because it takes significantly more time than scalar sorting. The way to solve it in the future is probably doing what Matter Labs are doing - computing MSM in chunks that would allow masking uploading the next chunk of base points with bucket accumulation from the previous chunk.
  2. Speaking of sorting scalars, instead of sorting indices for each bucket module individually, it's now done for all indices at once. While this requires a bit more memory and theoretically should take more operations that one-by-one approach, it turns out to be faster in practice. Not just for smaller MSM but for large ones as well. This also allows to automatically remove zero buckets.
  3. Removing zero buckets allows us to avoid using them in reduction. Though this makes the code more convoluted and doesn't help too much for large MSMs, it boosts small MSMs quite a bit.
  4. For large buckets, I increased the number of threads per bucket by a factor of large_bucket_size to make each thread in large bucket accumulation do the same amount of work as expected in normal accumulation. At the same time, this worsens potential memory bottleneck in large bucket accumulation because the amount of memory allocated here in the old version is proportional to the number of threads in the largest bucket times the number of large buckets. So if the largest bucket is really large and there are lots of much smaller large buckets, we might run out of memory. So I allocated only as much memory as necessary for each large bucket, depending on its size. This complicates the code but I think the speedup and memory savings are worth it.

Benchmarks

Measurements are made on an RTX 3090Ti card for the bn254 curve, H2D memory operations not included.

In the first experiment, ~30% of scalars are equal to 1, there are also 10 random scalars each with frequency around 1%. The rest of scalars are chosen uniformly at random. In the second experiment, all scalars are chosen uniformly at random.

# MSM size Batch size Old version, ms. New version, ms.
1 2^22 1 33.8 29.8
1 2^22 3 68.2 56.3
2 2^24 1 151.5 149.8
2 2^14 2^8 79.5 73.1

Failures and future work

I spent quite a bit of time trying to make reduction and scan methods from CUB and thrust work with our point types. I thought that it would be nice to let well-optimised CUDA libraries handle load balancing in bucket accumulation for us. As it turns out, there are many issues with this approach:

  • It seems that CUB and thrust are optimised for minimising memory movement rather than compute but our EC addition is very much compute bound;
  • Compile time grows a lot when trying to use our primitives inside CUB/thrust functions;
  • I wasn't even able to get correctness, results seem to always be zero or random, no idea why.

Overall I think this is a dead end and it's not worth trying to swap our custom bucket accumulation to anything CUB or thrust provide nowadays.

In terms of future work, arbitrary choice of c is still not supported as of this PR (though I think I can implement it pretty quickly) and a bunch of other improvements like xyzz accumulation, signed digits/scalars and computing MSM in chunks as described earlier are still in TODO status.

.to_ark();
Self::ArkEquivalent::new_unchecked(proj_x, proj_y)
if *self == Self::zero() {
Self::ArkEquivalent::zero()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it more efficient than the else case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The else case just doesn't cover zero. new_unchecked assumes that inputs represent a valid non-zero point.

@@ -168,14 +166,20 @@ where
for batch_size in batch_sizes {
let mut points = C::generate_random_affine_points(test_size * batch_size);
let mut scalars = vec![C::ScalarField::zero(); test_size * batch_size];

// add some zero points
for _ in 0..100 {
Copy link
Collaborator

@yshekel yshekel Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider moving this logic to generate_random_...()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the one hand we can
On the other, I don't think users of a function called generate_random_affine_points expect zero points being sprinkled in. We're not promising cryptographic rng here or anything and this function shouldn't be used to create secure randomness but still...

Copy link
Collaborator

@yshekel yshekel Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I just thought that if you need it more than once it's worth writing it once. You could wrap this function in a test util function that accepts the probability of each point to be zeroed but you should decide if it makes sense to you. If not that's fine too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, creating a separate test util function makes sense, will do

unsigned start = (sorted_bucket_sizes_sum[tid] + nof_pts_per_thread - 1) / nof_pts_per_thread + tid;
unsigned end = (sorted_bucket_sizes_sum[tid + 1] + nof_pts_per_thread - 1) / nof_pts_per_thread + tid + 1;
for (unsigned i = start; i < end; i++) {
bucket_indices[i] = tid | ((i - start) << log_nof_large_buckets);
Copy link
Collaborator

@yshekel yshekel Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure this is correct when nof_buckets is not a power of two?
If you assume it is, please add a comment and maybe also assert where calling the kernel.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be correct for non-powers-of-2 and in most cases nof_large_buckets is not a power of 2
What this line does is just packing two values into one number bucket_indices[i]. tid goes into the lowest log_nof_large_buckets bits and i - start goes into the rest. So log_nof_large_buckets in this case just means the number of bits needed to represent tid which varies from 0 to nof_large_buckets

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, as long as log_nof_large_buckets = ceil(log2(nof_large_buckets)) this is correct. Since it's not verified inside the kernel, it could mix the two fields so that's why I suggested to make verify it at kernel launch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle we can just take log inside the kernel, I just wanted to avoid all threads doing identical work. Verification is definitely cheaper though, can do it inside the kernel if you want

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My motivation is to avoid debugging cases like nof_buckets=100, (int)log2(nof_buckets)=6.

I would personally add a comment about this assumption next to the kernel param and make sure the host is calling correctly. If you prefer assert inside the kernel it's fine too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment though maybe I should've done a more full doc comment for each kernel (but they are internal so I don't feel the need to spend too much time documenting them tbh)

// sort by bucket sizes
unsigned h_nof_buckets_to_compute;
CHK_IF_RETURN(cudaMemcpyAsync(
&h_nof_buckets_to_compute, nof_buckets_to_compute, sizeof(unsigned), cudaMemcpyDeviceToHost, stream));

// if all points are 0 just return point 0
Copy link
Contributor Author

@DmytroTym DmytroTym Feb 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a question about removing this block. Turned out it was there for a good reason. If all the scalars are zero, there's a division by zero error inside. I will push a more elegant fix for this tomorrow.

@DmytroTym
Copy link
Contributor Author

Tried to explain the parts which raised questions during live review with comments, also dealt with all scalars being zero case. See the diagram with visual explanation of key changes in this PR:

MSM_PR

Benchmarks

We discussed that the change in large bucket accumulation demands measuring how performance changes with large_bucket_factor changing, both for skewed and uniform distributions. For skewed distributions, I looked at Lurk MSM (cc: @omershlo) and tried emulating data in their test number 1 (their raw data is unavailable plus afaik it's on grumpkin curve which we don't yet support, though it's in the works). For the second experiment, I just used uniform distribution. bn254 curve is used, it should have similar performance to grumpkin. RTX A6000, the same GPU as in Lurk MSM experiments, has been used.

Lurk MSM test 1 Size 9699051 uniform MSM
pasta-msm on uniform distribution, ms. 119.79 144.74
lurkrs, ms. 552.09 -
lurkrs compressed (hypothetical), ms. 19.74 -
ICICLE on dev branch, ms. 384.99 134.18
ICICLE on dev with optimal large_bucket_factor, ms. 164.62 127.2
Optimal large_bucket_factor for ICICLE on dev 4 0
ICICLE on this branch, ms. 133.4 174.85
ICICLE on this with optimal large_bucket_factor, ms. 132.69 126.9
Optimal large_bucket_factor for ICICLE on this 15 0

One weird detail is that for uniform distributions large_bucket_factor=0 is optimal for both the old and (especially) new versions. This is not the case on GPUs I tested on before, only on RTX A6000 and I didn't have access to thorough profiling due to renting it in the cloud. I would suggest that this is weirdness that needs further investigation in the future. Otherwise the results for the new version vary very little for most positive values of large_bucket_factor and 10 seems to be a reasonable default choice to me, don't see any reason to change it.

const unsigned c,
const unsigned threads_per_bucket,
const unsigned max_run_length)
const int points_per_thread,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not unsigned?

Copy link
Collaborator

@yshekel yshekel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

honestly I don't fully understand the details of the large bucket accumulation, but overall looks good to me.
I approve but you may want Hadar to review too.

Copy link
Contributor

@LeonHibnik LeonHibnik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@LeonHibnik LeonHibnik merged commit a91397e into dev Feb 15, 2024
14 checks passed
@LeonHibnik LeonHibnik deleted the develop/dima/msm_improvements branch February 15, 2024 18:02
@DmytroTym DmytroTym mentioned this pull request Feb 15, 2024
DmytroTym added a commit that referenced this pull request Feb 15, 2024
## Contents of this release

[FEAT]: support for multi-device execution:
#356
[FEAT]: full support for new mixed-radix NTT:
#367,
#368 and
#371
[FEAT]: examples for Poseidon hash and tree builder based on it
(currently only on C++ side):
#375
[PERF]: MSM performance upgrades & zero point handling:
#372
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants