Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Field creation automated through macros #551

Merged
merged 9 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 51 additions & 102 deletions icicle/include/fields/field.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public:

static constexpr HOST_DEVICE_INLINE Field from(uint32_t value)
{
storage<TLC> scalar;
storage<TLC> scalar{};
scalar.limbs[0] = value;
for (int i = 1; i < TLC; i++) {
scalar.limbs[i] = 0;
Expand All @@ -58,8 +58,10 @@ public:

if (logn > CONFIG::omegas_count) { THROW_ICICLE_ERR(IcicleError_t::InvalidArgument, "Field: Invalid omega index"); }

storage_array<CONFIG::omegas_count, TLC> const omega = CONFIG::omega;
return Field{omega.storages[logn - 1]};
Field omega = Field{CONFIG::rou};
for (int i = 0; i < CONFIG::omegas_count - logn; i++)
omega = sqr(omega);
return omega;
}

static HOST_INLINE Field omega_inv(uint32_t logn)
Expand All @@ -70,8 +72,10 @@ public:
THROW_ICICLE_ERR(IcicleError_t::InvalidArgument, "Field: Invalid omega_inv index");
}

storage_array<CONFIG::omegas_count, TLC> const omega_inv = CONFIG::omega_inv;
return Field{omega_inv.storages[logn - 1]};
Field omega = inverse(Field{CONFIG::rou});
for (int i = 0; i < CONFIG::omegas_count - logn; i++)
omega = sqr(omega);
return omega;
}

static HOST_DEVICE_INLINE Field inv_log_size(uint32_t logn)
Expand Down Expand Up @@ -182,32 +186,32 @@ public:
if (REDUCTION_SIZE == 0) return xs;
const ff_wide_storage modulus = get_modulus_squared<REDUCTION_SIZE>();
Wide rs = {};
return sub_limbs<true>(xs.limbs_storage, modulus, rs.limbs_storage) ? xs : rs;
return sub_limbs<2 * TLC, true>(xs.limbs_storage, modulus, rs.limbs_storage) ? xs : rs;
}

template <unsigned MODULUS_MULTIPLE = 1>
static constexpr HOST_DEVICE_INLINE Wide neg(const Wide& xs)
{
const ff_wide_storage modulus = get_modulus_squared<MODULUS_MULTIPLE>();
Wide rs = {};
sub_limbs<false>(modulus, xs.limbs_storage, rs.limbs_storage);
sub_limbs<2 * TLC, false>(modulus, xs.limbs_storage, rs.limbs_storage);
return rs;
}

friend HOST_DEVICE_INLINE Wide operator+(Wide xs, const Wide& ys)
{
Wide rs = {};
add_limbs<false>(xs.limbs_storage, ys.limbs_storage, rs.limbs_storage);
add_limbs<2 * TLC, false>(xs.limbs_storage, ys.limbs_storage, rs.limbs_storage);
return sub_modulus_squared<1>(rs);
}

friend HOST_DEVICE_INLINE Wide operator-(Wide xs, const Wide& ys)
{
Wide rs = {};
uint32_t carry = sub_limbs<true>(xs.limbs_storage, ys.limbs_storage, rs.limbs_storage);
uint32_t carry = sub_limbs<2 * TLC, true>(xs.limbs_storage, ys.limbs_storage, rs.limbs_storage);
if (carry == 0) return rs;
const ff_wide_storage modulus = get_modulus_squared<1>();
add_limbs<false>(rs.limbs_storage, modulus, rs.limbs_storage);
add_limbs<2 * TLC, false>(rs.limbs_storage, modulus, rs.limbs_storage);
return rs;
}
};
Expand All @@ -228,12 +232,6 @@ public:
}
}

template <unsigned MULTIPLIER = 1>
static constexpr HOST_DEVICE_INLINE ff_wide_storage modulus_wide()
{
return CONFIG::modulus_wide;
}

// return m
static constexpr HOST_DEVICE_INLINE ff_storage get_m() { return CONFIG::m; }

Expand All @@ -253,12 +251,11 @@ public:
}
}

template <bool SUBTRACT, bool CARRY_OUT>
static constexpr DEVICE_INLINE uint32_t
add_sub_u32_device(const uint32_t* x, const uint32_t* y, uint32_t* r, size_t n = (TLC >> 1))
template <unsigned NLIMBS, bool SUBTRACT, bool CARRY_OUT>
static constexpr DEVICE_INLINE uint32_t add_sub_u32_device(const uint32_t* x, const uint32_t* y, uint32_t* r)
{
r[0] = SUBTRACT ? ptx::sub_cc(x[0], y[0]) : ptx::add_cc(x[0], y[0]);
for (unsigned i = 1; i < n; i++)
for (unsigned i = 1; i < NLIMBS; i++)
r[i] = SUBTRACT ? ptx::subc_cc(x[i], y[i]) : ptx::addc_cc(x[i], y[i]);
if (!CARRY_OUT) {
ptx::addc(0, 0);
Expand All @@ -267,71 +264,35 @@ public:
return SUBTRACT ? ptx::subc(0, 0) : ptx::addc(0, 0);
}

// add or subtract limbs
template <bool SUBTRACT, bool CARRY_OUT>
static constexpr DEVICE_INLINE uint32_t
add_sub_limbs_device(const ff_storage& xs, const ff_storage& ys, ff_storage& rs)
{
const uint32_t* x = xs.limbs;
const uint32_t* y = ys.limbs;
uint32_t* r = rs.limbs;
return add_sub_u32_device<SUBTRACT, CARRY_OUT>(x, y, r, TLC);
}

template <bool SUBTRACT, bool CARRY_OUT>
template <unsigned NLIMBS, bool SUBTRACT, bool CARRY_OUT>
static constexpr DEVICE_INLINE uint32_t
add_sub_limbs_device(const ff_wide_storage& xs, const ff_wide_storage& ys, ff_wide_storage& rs)
{
const uint32_t* x = xs.limbs;
const uint32_t* y = ys.limbs;
uint32_t* r = rs.limbs;
return add_sub_u32_device<SUBTRACT, CARRY_OUT>(x, y, r, 2 * TLC);
}

template <bool SUBTRACT, bool CARRY_OUT>
static constexpr HOST_INLINE uint32_t add_sub_limbs_host(const ff_storage& xs, const ff_storage& ys, ff_storage& rs)
{
const uint32_t* x = xs.limbs;
const uint32_t* y = ys.limbs;
uint32_t* r = rs.limbs;
uint32_t carry = 0;
host_math::carry_chain<TLC, false, CARRY_OUT> chain;
for (unsigned i = 0; i < TLC; i++)
r[i] = SUBTRACT ? chain.sub(x[i], y[i], carry) : chain.add(x[i], y[i], carry);
return CARRY_OUT ? carry : 0;
}

template <bool SUBTRACT, bool CARRY_OUT>
static constexpr HOST_INLINE uint32_t
add_sub_limbs_host(const ff_wide_storage& xs, const ff_wide_storage& ys, ff_wide_storage& rs)
add_sub_limbs_device(const storage<NLIMBS>& xs, const storage<NLIMBS>& ys, storage<NLIMBS>& rs)
{
const uint32_t* x = xs.limbs;
const uint32_t* y = ys.limbs;
uint32_t* r = rs.limbs;
uint32_t carry = 0;
host_math::carry_chain<2 * TLC, false, CARRY_OUT> chain;
for (unsigned i = 0; i < 2 * TLC; i++)
r[i] = SUBTRACT ? chain.sub(x[i], y[i], carry) : chain.add(x[i], y[i], carry);
return CARRY_OUT ? carry : 0;
return add_sub_u32_device<NLIMBS, SUBTRACT, CARRY_OUT>(x, y, r);
}

template <bool CARRY_OUT, typename T>
static constexpr HOST_DEVICE_INLINE uint32_t add_limbs(const T& xs, const T& ys, T& rs)
template <unsigned NLIMBS, bool CARRY_OUT>
static constexpr HOST_DEVICE_INLINE uint32_t
add_limbs(const storage<NLIMBS>& xs, const storage<NLIMBS>& ys, storage<NLIMBS>& rs)
{
#ifdef __CUDA_ARCH__
return add_sub_limbs_device<false, CARRY_OUT>(xs, ys, rs);
return add_sub_limbs_device<NLIMBS, false, CARRY_OUT>(xs, ys, rs);
#else
return add_sub_limbs_host<false, CARRY_OUT>(xs, ys, rs);
return host_math::template add_sub_limbs<NLIMBS, false, CARRY_OUT>(xs, ys, rs);
#endif
}

template <bool CARRY_OUT, typename T>
static constexpr HOST_DEVICE_INLINE uint32_t sub_limbs(const T& xs, const T& ys, T& rs)
template <unsigned NLIMBS, bool CARRY_OUT>
static constexpr HOST_DEVICE_INLINE uint32_t
sub_limbs(const storage<NLIMBS>& xs, const storage<NLIMBS>& ys, storage<NLIMBS>& rs)
{
#ifdef __CUDA_ARCH__
return add_sub_limbs_device<true, CARRY_OUT>(xs, ys, rs);
return add_sub_limbs_device<NLIMBS, true, CARRY_OUT>(xs, ys, rs);
#else
return add_sub_limbs_host<true, CARRY_OUT>(xs, ys, rs);
return host_math::template add_sub_limbs<NLIMBS, true, CARRY_OUT>(xs, ys, rs);
#endif
}

Expand Down Expand Up @@ -531,7 +492,7 @@ public:
// are necessarily NTT-friendly, `b[0]` often turns out to be \f$ 2^{32} - 1 \f$. This actually leads to
// less efficient SASS generated by nvcc, so this case needed separate handling.
if (b[0] == UINT32_MAX) {
add_sub_u32_device<true, false>(c, a, even, TLC);
add_sub_u32_device<TLC, true, false>(c, a, even);
for (i = 0; i < TLC - 1; i++)
odd[i] = a[i];
} else {
Expand Down Expand Up @@ -639,17 +600,18 @@ public:
__align__(16) uint32_t diffs[TLC];
// Differences of halves \f$ a_{hi} - a_{lo}; b_{lo} - b_{hi} \$f are written into `diffs`, signs written to
// `carry1` and `carry2`.
uint32_t carry1 = add_sub_u32_device<true, true>(&a[TLC >> 1], a, diffs);
uint32_t carry2 = add_sub_u32_device<true, true>(b, &b[TLC >> 1], &diffs[TLC >> 1]);
uint32_t carry1 = add_sub_u32_device<(TLC >> 1), true, true>(&a[TLC >> 1], a, diffs);
uint32_t carry2 = add_sub_u32_device<(TLC >> 1), true, true>(b, &b[TLC >> 1], &diffs[TLC >> 1]);
// Compute the "middle part" of Karatsuba: \f$ a_{lo} \cdot b_{hi} + b_{lo} \cdot a_{hi} \f$.
// This is where the assumption about unset high bit of `a` and `b` is relevant.
multiply_and_add_short_raw_device(diffs, &diffs[TLC >> 1], middle_part, r, &r[TLC]);
// Corrections that need to be performed when differences are negative.
// Again, carry doesn't need to be propagated due to unset high bits of `a` and `b`.
if (carry1) add_sub_u32_device<true, false>(&middle_part[TLC >> 1], &diffs[TLC >> 1], &middle_part[TLC >> 1]);
if (carry2) add_sub_u32_device<true, false>(&middle_part[TLC >> 1], diffs, &middle_part[TLC >> 1]);
if (carry1)
add_sub_u32_device<(TLC >> 1), true, false>(&middle_part[TLC >> 1], &diffs[TLC >> 1], &middle_part[TLC >> 1]);
if (carry2) add_sub_u32_device<(TLC >> 1), true, false>(&middle_part[TLC >> 1], diffs, &middle_part[TLC >> 1]);
// Now that middle part is fully correct, it can be added to the result.
add_sub_u32_device<false, true>(&r[TLC >> 1], middle_part, &r[TLC >> 1], TLC);
add_sub_u32_device<TLC, false, true>(&r[TLC >> 1], middle_part, &r[TLC >> 1]);

// Carry from adding middle part has to be propagated to the highest limb.
for (size_t i = TLC + (TLC >> 1); i < 2 * TLC; i++)
Expand All @@ -673,25 +635,12 @@ public:
}
}

static HOST_INLINE void multiply_raw_host(const ff_storage& as, const ff_storage& bs, ff_wide_storage& rs)
{
const uint32_t* a = as.limbs;
const uint32_t* b = bs.limbs;
uint32_t* r = rs.limbs;
for (unsigned i = 0; i < TLC; i++) {
uint32_t carry = 0;
for (unsigned j = 0; j < TLC; j++)
r[j + i] = host_math::madc_cc(a[j], b[i], r[j + i], carry);
r[TLC + i] = carry;
}
}

static HOST_DEVICE_INLINE void multiply_raw(const ff_storage& as, const ff_storage& bs, ff_wide_storage& rs)
{
#ifdef __CUDA_ARCH__
return multiply_raw_device(as, bs, rs);
#else
return multiply_raw_host(as, bs, rs);
return host_math::template multiply_raw<TLC>(as, bs, rs);
#endif
}

Expand All @@ -702,9 +651,9 @@ public:
return multiply_and_add_lsb_neg_modulus_raw_device(as, cs, rs);
#else
Wide r_wide = {};
multiply_raw_host(as, get_neg_modulus(), r_wide.limbs_storage);
host_math::template multiply_raw<TLC>(as, get_neg_modulus(), r_wide.limbs_storage);
Field r = Wide::get_lower(r_wide);
add_limbs<false>(cs, r.limbs_storage, rs);
add_limbs<TLC, false>(cs, r.limbs_storage, rs);
#endif
}

Expand All @@ -713,7 +662,7 @@ public:
#ifdef __CUDA_ARCH__
return multiply_msb_raw_device(as, bs, rs);
#else
return multiply_raw_host(as, bs, rs);
return host_math::template multiply_raw<TLC>(as, bs, rs);
#endif
}

Expand Down Expand Up @@ -759,7 +708,7 @@ public:
if (REDUCTION_SIZE == 0) return xs;
const ff_storage modulus = get_modulus<REDUCTION_SIZE>();
Field rs = {};
return sub_limbs<true>(xs.limbs_storage, modulus, rs.limbs_storage) ? xs : rs;
return sub_limbs<TLC, true>(xs.limbs_storage, modulus, rs.limbs_storage) ? xs : rs;
}

friend std::ostream& operator<<(std::ostream& os, const Field& xs)
Expand All @@ -778,17 +727,17 @@ public:
friend HOST_DEVICE_INLINE Field operator+(Field xs, const Field& ys)
{
Field rs = {};
add_limbs<false>(xs.limbs_storage, ys.limbs_storage, rs.limbs_storage);
add_limbs<TLC, false>(xs.limbs_storage, ys.limbs_storage, rs.limbs_storage);
return sub_modulus<1>(rs);
}

friend HOST_DEVICE_INLINE Field operator-(Field xs, const Field& ys)
{
Field rs = {};
uint32_t carry = sub_limbs<true>(xs.limbs_storage, ys.limbs_storage, rs.limbs_storage);
uint32_t carry = sub_limbs<TLC, true>(xs.limbs_storage, ys.limbs_storage, rs.limbs_storage);
if (carry == 0) return rs;
const ff_storage modulus = get_modulus<1>();
add_limbs<false>(rs.limbs_storage, modulus, rs.limbs_storage);
add_limbs<TLC, false>(rs.limbs_storage, modulus, rs.limbs_storage);
return rs;
}

Expand Down Expand Up @@ -838,10 +787,10 @@ public:
uint32_t carry;
// As mentioned, either 2 or 1 reduction can be performed depending on the field in question.
if (num_of_reductions() == 2) {
carry = sub_limbs<true>(r.limbs_storage, get_modulus<2>(), r_reduced);
carry = sub_limbs<TLC, true>(r.limbs_storage, get_modulus<2>(), r_reduced);
if (carry == 0) r = Field{r_reduced};
}
carry = sub_limbs<true>(r.limbs_storage, get_modulus<1>(), r_reduced);
carry = sub_limbs<TLC, true>(r.limbs_storage, get_modulus<1>(), r_reduced);
if (carry == 0) r = Field{r_reduced};

return r;
Expand Down Expand Up @@ -933,7 +882,7 @@ public:
{
const ff_storage modulus = get_modulus<MODULUS_MULTIPLE>();
Field rs = {};
sub_limbs<false>(modulus, xs.limbs_storage, rs.limbs_storage);
sub_limbs<TLC, false>(modulus, xs.limbs_storage, rs.limbs_storage);
return rs;
}

Expand Down Expand Up @@ -963,7 +912,7 @@ public:
static constexpr HOST_DEVICE_INLINE bool lt(const Field& xs, const Field& ys)
{
ff_storage dummy = {};
uint32_t carry = sub_limbs<true>(xs.limbs_storage, ys.limbs_storage, dummy);
uint32_t carry = sub_limbs<TLC, true>(xs.limbs_storage, ys.limbs_storage, dummy);
return carry;
}

Expand All @@ -983,12 +932,12 @@ public:
while (!(u == one) && !(v == one)) {
while (is_even(u)) {
u = div2(u);
if (is_odd(b)) add_limbs<false>(b.limbs_storage, modulus, b.limbs_storage);
if (is_odd(b)) add_limbs<TLC, false>(b.limbs_storage, modulus, b.limbs_storage);
b = div2(b);
}
while (is_even(v)) {
v = div2(v);
if (is_odd(c)) add_limbs<false>(c.limbs_storage, modulus, c.limbs_storage);
if (is_odd(c)) add_limbs<TLC, false>(c.limbs_storage, modulus, c.limbs_storage);
c = div2(c);
}
if (lt(v, u)) {
Expand Down
Loading
Loading