diff --git a/ff/baby_bear.hpp b/ff/baby_bear.hpp index c7439f0..183e720 100644 --- a/ff/baby_bear.hpp +++ b/ff/baby_bear.hpp @@ -2,60 +2,380 @@ // Licensed under the Apache License, Version 2.0, see LICENSE for details. // SPDX-License-Identifier: Apache-2.0 -#ifdef __NVCC__ -# include "bb31_t.cuh" // device-side field types -# ifndef __CUDA_ARCH__ // host-side stand-in to make CUDA code compile, -# include // and provide some debugging support, but - // not to produce correct computational result... - -# if defined(__GNUC__) || defined(__clang__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wunused-parameter" -# endif -class bb31_t { - uint32_t val; +#ifndef __SPPARK_FF_BABY_BEAR_HPP__ +#define __SPPARK_FF_BABY_BEAR_HPP__ - static const uint32_t M = 0x77ffffff; -public: +#ifdef __CUDACC__ // CUDA device-side field types +# include +# include "mont32_t.cuh" +# define inline __device__ __forceinline__ + +using bb31_base = mont32_t<31, 0x78000001, 0x77ffffff, 0x45dddde3, 0x0ffffffe>; + +struct bb31_t : public bb31_base { using mem_t = bb31_t; - static const uint32_t degree = 1; - static const uint32_t nbits = 31; - static const uint32_t MOD = 0x78000001; - inline bb31_t() {} - inline bb31_t(uint32_t a) : val(a) {} + inline bb31_t() {} + inline bb31_t(const bb31_base& a) : bb31_base(a) {} + inline bb31_t(const uint32_t *p) : bb31_base(p) {} // this is used in constant declaration, e.g. as bb31_t{11} - inline constexpr bb31_t(int a) : val(((uint64_t)a << 32) % MOD) {} - - static inline const bb31_t one() { return bb31_t(1); } - inline bb31_t& operator+=(bb31_t b) { return *this; } - inline bb31_t& operator-=(bb31_t b) { return *this; } - inline bb31_t& operator*=(bb31_t b) { return *this; } - inline bb31_t& operator^=(int b) { return *this; } - inline bb31_t& sqr() { return *this; } - friend bb31_t operator+(bb31_t a, bb31_t b) { return a += b; } - friend bb31_t operator-(bb31_t a, bb31_t b) { return a -= b; } - friend bb31_t operator*(bb31_t a, bb31_t b) { return a *= b; } - friend bb31_t operator^(bb31_t a, uint32_t b) { return a ^= b; } - inline void zero() { val = 0; } - inline bool is_zero() const { return val==0; } - inline operator uint32_t() const - { return ((val*M)*(uint64_t)MOD + val) >> 32; } - inline void to() { val = ((uint64_t)val<<32) % MOD; } - inline void from() { val = *this; } -# if defined(_GLIBCXX_IOSTREAM) || defined(_IOSTREAM_) // non-standard - friend std::ostream& operator<<(std::ostream& os, const bb31_t& obj) - { - auto f = os.flags(); - os << "0x" << std::hex << (uint32_t)obj; - os.flags(f); - return os; + __host__ __device__ constexpr bb31_t(int a) : bb31_base(a) {} + __host__ __device__ constexpr bb31_t(uint32_t a) : bb31_base(a) {} + + inline bb31_t reciprocal() const + { + bb31_t x11, xff, ret = *this; + + x11 = sqr_n_mul(ret, 4, ret); // 0b10001 + ret = sqr_n_mul(x11, 1, x11); // 0b110011 + ret = sqr_n_mul(ret, 1, x11); // 0b1110111 + xff = sqr_n_mul(ret, 1, x11); // 0b11111111 + ret = sqr_n_mul(ret, 8, xff); // 0b111011111111111 + ret = sqr_n_mul(ret, 8, xff); // 0b11101111111111111111111 + ret = sqr_n_mul(ret, 8, xff); // 0b1110111111111111111111111111111 + + return ret; + } + friend inline bb31_t operator/(int one, bb31_t a) + { assert(one == 1); return a.reciprocal(); } + friend inline bb31_t operator/(bb31_t a, bb31_t b) + { return a * b.reciprocal(); } + inline bb31_t& operator/=(const bb31_t a) + { *this *= a.reciprocal(); return *this; } + + inline bb31_t heptaroot() const + { + bb31_t x03, x18, x1b, ret = *this; + + x03 = sqr_n_mul(ret, 1, ret); // 0b11 + x18 = sqr_n(x03, 3); // 0b11000 + x1b = x18*x03; // 0b11011 + ret = x18*x1b; // 0b110011 + ret = sqr_n_mul(ret, 6, x1b); // 0b110011011011 + ret = sqr_n_mul(ret, 6, x1b); // 0b110011011011011011 + ret = sqr_n_mul(ret, 6, x1b); // 0b110011011011011011011011 + ret = sqr_n_mul(ret, 6, x1b); // 0b110011011011011011011011011011 + ret = sqr_n_mul(ret, 1, *this); // 0b1100110110110110110110110110111 + + return ret; } -# endif }; -# if defined(__GNUC__) || defined(__clang__) -# pragma GCC diagnostic pop + +class __align__(16) bb31_4_t { + union { bb31_t c[4]; uint32_t u[4]; }; + + static const uint32_t MOD = 0x78000001; + static const uint32_t M = 0x77ffffff; +#ifdef BABY_BEAR_CANONICAL + static const uint32_t BETA = 0x37ffffe9; // (11<<32)%MOD +#else // such as RISC Zero + static const uint32_t BETA = 0x40000018; // (-11<<32)%MOD +#endif + +public: + static const uint32_t degree = 4; + using mem_t = bb31_4_t; + + inline bb31_t& operator[](size_t i) { return c[i]; } + inline const bb31_t& operator[](size_t i) const { return c[i]; } + inline size_t len() const { return 4; } + + inline bb31_4_t() {} + inline bb31_4_t(bb31_t a) { c[0] = a; u[1] = u[2] = u[3] = 0; } + // this is used in constant declaration, e.g. as bb31_4_t{1, 2, 3, 4} + __host__ __device__ __forceinline__ bb31_4_t(int a) + { c[0] = bb31_t{a}; u[1] = u[2] = u[3] = 0; } + __host__ __device__ __forceinline__ bb31_4_t(int d, int f, int g, int h) + { c[0] = bb31_t{d}; c[1] = bb31_t{f}; c[2] = bb31_t{g}; c[3] = bb31_t{h}; } + + // Polynomial multiplication modulo x^4 - BETA + friend __device__ __noinline__ bb31_4_t operator*(bb31_4_t a, bb31_4_t b) + { + bb31_4_t ret; + +# ifdef __CUDA_ARCH__ +# ifdef __GNUC__ +# define asm __asm__ __volatile__ +# else +# define asm asm volatile # endif + // ret[0] = a[0]*b[0] + BETA*(a[1]*b[3] + a[2]*b[2] + a[3]*b[1]); + asm("{ .reg.b32 %lo, %hi, %m; .reg.pred %p;\n\t" + "mul.lo.u32 %lo, %4, %6; mul.hi.u32 %hi, %4, %6;\n\t" + "mad.lo.cc.u32 %lo, %3, %7, %lo; madc.hi.u32 %hi, %3, %7, %hi;\n\t" + "mad.lo.cc.u32 %lo, %2, %8, %lo; madc.hi.u32 %hi, %2, %8, %hi;\n\t" + "setp.ge.u32 %p, %hi, %9;\n\t" + "@%p sub.u32 %hi, %hi, %9;\n\t" + + "mul.lo.u32 %m, %lo, %10;\n\t" + "mad.lo.cc.u32 %lo, %m, %9, %lo; madc.hi.u32 %hi, %m, %9, %hi;\n\t" + //"setp.ge.u32 %p, %hi, %9;\n\t" + //"@%p sub.u32 %hi, %hi, %9;\n\t" + + "mul.lo.u32 %lo, %hi, %11; mul.hi.u32 %hi, %hi, %11;\n\t" + "mad.lo.cc.u32 %lo, %1, %5, %lo; madc.hi.u32 %hi, %1, %5, %hi;\n\t" + + "mul.lo.u32 %m, %lo, %10;\n\t" + "mad.lo.cc.u32 %lo, %m, %9, %lo; madc.hi.u32 %0, %m, %9, %hi;\n\t" + "setp.ge.u32 %p, %0, %9;\n\t" + "@%p sub.u32 %0, %0, %9;\n\t" + "}" : "=r"(ret.u[0]) + : "r"(a.u[0]), "r"(a.u[1]), "r"(a.u[2]), "r"(a.u[3]), + "r"(b.u[0]), "r"(b.u[1]), "r"(b.u[2]), "r"(b.u[3]), + "r"(MOD), "r"(M), "r"(BETA)); + + // ret[1] = a[0]*b[1] + a[1]*b[0] + BETA*(a[2]*b[3] + a[3]*b[2]); + asm("{ .reg.b32 %lo, %hi, %m; .reg.pred %p;\n\t" + "mul.lo.u32 %lo, %4, %7; mul.hi.u32 %hi, %4, %7;\n\t" + "mad.lo.cc.u32 %lo, %3, %8, %lo; madc.hi.u32 %hi, %3, %8, %hi;\n\t" + + "mul.lo.u32 %m, %lo, %10;\n\t" + "mad.lo.cc.u32 %lo, %m, %9, %lo; madc.hi.u32 %hi, %m, %9, %hi;\n\t" + //"setp.ge.u32 %p, %hi, %9;\n\t" + //"@%p sub.u32 %hi, %hi, %9;\n\t" + + "mul.lo.u32 %lo, %hi, %11; mul.hi.u32 %hi, %hi, %11;\n\t" + "mad.lo.cc.u32 %lo, %2, %5, %lo; madc.hi.u32 %hi, %2, %5, %hi;\n\t" + "mad.lo.cc.u32 %lo, %1, %6, %lo; madc.hi.u32 %hi, %1, %6, %hi;\n\t" + "setp.ge.u32 %p, %hi, %9;\n\t" + "@%p sub.u32 %hi, %hi, %9;\n\t" + + "mul.lo.u32 %m, %lo, %10;\n\t" + "mad.lo.cc.u32 %lo, %m, %9, %lo; madc.hi.u32 %0, %m, %9, %hi;\n\t" + "setp.ge.u32 %p, %0, %9;\n\t" + "@%p sub.u32 %0, %0, %9;\n\t" + "}" : "=r"(ret.u[1]) + : "r"(a.u[0]), "r"(a.u[1]), "r"(a.u[2]), "r"(a.u[3]), + "r"(b.u[0]), "r"(b.u[1]), "r"(b.u[2]), "r"(b.u[3]), + "r"(MOD), "r"(M), "r"(BETA)); + + // ret[2] = a[0]*b[2] + a[1]*b[1] + a[2]*b[0] + BETA*(a[3]*b[3]); + asm("{ .reg.b32 %lo, %hi, %m; .reg.pred %p;\n\t" + "mul.lo.u32 %lo, %4, %8; mul.hi.u32 %hi, %4, %8;\n\t" + + "mul.lo.u32 %m, %lo, %10;\n\t" + "mad.lo.cc.u32 %lo, %m, %9, %lo; madc.hi.u32 %hi, %m, %9, %hi;\n\t" + //"setp.ge.u32 %p, %hi, %9;\n\t" + //"@%p sub.u32 %hi, %hi, %9;\n\t" + + "mul.lo.u32 %lo, %hi, %11; mul.hi.u32 %hi, %hi, %11;\n\t" + "mad.lo.cc.u32 %lo, %3, %5, %lo; madc.hi.u32 %hi, %3, %5, %hi;\n\t" + "mad.lo.cc.u32 %lo, %2, %6, %lo; madc.hi.u32 %hi, %2, %6, %hi;\n\t" + "mad.lo.cc.u32 %lo, %1, %7, %lo; madc.hi.u32 %hi, %1, %7, %hi;\n\t" + "setp.ge.u32 %p, %hi, %9;\n\t" + "@%p sub.u32 %hi, %hi, %9;\n\t" + + "mul.lo.u32 %m, %lo, %10;\n\t" + "mad.lo.cc.u32 %lo, %m, %9, %lo; madc.hi.u32 %0, %m, %9, %hi;\n\t" + "setp.ge.u32 %p, %0, %9;\n\t" + "@%p sub.u32 %0, %0, %9;\n\t" + "}" : "=r"(ret.u[2]) + : "r"(a.u[0]), "r"(a.u[1]), "r"(a.u[2]), "r"(a.u[3]), + "r"(b.u[0]), "r"(b.u[1]), "r"(b.u[2]), "r"(b.u[3]), + "r"(MOD), "r"(M), "r"(BETA)); + + // ret[3] = a[0]*b[3] + a[1]*b[2] + a[2]*b[1] + a[3]*b[0]; + asm("{ .reg.b32 %lo, %hi, %m; .reg.pred %p;\n\t" + "mul.lo.u32 %lo, %4, %5; mul.hi.u32 %hi, %4, %5;\n\t" + "mad.lo.cc.u32 %lo, %3, %6, %lo; madc.hi.u32 %hi, %3, %6, %hi;\n\t" + "mad.lo.cc.u32 %lo, %2, %7, %lo; madc.hi.u32 %hi, %2, %7, %hi;\n\t" + "mad.lo.cc.u32 %lo, %1, %8, %lo; madc.hi.u32 %hi, %1, %8, %hi;\n\t" + "setp.ge.u32 %p, %hi, %9;\n\t" + "@%p sub.u32 %hi, %hi, %9;\n\t" + + "mul.lo.u32 %m, %lo, %10;\n\t" + "mad.lo.cc.u32 %lo, %m, %9, %lo; madc.hi.u32 %0, %m, %9, %hi;\n\t" + "setp.ge.u32 %p, %0, %9;\n\t" + "@%p sub.u32 %0, %0, %9;\n\t" + "}" : "=r"(ret.u[3]) + : "r"(a.u[0]), "r"(a.u[1]), "r"(a.u[2]), "r"(a.u[3]), + "r"(b.u[0]), "r"(b.u[1]), "r"(b.u[2]), "r"(b.u[3]), + "r"(MOD), "r"(M), "r"(BETA)); +# undef asm +# else + union { uint64_t ul; uint32_t u[2]; }; + + // ret[0] = a[0]*b[0] + BETA*(a[1]*b[3] + a[2]*b[2] + a[3]*b[1]); + ul = a.u[1] * (uint64_t)b.u[3]; + ul += a.u[2] * (uint64_t)b.u[2]; + ul += a.u[3] * (uint64_t)b.u[1]; if (u[1] >= MOD) u[1] -= MOD; + ul += (u[0] * M) * (uint64_t)MOD; // if (u[1] >= MOD) u[1] -= MOD; + ul = u[1] * (uint64_t)BETA; + ul += a.u[0] * (uint64_t)b.u[0]; + ul += (u[0] * M) * (uint64_t)MOD; + ret.u[0] = u[1] >= MOD ? u[1] - MOD : u[1]; + + // ret[1] = a[0]*b[1] + a[1]*b[0] + BETA*(a[2]*b[3] + a[3]*b[2]); + ul = a.u[2] * (uint64_t)b.u[3]; + ul += a.u[3] * (uint64_t)b.u[2]; + ul += (u[0] * M) * (uint64_t)MOD; // if (u[1] >= MOD) u[1] -= MOD; + ul = u[1] * (uint64_t)BETA; + ul += a.u[0] * (uint64_t)b.u[1]; + ul += a.u[1] * (uint64_t)b.u[0]; if (u[1] >= MOD) u[1] -= MOD; + ul += (u[0] * M) * (uint64_t)MOD; + ret.u[1] = u[1] >= MOD ? u[1] - MOD : u[1]; + + // ret[2] = a[0]*b[2] + a[1]*b[1] + a[2]*b[0] + BETA*(a[3]*b[3]); + ul = a.u[3] * (uint64_t)b.u[3]; + ul += (u[0] * M) * (uint64_t)MOD; // if (u[1] >= MOD) u[1] -= MOD; + ul = u[1] * (uint64_t)BETA; + ul += a.u[0] * (uint64_t)b.u[2]; + ul += a.u[1] * (uint64_t)b.u[1]; + ul += a.u[2] * (uint64_t)b.u[0]; if (u[1] >= MOD) u[1] -= MOD; + ul += (u[0] * M) * (uint64_t)MOD; + ret.u[2] = u[1] >= MOD ? u[1] - MOD : u[1]; + + // ret[3] = a[0]*b[3] + a[1]*b[2] + a[2]*b[1] + a[3]*b[0]; + ul = a.u[0] * (uint64_t)b.u[3]; + ul += a.u[1] * (uint64_t)b.u[2]; + ul += a.u[2] * (uint64_t)b.u[1]; + ul += a.u[3] * (uint64_t)b.u[0]; if (u[1] >= MOD) u[1] -= MOD; + ul += (u[0] * M) * (uint64_t)MOD; + ret.u[3] = u[1] >= MOD ? u[1] - MOD : u[1]; # endif + + return ret; + } + inline bb31_4_t& operator*=(const bb31_4_t& b) + { return *this = *this * b; } + + friend __device__ __noinline__ bb31_4_t operator*(bb31_4_t a, bb31_t b) + { + bb31_4_t ret; + + for (size_t i = 0; i < 4; i++) + ret[i] = a[i] * b; + + return ret; + } + friend inline bb31_4_t operator*(bb31_t b, const bb31_4_t& a) + { return a * b; } + inline bb31_4_t& operator*=(bb31_t b) + { return *this = *this * b; } + + friend inline bb31_4_t operator+(const bb31_4_t& a, const bb31_4_t& b) + { + bb31_4_t ret; + + for (size_t i = 0; i < 4; i++) + ret[i] = a[i] + b[i]; + + return ret; + } + inline bb31_4_t& operator+=(const bb31_4_t& b) + { return *this = *this + b; } + + friend inline bb31_4_t operator+(const bb31_4_t& a, bb31_t b) + { + bb31_4_t ret; + + ret[0] = a[0] + b; + ret[1] = a[1]; + ret[2] = a[2]; + ret[3] = a[3]; + + return ret; + } + friend inline bb31_4_t operator+(bb31_t b, const bb31_4_t& a) + { return a + b; } + inline bb31_4_t& operator+=(bb31_t b) + { c[0] += b; return *this; } + + friend inline bb31_4_t operator-(const bb31_4_t& a, const bb31_4_t& b) + { + bb31_4_t ret; + + for (size_t i = 0; i < 4; i++) + ret[i] = a[i] - b[i]; + + return ret; + } + inline bb31_4_t& operator-=(const bb31_4_t& b) + { return *this = *this - b; } + + friend inline bb31_4_t operator-(const bb31_4_t& a, bb31_t b) + { + bb31_4_t ret; + + ret[0] = a[0] - b; + ret[1] = a[1]; + ret[2] = a[2]; + ret[3] = a[3]; + + return ret; + } + friend inline bb31_4_t operator-(bb31_t b, const bb31_4_t& a) + { + bb31_4_t ret; + + ret[0] = b - a[0]; + ret[1] = -a[1]; + ret[2] = -a[2]; + ret[3] = -a[3]; + + return ret; + } + inline bb31_4_t& operator-=(bb31_t b) + { c[0] -= b; return *this; } + + inline bb31_4_t reciprocal() const + { + const bb31_t beta{BETA}; + + // don't bother with breaking this down, 1/x dominates. + bb31_t b0 = c[0]*c[0] - beta*(c[1]*bb31_t{u[3]<<1} - c[2]*c[2]); + bb31_t b2 = c[0]*bb31_t{u[2]<<1} - c[1]*c[1] - beta*(c[3]*c[3]); + + bb31_t inv = 1/(b0*b0 - beta*b2*b2); + + b0 *= inv; + b2 *= inv; + + bb31_4_t ret; + bb31_t beta_b2 = beta*b2; + ret[0] = c[0]*b0 - c[2]*beta_b2; + ret[1] = c[3]*beta_b2 - c[1]*b0; + ret[2] = c[2]*b0 - c[0]*b2; + ret[3] = c[1]*b2 - c[3]*b0; + + return ret; + } + friend inline bb31_4_t operator/(int one, const bb31_4_t& a) + { assert(one == 1); return a.reciprocal(); } + friend inline bb31_4_t operator/(const bb31_4_t& a, const bb31_4_t& b) + { return a * b.reciprocal(); } + friend inline bb31_4_t operator/(bb31_t a, const bb31_4_t& b) + { return b.reciprocal() * a; } + friend inline bb31_4_t operator/(const bb31_4_t& a, bb31_t b) + { return a * b.reciprocal(); } + inline bb31_4_t& operator/=(const bb31_4_t& a) + { return *this *= a.reciprocal(); } + inline bb31_4_t& operator/=(bb31_t a) + { return *this *= a.reciprocal(); } + + inline bool is_one() const + { return c[0].is_one() & u[1]==0 & u[2]==0 & u[3]==0; } + inline bool is_zero() const + { return u[0]==0 & u[1]==0 & u[2]==0 & u[3]==0; } +# undef inline + +public: + friend inline bool operator==(const bb31_4_t& a, const bb31_4_t& b) + { return a.u[0]==b.u[0] & a.u[1]==b.u[1] & a.u[2]==b.u[2] & a.u[3]==b.u[3]; } + friend inline bool operator!=(const bb31_4_t& a, const bb31_4_t& b) + { return a.u[0]!=b.u[0] | a.u[1]!=b.u[1] | a.u[2]!=b.u[2] | a.u[3]!=b.u[3]; } + +# if defined(_GLIBCXX_IOSTREAM) || defined(_IOSTREAM_) // non-standard + friend std::ostream& operator<<(std::ostream& os, const bb31_4_t& a) + { + os << "[" << a.c[0] << ", " << a.c[1] << ", " a.c[2] << ", " << a.c[3] << "]"; + return os; + } +# endif +}; + typedef bb31_t fr_t; +typedef bb31_4_t fr4_t; + +#endif #endif diff --git a/ff/bb31_t.cuh b/ff/bb31_t.cuh index 9eea5d9..6658049 100644 --- a/ff/bb31_t.cuh +++ b/ff/bb31_t.cuh @@ -211,6 +211,12 @@ public: if (p < 2) asm("trap;"); + if (p == 7) { + bb31_t temp = sqr_n_mul(*this, 1, *this); + *this = sqr_n_mul(temp, 1, *this); + return *this; + } + bb31_t sqr = *this; if ((p&1) == 0) { do { diff --git a/ff/gl64_t.cuh b/ff/gl64_t.cuh index eb7164d..03ac4de 100644 --- a/ff/gl64_t.cuh +++ b/ff/gl64_t.cuh @@ -2,16 +2,10 @@ // Licensed under the Apache License, Version 2.0, see LICENSE for details. // SPDX-License-Identifier: Apache-2.0 -#ifndef __SPPARK_FF_GL64_T_CUH__ +#if defined(__CUDACC__) && !defined(__SPPARK_FF_GL64_T_CUH__) #define __SPPARK_FF_GL64_T_CUH__ # include - -namespace gl64_device { - static __device__ __constant__ /*const*/ uint32_t W = 0xffffffffU; -} - -#ifdef __CUDA_ARCH__ # define inline __device__ __forceinline__ # ifdef __GNUC__ # define asm __asm__ __volatile__ @@ -19,6 +13,10 @@ namespace gl64_device { # define asm asm volatile # endif +namespace gl64_device { + static __device__ __constant__ /*const*/ uint32_t W = 0xffffffffU; +} + #ifdef GL64_PARTIALLY_REDUCED // // This variant operates with partially reduced values, ones less than @@ -51,8 +49,12 @@ public: inline size_t len() const { return 1; } inline gl64_t() {} - inline gl64_t(const uint64_t a) { val = a; to(); } - inline gl64_t(const uint64_t *p) { val = *p; to(); } +#ifdef __CUDA_ARCH__ + inline gl64_t(uint64_t a) : val(a) { to(); } +#else + __host__ constexpr gl64_t(uint64_t a) : val(a) {} +#endif + inline gl64_t(const uint64_t *p) : val(*p) { to(); } inline operator uint64_t() const { auto ret = *this; ret.from(); return ret.val; } @@ -594,9 +596,23 @@ public: inline void shfl_bfly(uint32_t laneMask) { val = __shfl_xor_sync(0xFFFFFFFF, val, laneMask); } -}; # undef inline # undef asm -#endif + +public: + friend inline bool operator==(gl64_t a, gl64_t b) + { return a.val == b.val; } + friend inline bool operator!=(gl64_t a, gl64_t b) + { return a.val != b.val; } +# if defined(_GLIBCXX_IOSTREAM) || defined(_IOSTREAM_) // non-standard + friend std::ostream& operator<<(std::ostream& os, const gl64_t& obj) + { + auto f = os.flags(); + os << "0x" << std::hex << obj.val; + os.flags(f); + return os; + } +# endif +}; #endif diff --git a/ff/goldilocks.hpp b/ff/goldilocks.hpp index deae186..a909a2d 100644 --- a/ff/goldilocks.hpp +++ b/ff/goldilocks.hpp @@ -2,36 +2,7 @@ // Licensed under the Apache License, Version 2.0, see LICENSE for details. // SPDX-License-Identifier: Apache-2.0 -#ifdef __NVCC__ -# include "gl64_t.cuh" // device-side field types -# ifndef __CUDA_ARCH__ // host-side stand-in to make CUDA code compile, -# include // not to produce correct result... - -# if defined(__GNUC__) || defined(__clang__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wunused-parameter" -# endif -class gl64_t { - uint64_t val; -public: - using mem_t = gl64_t; - static const uint32_t degree = 1; - static const uint64_t MOD = 0xffffffff00000001U; - - inline gl64_t() {} - inline gl64_t(uint64_t a) : val(a) {} - inline operator uint64_t() const { return val; } - static inline const gl64_t one() { return 1; } - inline gl64_t& operator+=(gl64_t b) { return *this; } - inline gl64_t& operator-=(gl64_t b) { return *this; } - inline gl64_t& operator*=(gl64_t b) { return *this; } - inline gl64_t& operator^=(int p) { return *this; } - inline gl64_t& sqr() { return *this; } - inline void zero() { val = 0; } -}; -# if defined(__GNUC__) || defined(__clang__) -# pragma GCC diagnostic pop -# endif -# endif +#ifdef __CUDACC__ +# include "gl64_t.cuh" // CUDA device-side field types typedef gl64_t fr_t; #endif diff --git a/ff/mersenne31.hpp b/ff/mersenne31.hpp new file mode 100644 index 0000000..84097ea --- /dev/null +++ b/ff/mersenne31.hpp @@ -0,0 +1,419 @@ +// Copyright Supranational LLC +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 + +#ifndef __SPPARK_FF_MERSENNE31_HPP__ +#define __SPPARK_FF_MERSENNE31_HPP__ + +#ifdef __CUDACC__ // CUDA device-side field types +# include "mont32_t.cuh" +# define inline __device__ __forceinline__ + +using mrs31_base = mont32_t<31, 0x7fffffff, 0x80000001, 4, 2>; + +struct mrs31_t : public mrs31_base { + // mem_t bridges the host-side non-Montgomery representation + class mem_t { friend mrs31_t; + uint32_t val; + + public: + inline operator mrs31_t() const + { return mrs31_t{val} << 1; } + inline mem_t& operator=(const mrs31_t& a) + { val = *(a >> 1); return *this; } + }; + + inline mrs31_t() {} + inline mrs31_t(const mrs31_base& a) : mrs31_base(a) {} + inline mrs31_t(const uint32_t *p) : mrs31_base(p) {} + inline mrs31_t(const mem_t* p) { *this = *p; } + // this is used in constant declaration, e.g. as mrs31_t{11} + __host__ __device__ constexpr mrs31_t(int a) : mrs31_base(a) {} + __host__ __device__ constexpr mrs31_t(uint32_t a) : mrs31_base(a) {} + + inline operator uint32_t() const { return *(*this >> 1); } + inline void to() { *this <<= 1; } + inline void from() { *this >>= 1; } + inline void store(mem_t* p) const { *p = *this; } + + inline mrs31_t reciprocal() const + { + mrs31_t x05, x0f, x7d, xff, ret = *this; + + x05 = sqr_n_mul(ret, 2, ret); // 0b101 + x0f = sqr_n_mul(x05, 1, x05); // 0b1111 + x7d = sqr_n_mul(x0f, 3, x05); // 0b1111101 + xff = sqr_n_mul(x7d, 1, x05); // 0b11111111 + ret = sqr_n_mul(xff, 8, xff); // 0b1111111111111111 + ret = sqr_n_mul(ret, 8, xff); // 0b111111111111111111111111 + ret = sqr_n_mul(ret, 7, x7d); // 0b1111111111111111111111111111101 + + return ret; + } + friend inline mrs31_t operator/(int one, mrs31_t a) + { if (one != 1) asm("trap;"); return a.reciprocal(); } + friend inline mrs31_t operator/(mrs31_t a, mrs31_t b) + { return a * b.reciprocal(); } + inline mrs31_t& operator/=(const mrs31_t a) + { *this *= a.reciprocal(); return *this; } + + inline mrs31_t sqrt() const + { return sqr_n(*this, 29); } + friend inline mrs31_t sqrt(mrs31_t a) + { return a.sqrt(); } +}; +# undef inline + +typedef mrs31_t fr_t; + +#else +# include +# include +# if defined(__CUDACC__) || defined(__HIPCC__) +# define inline __host__ __device__ __forceinline__ +# endif + +class mrs31_t { +private: + uint32_t val; + + const static uint32_t MOD = 0x7fffffff; + +public: + // mem_t is a pass-through to mirror the corresponding CUDA bridge + class mem_t { friend mrs31_t; + uint32_t val; + + public: + inline operator mrs31_t() const + { return mrs31_t{val}; } + inline mem_t& operator=(const mrs31_t& a) + { val = a; return *this; } + }; + + static const uint32_t degree = 1; + static constexpr size_t bit_length() { return 31; } + + inline uint32_t& operator[](size_t i) { return val; (void)i; } + inline const uint32_t& operator[](size_t i) const { return val; (void)i; } + inline uint32_t& operator*() { return val; } + inline uint32_t operator*() const { return val; } + inline size_t len() const { return 1; } + + inline mrs31_t() {} + inline constexpr mrs31_t(int a) : val(a) {} + inline constexpr mrs31_t(uint32_t a) : val(a) {} + inline mrs31_t(const uint32_t* p) : val(*p) {} + inline mrs31_t(const mem_t* p) { *this = *p; } + + inline operator uint32_t() const { return val; } + inline void store(uint32_t* p) const { *p = val; } + inline void store(mem_t* p) const { *p = *this; } + inline mrs31_t& operator=(uint32_t b) { val = b; return *this; } + + inline mrs31_t& operator+=(const mrs31_t b) + { + val += b.val; + if (val >= MOD) + val -= MOD; + + return *this; + } + friend inline mrs31_t operator+(mrs31_t a, const mrs31_t b) + { return a += b; } + + inline mrs31_t& operator<<=(uint32_t l) + { + l %= 31; + + if (l > 2) { + uint64_t tmp = (uint64_t)val << l; + val = ((uint32_t)tmp & MOD) + (uint32_t)(tmp >> 31); + if (val >= MOD) + val -= MOD; + } else { + while (l--) { + val <<= 1; + if (val >= MOD) + val -= MOD; + } + } + + return *this; + } + friend inline mrs31_t operator<<(mrs31_t a, uint32_t l) + { return a <<= l; } + + inline mrs31_t& operator>>=(uint32_t r) + { + r %= 31; + + if (r > 2) { + uint32_t red = val & ((1<> r); + } else { + while (r--) { + val += val&1 ? MOD : 0; + val >>= 1; + } + } + + return *this; + } + friend inline mrs31_t operator>>(mrs31_t a, uint32_t r) + { return a >>= r; } + + inline mrs31_t& operator-=(const mrs31_t b) + { + bool borrow = val < b.val; + + val -= b.val; + if (borrow) + val += MOD; + + return *this; + } + friend inline mrs31_t operator-(mrs31_t a, const mrs31_t b) + { return a -= b; } + + inline mrs31_t cneg(bool flag) + { + if (flag && val != 0) + val = MOD - val; + + return *this; + } + static inline mrs31_t cneg(mrs31_t a, bool flag) + { return a.cneg(flag); } + inline mrs31_t operator-() const + { return cneg(*this, true); } + + static inline const mrs31_t one() { return mrs31_t{1}; } + inline bool is_one() const { return val == 1; } + inline bool is_zero() const { return val == 0; } + inline void zero() { val = 0; } + + friend inline mrs31_t czero(const mrs31_t a, int set_z) + { return set_z ? mrs31_t{0} : a; } + + static inline mrs31_t csel(const mrs31_t a, const mrs31_t b, int sel_a) + { return sel_a ? a : b; } + +private: + inline mrs31_t& mul(const mrs31_t b) + { + uint64_t tmp = val * (uint64_t)b.val; + + val = ((uint32_t)tmp & MOD) + (uint32_t)(tmp >> 31); + if (val >= MOD) + val -= MOD; + + return *this; + } + +public: + friend inline mrs31_t operator*(mrs31_t a, const mrs31_t b) + { return a.mul(b); } + inline mrs31_t& operator*=(const mrs31_t a) + { return mul(a); } + + // raise to a variable power, variable in respect to threadIdx, + // but mind the ^ operator's precedence! + inline mrs31_t& operator^=(uint32_t p) + { + mrs31_t sqr = *this; + *this = csel(val, 1, p&1); + + #pragma unroll 1 + while (p >>= 1) { + sqr.mul(sqr); + if (p&1) + mul(sqr); + } + + return *this; + } + friend inline mrs31_t operator^(mrs31_t a, uint32_t p) + { return a ^= p; } + inline mrs31_t operator()(uint32_t p) + { return *this^p; } + + // raise to a constant power, e.g. x^7, to be unrolled at compile time + inline mrs31_t& operator^=(int p) + { + assert(p >= 2); + + mrs31_t sqr = *this; + if ((p&1) == 0) { + do { + sqr.mul(sqr); + p >>= 1; + } while ((p&1) == 0); + *this = sqr; + } + for (p >>= 1; p; p >>= 1) { + sqr.mul(sqr); + if (p&1) + mul(sqr); + } + + return *this; + } + friend inline mrs31_t operator^(mrs31_t a, int p) + { return a ^= p; } + inline mrs31_t operator()(int p) + { return *this^p; } + friend inline mrs31_t sqr(mrs31_t a) + { return a.sqr(); } + inline mrs31_t& sqr() + { return mul(*this); } + + template + static inline mrs31_t dot_product(const mrs31_t a[T], const mrs31_t b[T]) + { + union { uint64_t acc; uint32_t u[2]; }; + size_t i = 1; + + acc = *a[0] * (uint64_t)*b[0]; + + if ((T&1) == 0) { + acc += *a[i] * (uint64_t)*b[i]; + i++; + } + for (; i < T; i += 2) { + acc += *a[i] * (uint64_t)*b[i]; + acc += *a[i+1] * (uint64_t)*b[i+1]; + if (u[1] >= MOD) + u[1] -= MOD; + } + + uint32_t ret = u[0] + (u[1] << 1); + + if (ret < u[0]) + ret += 2; + if (ret >= MOD) + ret -= MOD; + + return mrs31_t{ret}; + } + + template + static inline mrs31_t dot_product(mrs31_t a0, mrs31_t b0, + const mrs31_t a[T-1], const mrs31_t* b, + size_t stride_b = 1) + { + union { uint64_t acc; uint32_t u[2]; }; + size_t i = 0; + + acc = *a0 * (uint64_t)*b0; + + if ((T&1) == 0) { + acc += *a[i] * (uint64_t)*b[0]; + i++, b += stride_b; + } + for (; i < T-1; i += 2) { + acc += *a[i] * (uint64_t)*b[0]; + b += stride_b; + acc += *a[i+1] * (uint64_t)*b[0]; + b += stride_b; + if (u[1] >= MOD) + u[1] -= MOD; + } + + uint32_t ret = u[0] + (u[1] << 1); + + if (ret < u[0]) + ret += 2; + if (ret >= MOD) + ret -= MOD; + + return mrs31_t{ret}; + } + +private: + static inline mrs31_t sqr_n(mrs31_t s, uint32_t n) + { + #pragma unroll 4 + while (n--) { + uint64_t tmp = s.val * (uint64_t)s.val; + + s.val = ((uint32_t)tmp & MOD) + (uint32_t)(tmp >> 31); + + if (s.val >= MOD) + s.val -= MOD; + } + + return s; + } + + static inline mrs31_t sqr_n_mul(mrs31_t s, uint32_t n, mrs31_t m) + { + s = sqr_n(s, n); + s.mul(m); + + return s; + } + +public: + inline mrs31_t sqrt() const + { return sqr_n(*this, 29); } + friend inline mrs31_t sqrt(mrs31_t a) + { return a.sqrt(); } + + inline mrs31_t reciprocal() const + { + mrs31_t x05, x0f, x7d, xff, ret = *this; + + x05 = sqr_n_mul(ret, 2, ret); // 0b101 + x0f = sqr_n_mul(x05, 1, x05); // 0b1111 + x7d = sqr_n_mul(x0f, 3, x05); // 0b1111101 + xff = sqr_n_mul(x7d, 1, x05); // 0b11111111 + ret = sqr_n_mul(xff, 8, xff); // 0b1111111111111111 + ret = sqr_n_mul(ret, 8, xff); // 0b111111111111111111111111 + ret = sqr_n_mul(ret, 7, x7d); // 0b1111111111111111111111111111101 + + return ret; + } + friend inline mrs31_t operator/(int one, mrs31_t a) + { assert(one == 1); return a.reciprocal(); } + friend inline mrs31_t operator/(mrs31_t a, mrs31_t b) + { return a * b.reciprocal(); } + inline mrs31_t& operator/=(const mrs31_t a) + { *this *= a.reciprocal(); return *this; } + +# if defined(__CUDACC__) +# undef inline + __device__ __forceinline__ void shfl_bfly(uint32_t laneMask) + { val = __shfl_xor_sync(0xFFFFFFFF, val, laneMask); } +# elif defined(__HIPCC__) +# undef inline + __device__ __forceinline__ void shfl_bfly(uint32_t laneMask) + { + uint32_t idx = (threadIdx.x ^ laneMask) << 2; + + val = __builtin_amdgcn_ds_bpermute(idx, val); + } +# endif + +public: + friend inline bool operator==(mrs31_t a, mrs31_t b) + { return a.val == b.val; } + friend inline bool operator!=(mrs31_t a, mrs31_t b) + { return a.val != b.val; } + +# if defined(_GLIBCXX_IOSTREAM) || defined(_IOSTREAM_) // non-standard + friend std::ostream& operator<<(std::ostream& os, const mrs31_t& obj) + { + auto f = os.flags(); + os << "0x" << std::hex << obj.val; + os.flags(f); + return os; + } +# endif +}; + +typedef mrs31_t fr_t; + +#endif +#endif /* __SPPARK_FF_MERSENNE31_HPP__ */ diff --git a/ff/mont32_t.cuh b/ff/mont32_t.cuh new file mode 100644 index 0000000..7050b82 --- /dev/null +++ b/ff/mont32_t.cuh @@ -0,0 +1,441 @@ +// Copyright Supranational LLC +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 + +#if defined(__CUDACC__) && !defined(__SPPARK_FF_MONT32_T_CUH__) +#define __SPPARK_FF_MONT32_T_CUH__ + +# include +# define inline __device__ __forceinline__ +# ifdef __GNUC__ +# define asm __asm__ __volatile__ +# else +# define asm asm volatile +# endif + +template +class mont32_t { +private: + uint32_t val; + +public: + using mem_t = mont32_t; + static const uint32_t degree = 1; + static constexpr size_t __device__ bit_length() { return N; } + + inline uint32_t& operator[](size_t i) { return val; } + inline uint32_t& operator*() { return val; } + inline const uint32_t& operator[](size_t i) const { return val; } + inline uint32_t operator*() const { return val; } + inline size_t len() const { return 1; } + + inline mont32_t() {} + inline mont32_t(const uint32_t *p) { val = *p; } + // this is used in constant declaration, e.g. as mont32_t{11} + __host__ __device__ constexpr mont32_t(int a) : val(((uint64_t)a << 32) % MOD) {} + __host__ __device__ constexpr mont32_t(uint32_t a) : val(a) {} + + inline operator uint32_t() const { return mul_by_1(); } + inline void store(uint32_t *p) const { *p = mul_by_1(); } + inline mont32_t& operator=(uint32_t b) { val = b; to(); return *this; } + + inline mont32_t& operator+=(const mont32_t b) + { + val += b.val; + if (N == 32) { + if (val < b.val || val >= MOD) val -= MOD; + } else { + if (val >= MOD) val -= MOD; + } + + return *this; + } + friend inline mont32_t operator+(mont32_t a, const mont32_t b) + { return a += b; } + + inline mont32_t& operator<<=(uint32_t l) + { + if (N == 32) { + while (l--) { + bool carry = val >> 31; + val <<= 1; + if (carry || val >= MOD) val -= MOD; + } + } else { + while (l--) { + val <<= 1; + if (val >= MOD) val -= MOD; + } + } + + return *this; + } + friend inline mont32_t operator<<(mont32_t a, uint32_t l) + { return a <<= l; } + + inline mont32_t& operator>>=(uint32_t r) + { + while (r >= 32) { + val = mul_by_1(); + r -= 32; + } + + if (r > 2) { + uint32_t lo, hi, red = (val * M0) & ((1<>= 1; + } + } + + return *this; + } + friend inline mont32_t operator>>(mont32_t a, uint32_t r) + { return a >>= r; } + + inline mont32_t& operator-=(const mont32_t b) + { + asm("{"); + asm(".reg.pred %brw;"); + asm("setp.lt.u32 %brw, %0, %1;" :: "r"(val), "r"(b.val)); + asm("sub.u32 %0, %0, %1;" : "+r"(val) : "r"(b.val)); + asm("@%brw add.u32 %0, %0, %1;" : "+r"(val) : "r"(MOD)); + asm("}"); + + return *this; + } + friend inline mont32_t operator-(mont32_t a, const mont32_t b) + { return a -= b; } + + inline mont32_t cneg(bool flag) + { + asm("{"); + asm(".reg.pred %flag;"); + asm("setp.ne.u32 %flag, %0, 0;" :: "r"(val)); + asm("@%flag setp.ne.u32 %flag, %0, 0;" :: "r"((int)flag)); + asm("@%flag sub.u32 %0, %1, %0;" : "+r"(val) : "r"(MOD)); + asm("}"); + + return *this; + } + static inline mont32_t cneg(mont32_t a, bool flag) + { return a.cneg(flag); } + inline mont32_t operator-() const + { return cneg(*this, true); } + + static inline const mont32_t one() { return mont32_t{ONE}; } + inline bool is_one() const { return val == ONE; } + inline bool is_zero() const { return val == 0; } + inline void zero() { val = 0; } + + friend inline mont32_t czero(const mont32_t a, int set_z) + { + mont32_t ret; + + asm("{"); + asm(".reg.pred %set_z;"); + asm("setp.ne.s32 %set_z, %0, 0;" : : "r"(set_z)); + asm("selp.u32 %0, 0, %1, %set_z;" : "=r"(ret.val) : "r"(a.val)); + asm("}"); + + return ret; + } + + static inline mont32_t csel(const mont32_t a, const mont32_t b, int sel_a) + { + mont32_t ret; + + asm("{"); + asm(".reg.pred %sel_a;"); + asm("setp.ne.s32 %sel_a, %0, 0;" :: "r"(sel_a)); + asm("selp.u32 %0, %1, %2, %sel_a;" : "=r"(ret.val) : "r"(a.val), "r"(b.val)); + asm("}"); + + return ret; + } + +private: + static inline uint32_t final_sub(uint32_t val) + { + asm("{"); + asm(".reg.pred %p;"); + if (N == 32) { + uint32_t carry; + + asm("addc.u32 %0, 0, 0;" : "=r"(carry)); + asm("setp.lt.u32 %p, %0, %1;" :: "r"(val), "r"(MOD)); + asm("@%p setp.eq.u32 %p, %0, 0;" :: "r"(carry)); + asm("@!%p sub.u32 %0, %0, %1;" : "+r"(val) : "r"(MOD)); + } else { + asm("setp.ge.u32 %p, %0, %1;" :: "r"(val), "r"(MOD)); + asm("@%p sub.u32 %0, %0, %1;" : "+r"(val) : "r"(MOD)); + } + asm("}"); + + return val; + } + + inline mont32_t& mul(const mont32_t b) + { + uint32_t tmp[2], red; + + asm("mul.lo.u32 %0, %2, %3; mul.hi.u32 %1, %2, %3;" + : "=r"(tmp[0]), "=r"(tmp[1]) + : "r"(val), "r"(b.val)); + asm("mul.lo.u32 %0, %1, %2;" : "=r"(red) : "r"(tmp[0]), "r"(M0)); + asm("mad.lo.cc.u32 %0, %2, %3, %0; madc.hi.cc.u32 %1, %2, %3, %1;" + : "+r"(tmp[0]), "+r"(tmp[1]) + : "r"(red), "r"(MOD)); + + val = final_sub(tmp[1]); + + return *this; + } + + inline uint32_t mul_by_1() const + { + uint32_t tmp[2], red; + + asm("mul.lo.u32 %0, %1, %2;" : "=r"(red) : "r"(val), "r"(M0)); + asm("mad.lo.cc.u32 %0, %2, %3, %4; madc.hi.u32 %1, %2, %3, 0;" + : "=r"(tmp[0]), "=r"(tmp[1]) + : "r"(red), "r"(MOD), "r"(val)); + + return tmp[1]; + } + +public: + friend inline mont32_t operator*(mont32_t a, const mont32_t b) + { return a.mul(b); } + inline mont32_t& operator*=(const mont32_t a) + { return mul(a); } + + // raise to a variable power, variable in respect to threadIdx, + // but mind the ^ operator's precedence! + inline mont32_t& operator^=(uint32_t p) + { + mont32_t sqr = *this; + *this = csel(val, ONE, p&1); + + #pragma unroll 1 + while (p >>= 1) { + sqr.mul(sqr); + if (p&1) + mul(sqr); + } + + return *this; + } + friend inline mont32_t operator^(mont32_t a, uint32_t p) + { return a ^= p; } + inline mont32_t operator()(uint32_t p) + { return *this^p; } + + // raise to a constant power, e.g. x^7, to be unrolled at compile time + inline mont32_t& operator^=(int p) + { + if (p < 2) + asm("trap;"); + + if (p == 7) { + mont32_t temp = sqr_n_mul(*this, 1, *this); + *this = sqr_n_mul(temp, 1, *this); + return *this; + } + + mont32_t sqr = *this; + if ((p&1) == 0) { + do { + sqr.mul(sqr); + p >>= 1; + } while ((p&1) == 0); + *this = sqr; + } + for (p >>= 1; p; p >>= 1) { + sqr.mul(sqr); + if (p&1) + mul(sqr); + } + + return *this; + } + friend inline mont32_t operator^(mont32_t a, int p) + { return a ^= p; } + inline mont32_t operator()(int p) + { return *this^p; } + friend inline mont32_t sqr(mont32_t a) + { return a.sqr(); } + inline mont32_t& sqr() + { return mul(*this); } + + inline void to() { mul(RR); } + inline void from() { val = mul_by_1(); } + + template + static inline mont32_t dot_product(const mont32_t a[T], const mont32_t b[T]) + { + uint32_t acc[2]; + + asm("mul.lo.u32 %0, %2, %3; mul.hi.u32 %1, %2, %3;" + : "=r"(acc[0]), "=r"(acc[1]) : "r"(*a[0]), "r"(*b[0])); + + if (N == 32) { + for (size_t i = 1; i < T; i++) { + asm("mad.lo.cc.u32 %0, %2, %3, %0; madc.hi.cc.u32 %1, %2, %3, %1;" + : "+r"(acc[0]), "+r"(acc[1]) : "r"(*a[i]), "r"(*b[i])); + acc[1] = final_sub(acc[1]); + } + } else { + size_t i = 1; + + if ((T&1) == 0) { + asm("mad.lo.cc.u32 %0, %2, %3, %0; madc.hi.u32 %1, %2, %3, %1;" + : "+r"(acc[0]), "+r"(acc[1]) : "r"(*a[i]), "r"(*b[i])); + i++; + } + for (; i < T; i += 2) { + asm("mad.lo.cc.u32 %0, %2, %3, %0; madc.hi.u32 %1, %2, %3, %1;" + : "+r"(acc[0]), "+r"(acc[1]) : "r"(*a[i]), "r"(*b[i])); + asm("mad.lo.cc.u32 %0, %2, %3, %0; madc.hi.u32 %1, %2, %3, %1;" + : "+r"(acc[0]), "+r"(acc[1]) : "r"(*a[i+1]), "r"(*b[i+1])); + acc[1] = final_sub(acc[1]); + } + } + + uint32_t red; + asm("mul.lo.u32 %0, %1, %2;" : "=r"(red) : "r"(acc[0]), "r"(M0)); + asm("mad.lo.cc.u32 %0, %2, %3, %0; madc.hi.cc.u32 %1, %2, %3, %1;" + : "+r"(acc[0]), "+r"(acc[1]) : "r"(red), "r"(MOD)); + + return final_sub(acc[1]); + } + + template + static inline mont32_t dot_product(mont32_t a0, mont32_t b0, + const mont32_t a[T-1], const mont32_t *b, + size_t stride_b = 1) + { + uint32_t acc[2]; + + asm("mul.lo.u32 %0, %2, %3; mul.hi.u32 %1, %2, %3;" + : "=r"(acc[0]), "=r"(acc[1]) : "r"(*a0), "r"(*b0)); + + if (N == 32) { + for (size_t i = 0; i < T-1; i++, b += stride_b) { + asm("mad.lo.cc.u32 %0, %2, %3, %0; madc.hi.cc.u32 %1, %2, %3, %1;" + : "+r"(acc[0]), "+r"(acc[1]) : "r"(*a[i]), "r"(*b[0])); + acc[1] = final_sub(acc[1]); + } + } else { + size_t i = 0; + + if ((T&1) == 0) { + asm("mad.lo.cc.u32 %0, %2, %3, %0; madc.hi.u32 %1, %2, %3, %1;" + : "+r"(acc[0]), "+r"(acc[1]) : "r"(*a[i]), "r"(*b[0])); + i++, b += stride_b; + } + for (; i < T-1; i += 2) { + asm("mad.lo.cc.u32 %0, %2, %3, %0; madc.hi.u32 %1, %2, %3, %1;" + : "+r"(acc[0]), "+r"(acc[1]) : "r"(*a[i]), "r"(*b[0])); + b += stride_b; + asm("mad.lo.cc.u32 %0, %2, %3, %0; madc.hi.u32 %1, %2, %3, %1;" + : "+r"(acc[0]), "+r"(acc[1]) : "r"(*a[i+1]), "r"(*b[0])); + b += stride_b; + acc[1] = final_sub(acc[1]); + } + } + + uint32_t red; + asm("mul.lo.u32 %0, %1, %2;" : "=r"(red) : "r"(acc[0]), "r"(M0)); + asm("mad.lo.cc.u32 %0, %2, %3, %0; madc.hi.cc.u32 %1, %2, %3, %1;" + : "+r"(acc[0]), "+r"(acc[1]) : "r"(red), "r"(MOD)); + + return final_sub(acc[1]); + } + + inline mont32_t reciprocal() const + { return *this ^ (MOD-2); } + friend inline mont32_t operator/(int one, mont32_t a) + { if (one != 1) asm("trap;"); return a.reciprocal(); } + friend inline mont32_t operator/(mont32_t a, mont32_t b) + { return a * b.reciprocal(); } + inline mont32_t& operator/=(const mont32_t a) + { return *this *= a.reciprocal(); } + + inline void shfl_bfly(uint32_t laneMask) + { val = __shfl_xor_sync(0xFFFFFFFF, val, laneMask); } + +protected: + static inline mont32_t sqr_n(mont32_t s, uint32_t n) + { + if (N == 32 || M0 > MOD) { + #pragma unroll 4 + while (n--) + s.sqr(); + } else { // +20% [for bb31_t::reciprocal()] + #pragma unroll 4 + while (n--) { + uint32_t tmp[2], red; + + asm("mul.lo.u32 %0, %2, %2; mul.hi.u32 %1, %2, %2;" + : "=r"(tmp[0]), "=r"(tmp[1]) + : "r"(s.val)); + asm("mul.lo.u32 %0, %1, %2;" : "=r"(red) : "r"(tmp[0]), "r"(M0)); + asm("mad.lo.cc.u32 %0, %2, %3, %0; madc.hi.u32 %1, %2, %3, %4;" + : "+r"(tmp[0]), "=r"(s.val) + : "r"(red), "r"(MOD), "r"(tmp[1])); + + if (n&1) + s.val = final_sub(s.val); + } + } + + return s; + } + + static inline mont32_t sqr_n_mul(mont32_t s, uint32_t n, mont32_t m) + { + s = sqr_n(s, n); + s.mul(m); + + return s; + } + +# undef inline +# undef asm + +public: + friend inline bool operator==(mont32_t a, mont32_t b) + { return a.val == b.val; } + friend inline bool operator!=(mont32_t a, mont32_t b) + { return a.val != b.val; } + +# if defined(_GLIBCXX_IOSTREAM) || defined(_IOSTREAM_) // non-standard + friend std::ostream& operator<<(std::ostream& os, const mont32_t& obj) + { + auto f = os.flags(); + uint32_t red = obj.val * M0; + uint64_t v = obj.val + red * (uint64_t)MOD; + os << "0x" << std::hex << (uint32_t)(v >> 32); + os.flags(f); + return os; + } +# endif +}; + +#endif /* __SPPARK_FF_MONT32_T_CUH__ */ diff --git a/ntt/kernels.cu b/ntt/kernels.cu index da44c26..7d358eb 100644 --- a/ntt/kernels.cu +++ b/ntt/kernels.cu @@ -196,7 +196,6 @@ void LDE_spread_distribute_powers(fr_t* out, fr_t* in, } index_t idx0 = blockDim.x * blockIdx.x; - uint32_t thread_pos = threadIdx.x & (blowup - 1); #if 0 index_t iters = domain_size / stride; @@ -225,15 +224,18 @@ void LDE_spread_distribute_powers(fr_t* out, fr_t* in, else __syncthreads(); - r.zero(); - - for (uint32_t i = 0; i < blowup; i++) { - uint32_t offset = i * blockDim.x + threadIdx.x; - - if (thread_pos == 0) + for (uint32_t offset = threadIdx.x, i = 0; i < blowup; i += 2) { + r.zero(); + if ((offset & (blowup-1)) == 0) r = exchange[offset >> lg_blowup]; + out[(idx0 << lg_blowup) + offset] = r; + offset += blockDim.x; + r.zero(); + if ((offset & (blowup-1)) == 0) + r = exchange[offset >> lg_blowup]; out[(idx0 << lg_blowup) + offset] = r; + offset += blockDim.x; } idx0 += stride; diff --git a/ntt/kernels/ct_mixed_radix_narrow.cu b/ntt/kernels/ct_mixed_radix_narrow.cu index 801d046..5d393ac 100644 --- a/ntt/kernels/ct_mixed_radix_narrow.cu +++ b/ntt/kernels/ct_mixed_radix_narrow.cu @@ -101,7 +101,7 @@ void _CT_NTT(const unsigned int radix, const unsigned int lg_domain_size, t.shfl_bfly(laneMask); - r[0][z] = fr_t::csel(t, r[0][z], !pos); + r[0][z] = fr_t::csel(r[0][z], t, pos); r[1][z] = fr_t::csel(t, r[1][z], pos); t = root * r[1][z]; @@ -133,7 +133,7 @@ void _CT_NTT(const unsigned int radix, const unsigned int lg_domain_size, for (int z = 0; z < z_count; z++) { fr_t t = xchg[threadIdx.x ^ laneMask][z]; - r[0][z] = fr_t::csel(t, r[0][z], !pos); + r[0][z] = fr_t::csel(r[0][z], t, pos); r[1][z] = fr_t::csel(t, r[1][z], pos); t = root * r[1][z]; diff --git a/ntt/kernels/ct_mixed_radix_wide.cu b/ntt/kernels/ct_mixed_radix_wide.cu index 6bee49e..ce6d1ca 100644 --- a/ntt/kernels/ct_mixed_radix_wide.cu +++ b/ntt/kernels/ct_mixed_radix_wide.cu @@ -80,7 +80,7 @@ void _CT_NTT(const unsigned int radix, const unsigned int lg_domain_size, fr_t x = fr_t::csel(r1, r0, pos); x.shfl_bfly(laneMask); - r0 = fr_t::csel(x, r0, !pos); + r0 = fr_t::csel(r0, x, pos); r1 = fr_t::csel(x, r1, pos); fr_t t = d_radix6_twiddles[rank << (6 - (s + 1))]; @@ -106,7 +106,7 @@ void _CT_NTT(const unsigned int radix, const unsigned int lg_domain_size, shared_exchange[threadIdx.x] = x; __syncthreads(); x = shared_exchange[threadIdx.x ^ laneMask]; - r0 = fr_t::csel(x, r0, !pos); + r0 = fr_t::csel(r0, x, pos); r1 = fr_t::csel(x, r1, pos); t *= r1; diff --git a/ntt/kernels/gs_mixed_radix_narrow.cu b/ntt/kernels/gs_mixed_radix_narrow.cu index 798f911..0b50305 100644 --- a/ntt/kernels/gs_mixed_radix_narrow.cu +++ b/ntt/kernels/gs_mixed_radix_narrow.cu @@ -80,7 +80,7 @@ void _GS_NTT(const unsigned int radix, const unsigned int lg_domain_size, #pragma unroll for (int z = 0; z < z_count; z++) { fr_t t = xchg[threadIdx.x ^ laneMask][z]; - r[0][z] = fr_t::csel(t, r[0][z], !pos); + r[0][z] = fr_t::csel(r[0][z], t, pos); r[1][z] = fr_t::csel(t, r[1][z], pos); } } @@ -104,7 +104,7 @@ void _GS_NTT(const unsigned int radix, const unsigned int lg_domain_size, t.shfl_bfly(laneMask); - r[0][z] = fr_t::csel(t, r[0][z], !pos); + r[0][z] = fr_t::csel(r[0][z], t, pos); r[1][z] = fr_t::csel(t, r[1][z], pos); } } diff --git a/ntt/kernels/gs_mixed_radix_wide.cu b/ntt/kernels/gs_mixed_radix_wide.cu index 34b2d3a..2a2725d 100644 --- a/ntt/kernels/gs_mixed_radix_wide.cu +++ b/ntt/kernels/gs_mixed_radix_wide.cu @@ -51,7 +51,7 @@ void _GS_NTT(const unsigned int radix, const unsigned int lg_domain_size, shared_exchange[threadIdx.x] = t; __syncthreads(); t = shared_exchange[threadIdx.x ^ laneMask]; - r0 = fr_t::csel(t, r0, !pos); + r0 = fr_t::csel(r0, t, pos); r1 = fr_t::csel(t, r1, pos); } @@ -70,7 +70,7 @@ void _GS_NTT(const unsigned int radix, const unsigned int lg_domain_size, t = fr_t::csel(r1, r0, pos); t.shfl_bfly(laneMask); - r0 = fr_t::csel(t, r0, !pos); + r0 = fr_t::csel(r0, t, pos); r1 = fr_t::csel(t, r1, pos); } diff --git a/ntt/parameters/baby_bear.h b/ntt/parameters/baby_bear.h index c5378e6..56cdda7 100644 --- a/ntt/parameters/baby_bear.h +++ b/ntt/parameters/baby_bear.h @@ -2,12 +2,83 @@ // Licensed under the Apache License, Version 2.0, see LICENSE for details. // SPDX-License-Identifier: Apache-2.0 +const int S = 27; + +#ifdef BABY_BEAR_CANONICAL + +const fr_t group_gen = fr_t(0x1f); // primitive_root(0x78000001) +const fr_t group_gen_inverse = fr_t(0x03def7be); + // Values in Montgomery form -const fr_t group_gen = fr_t(0x2ffffffau); -const fr_t group_gen_inverse = fr_t(0x2d555555u); +const fr_t forward_roots_of_unity[S + 1] = { + fr_t(0x0ffffffeu), + fr_t(0x68000003u), + fr_t(0x1c38d511u), + fr_t(0x3d85298fu), + fr_t(0x5f06e481u), + fr_t(0x3f5c39ecu), + fr_t(0x5516a97au), + fr_t(0x3d6be592u), + fr_t(0x5bb04149u), + fr_t(0x4907f9abu), + fr_t(0x548b8e90u), + fr_t(0x1d8ca617u), + fr_t(0x2ce7f0e6u), + fr_t(0x621b371fu), + fr_t(0x6d4d2d78u), + fr_t(0x18716fcdu), + fr_t(0x3b30a682u), + fr_t(0x1c6f4728u), + fr_t(0x59b01f7cu), + fr_t(0x1a7f97acu), + fr_t(0x0732561cu), + fr_t(0x2b5a1cd4u), + fr_t(0x6f7d26f9u), + fr_t(0x16e2f919u), + fr_t(0x285ab85bu), + fr_t(0x0dd5a9ecu), + fr_t(0x43f13568u), + fr_t(0x57fab6eeu) +}; -const int S = 27; +const fr_t inverse_roots_of_unity[S + 1] = { + fr_t(0x0ffffffeu), + fr_t(0x68000003u), + fr_t(0x5bc72af0u), + fr_t(0x02ec07f3u), + fr_t(0x67e027cau), + fr_t(0x5e1a0700u), + fr_t(0x4bcc008cu), + fr_t(0x0bed94d1u), + fr_t(0x330b2e00u), + fr_t(0x6b469805u), + fr_t(0x0d83fad2u), + fr_t(0x26e64394u), + fr_t(0x0855523bu), + fr_t(0x5c9f0045u), + fr_t(0x5a7ba8c3u), + fr_t(0x3c8b04e2u), + fr_t(0x0c0f2066u), + fr_t(0x1b51d34cu), + fr_t(0x59f9bc12u), + fr_t(0x3511f012u), + fr_t(0x061ec85fu), + fr_t(0x5fd09c6bu), + fr_t(0x26bdc06cu), + fr_t(0x1272832eu), + fr_t(0x052ce2e8u), + fr_t(0x02ff110du), + fr_t(0x216ce204u), + fr_t(0x5e12c8e9u) +}; + +#else + +const fr_t group_gen = fr_t(3); +const fr_t group_gen_inverse = fr_t(0x50000001); + +// Values in Montgomery form const fr_t forward_roots_of_unity[S + 1] = { fr_t(0x0ffffffeu), @@ -70,6 +141,7 @@ const fr_t inverse_roots_of_unity[S + 1] = { fr_t(0x167ca34bu), fr_t(0x50b3630au) }; +#endif const fr_t domain_size_inverse[S + 1] = { fr_t(0x0ffffffeu), diff --git a/util/gpu_t.cuh b/util/gpu_t.cuh index 65ef882..ce01e20 100644 --- a/util/gpu_t.cuh +++ b/util/gpu_t.cuh @@ -77,6 +77,10 @@ public: inline void Dfree(void* d_ptr) const { CUDA_OK(cudaFreeAsync(d_ptr, stream)); } + template + inline void bzero(T* dst, size_t nelems) const + { CUDA_OK(cudaMemsetAsync(dst, 0, nelems * sizeof(T), stream)); } + template inline void HtoD(T* dst, const void* src, size_t nelems, size_t sz = sizeof(T)) const @@ -212,6 +216,10 @@ public: zero.sync(); } + template + inline void bzero(T* dst, size_t nelems) const + { zero.bzero(dst, nelems); } + template inline void HtoD(T* dst, const void* src, size_t nelems, size_t sz = sizeof(T)) const