Skip to content

Commit

Permalink
Merge pull request #2018 from vneiger/nmod_vec_dot_product_small_modulus
Browse files Browse the repository at this point in the history
nmod_vec_dot_product enhancements (including avx2 for small moduli)
  • Loading branch information
fredrik-johansson authored Jul 3, 2024
2 parents b579cdd + 6ca1357 commit 4496945
Show file tree
Hide file tree
Showing 36 changed files with 1,655 additions and 468 deletions.
126 changes: 92 additions & 34 deletions doc/source/nmod_vec.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Random functions

.. function:: void _nmod_vec_randtest(nn_ptr vec, flint_rand_t state, slong len, nmod_t mod)

Sets ``vec`` to a random vector of the given length with entries
Sets ``vec`` to a random vector of the given length with entries
reduced modulo ``mod.n``.


Expand All @@ -46,7 +46,7 @@ Basic manipulation and comparison

.. function:: void _nmod_vec_reduce(nn_ptr res, nn_srcptr vec, slong len, nmod_t mod)

Reduces the entries of ``(vec, len)`` modulo ``mod.n`` and set
Reduces the entries of ``(vec, len)`` modulo ``mod.n`` and set
``res`` to the result.

.. function:: flint_bitcnt_t _nmod_vec_max_bits(nn_srcptr vec, slong len)
Expand All @@ -55,8 +55,8 @@ Basic manipulation and comparison

.. function:: int _nmod_vec_equal(nn_srcptr vec, nn_srcptr vec2, slong len)

Returns~`1` if ``(vec, len)`` is equal to ``(vec2, len)``,
otherwise returns~`0`.
Returns `1` if ``(vec, len)`` is equal to ``(vec2, len)``,
otherwise returns `0`.


Printing
Expand Down Expand Up @@ -92,12 +92,12 @@ Arithmetic operations

.. function:: void _nmod_vec_add(nn_ptr res, nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod)

Sets ``(res, len)`` to the sum of ``(vec1, len)``
Sets ``(res, len)`` to the sum of ``(vec1, len)``
and ``(vec2, len)``.

.. function:: void _nmod_vec_sub(nn_ptr res, nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod)

Sets ``(res, len)`` to the difference of ``(vec1, len)``
Sets ``(res, len)`` to the difference of ``(vec1, len)``
and ``(vec2, len)``.

.. function:: void _nmod_vec_neg(nn_ptr res, nn_srcptr vec, slong len, nmod_t mod)
Expand All @@ -107,62 +107,120 @@ Arithmetic operations
.. function:: void _nmod_vec_scalar_mul_nmod(nn_ptr res, nn_srcptr vec, slong len, ulong c, nmod_t mod)

Sets ``(res, len)`` to ``(vec, len)`` multiplied by `c`. The element
`c` and all elements of `vec` are assumed to be less than `mod.n`.
`c` and all elements of ``vec`` are assumed to be less than ``mod.n``.

.. function:: void _nmod_vec_scalar_mul_nmod_shoup(nn_ptr res, nn_srcptr vec, slong len, ulong c, nmod_t mod)

Sets ``(res, len)`` to ``(vec, len)`` multiplied by `c` using
:func:`n_mulmod_shoup`. `mod.n` should be less than `2^{\mathtt{FLINT\_BITS} - 1}`. `c`
and all elements of `vec` should be less than `mod.n`.
:func:`n_mulmod_shoup`. `mod.n` should be less than `2^{\mathtt{FLINT\_BITS} - 1}`. `c`
and all elements of ``vec`` should be less than ``mod.n``.

.. function:: void _nmod_vec_scalar_addmul_nmod(nn_ptr res, nn_srcptr vec, slong len, ulong c, nmod_t mod)

Adds ``(vec, len)`` times `c` to the vector ``(res, len)``. The element
`c` and all elements of `vec` are assumed to be less than `mod.n`.
`c` and all elements of ``vec`` are assumed to be less than ``mod.n``.


Dot products
--------------------------------------------------------------------------------

Dot products functions and macros rely on several implementations, depending on
the length of this dot product and on the underlying modulus. What
implementations will be called is determined via ``_nmod_vec_dot_params``,
which returns a ``dot_params_t`` element which can then be used as input to the
dot product routines.

.. function:: int _nmod_vec_dot_bound_limbs(slong len, nmod_t mod)
The efficiency of the different approaches range roughly as follows, from
faster to slower, on 64 bit machines. In all cases, modular reduction is only
performed at the very end of the computation.

Returns the number of limbs (0, 1, 2 or 3) needed to represent the
unreduced dot product of two vectors of length ``len`` having entries
modulo ``mod.n``, assuming that ``len`` is nonnegative and that
``mod.n`` is nonzero. The computed bound is tight. In other words,
this function returns the precise limb size of ``len`` times
``(mod.n - 1) ^ 2``.
- moduli up to `1515531528` (about `2^{30.5}`): implemented via single limb
integer multiplication, using explicit vectorization if supported (current
support is for AVX2);

- moduli that are a power of `2` up to `2^{32}`: same efficiency as the above
case;

- moduli that are a power of `2` between `2^{33}` and `2^{63}`: efficiency
between that of the above case and that of the below one (depending on the
machine and on automatic vectorization);

- other moduli up to `2^{32}`: implemented via single limb integer
multiplication combined with accumulation in two limbs;

- moduli more than `2^{32}`, unreduced dot product fits in two limbs:
implemented via two limbs integer multiplication, with a final modular
reduction;

- unreduced dot product fits in three limbs, moduli up to about `2^{62.5}`:
implemented via two limbs integer multiplication, with intermediate
accumulation of sub-products in two limbs, and overall accumulation in three
limbs;

- unreduced dot product fits in three limbs, other moduli: implemented via two
limbs integer multiplication, with accumulation in three limbs.


.. type:: dot_params_t

.. function:: dot_params_t _nmod_vec_dot_params(slong len, nmod_t mod)

Returns a ``dot_params_t`` element. This element can be used as input for
the dot product macros and functions that require it, for any dot product
of vector with entries reduced modulo ``mod.n`` and whose length is less
than or equal to ``len``.

Internals, subject to change: its field ``method`` indicates the method that
will be used to compute a dot product of this length ``len`` when working
with the given ``mod``. Its field ``pow2_precomp`` is set to ``2**DOT_SPLIT_BITS
% mod.n`` if ``method == _DOT2_SPLIT``, and to `0` otherwise.

.. macro:: NMOD_VEC_DOT(res, i, len, expr1, expr2, mod, nlimbs)
.. function:: ulong _nmod_vec_dot(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, dot_params_t params)

Returns the dot product of (``vec1``, ``len``) and (``vec2``, ``len``). The
input ``params`` has type ``dot_params_t`` and must have been computed via
``_nmod_vec_dot_params`` with the specified ``mod`` and with a length
greater than or equal to ``len``.

.. function:: ulong _nmod_vec_dot_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, dot_params_t params)

The same as ``_nmod_vec_dot``, but reverses ``vec2``.

.. function:: ulong _nmod_vec_dot_ptr(nn_srcptr vec1, const nn_ptr * vec2, slong offset, slong len, nmod_t mod, dot_params_t params)

Returns the dot product of (``vec1``, ``len``) and the values at
``vec2[i][offset]``. The input ``params`` has type ``dot_params_t`` and
must have been computed via ``_nmod_vec_dot_params`` with the specified
``mod`` and with a length greater than or equal to ``len``.

.. macro:: NMOD_VEC_DOT(res, i, len, expr1, expr2, mod, params)

Effectively performs the computation::

res = 0;
for (i = 0; i < len; i++)
res += (expr1) * (expr2);

but with the arithmetic performed modulo ``mod``. The ``nlimbs`` parameter
should be 0, 1, 2 or 3, specifying the number of limbs needed to represent
the unreduced result.
but with the arithmetic performed modulo ``mod``. The input ``params`` has
type ``dot_params_t`` and must have been computed via
``_nmod_vec_dot_params`` with the specified ``mod`` and with a length
greater than or equal to ``len``.

``nmod.h`` has to be included in order for this macro to work (order of
inclusions does not matter).

.. function:: ulong _nmod_vec_dot(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, int nlimbs)

Returns the dot product of (``vec1``, ``len``) and
(``vec2``, ``len``). The ``nlimbs`` parameter should be
0, 1, 2 or 3, specifying the number of limbs needed to represent the
unreduced result.
.. function:: int _nmod_vec_dot_bound_limbs(slong len, nmod_t mod)

.. function:: ulong _nmod_vec_dot_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, int nlimbs)
Returns the number of limbs (0, 1, 2 or 3) needed to represent the
unreduced dot product of two vectors of length ``len`` having entries
modulo ``mod.n``, assuming that ``len`` is nonnegative and that
``mod.n`` is nonzero. The computed bound is tight. In other words,
this function returns the precise limb size of ``len`` times
``(mod.n - 1)**2``.

The same as ``_nmod_vec_dot``, but reverses ``vec2``.
.. function:: int _nmod_vec_dot_bound_limbs_from_params(slong len, nmod_t mod, dot_params_t params)

.. function:: ulong _nmod_vec_dot_ptr(nn_srcptr vec1, const nn_ptr * vec2, slong offset, slong len, nmod_t mod, int nlimbs)
Same specification as ``_nmod_vec_dot_bound_limbs``, but uses the additional
input ``params`` to reduce the amount of computations; for correctness
``params`` must have been computed for the specified ``len`` and ``mod``.

Returns the dot product of (``vec1``, ``len``) and the values at
``vec2[i][offset]``. The ``nlimbs`` parameter should be
0, 1, 2 or 3, specifying the number of limbs needed to represent the
unreduced result.
10 changes: 5 additions & 5 deletions src/arith/stirling2.c
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ stirling_2_nmod(const unsigned int * divtab, ulong n, ulong k, nmod_t mod)
nn_ptr t, u;
slong i, bin_len, pow_len;
ulong s1, s2, bden, bd;
int bound_limbs;
dot_params_t params;
TMP_INIT;
TMP_START;

Expand Down Expand Up @@ -575,13 +575,13 @@ stirling_2_nmod(const unsigned int * divtab, ulong n, ulong k, nmod_t mod)
for (i = 1; i < bin_len; i += 2)
t[i] = nmod_neg(t[i], mod);

bound_limbs = _nmod_vec_dot_bound_limbs(bin_len, mod);
s1 = _nmod_vec_dot(t, u, bin_len, mod, bound_limbs);
params = _nmod_vec_dot_params(bin_len, mod);
s1 = _nmod_vec_dot(t, u, bin_len, mod, params);

if (pow_len > bin_len)
{
bound_limbs = _nmod_vec_dot_bound_limbs(pow_len - bin_len, mod);
s2 = _nmod_vec_dot_rev(u + bin_len, t + k - pow_len + 1, pow_len - bin_len, mod, bound_limbs);
params = _nmod_vec_dot_params(pow_len - bin_len, mod);
s2 = _nmod_vec_dot_rev(u + bin_len, t + k - pow_len + 1, pow_len - bin_len, mod, params);
if (k % 2)
s1 = nmod_sub(s1, s2, mod);
else
Expand Down
16 changes: 8 additions & 8 deletions src/fq_nmod_mpoly/fq_nmod_embed.c
Original file line number Diff line number Diff line change
Expand Up @@ -352,12 +352,12 @@ void bad_n_fq_embed_lg_to_sm(
slong smd = fq_nmod_ctx_degree(emb->smctx);
slong lgd = fq_nmod_ctx_degree(emb->lgctx);
slong i;
int nlimbs = _nmod_vec_dot_bound_limbs(lgd, emb->lgctx->mod);
const dot_params_t params = _nmod_vec_dot_params(lgd, emb->lgctx->mod);

n_poly_fit_length(out, lgd);
for (i = 0; i < lgd; i++)
out->coeffs[i] = _nmod_vec_dot(emb->lg_to_sm_mat->rows[i], in, lgd,
emb->lgctx->mod, nlimbs);
emb->lgctx->mod, params);
FLINT_ASSERT(lgd/smd == emb->h->length - 1);
out->length = emb->h->length - 1;
_n_fq_poly_normalise(out, smd);
Expand Down Expand Up @@ -438,7 +438,7 @@ void bad_n_fq_embed_sm_to_lg(
slong smd = fq_nmod_ctx_degree(emb->smctx);
slong lgd = fq_nmod_ctx_degree(emb->lgctx);
slong i;
int nlimbs = _nmod_vec_dot_bound_limbs(lgd, emb->lgctx->mod);
const dot_params_t params = _nmod_vec_dot_params(lgd, emb->lgctx->mod);
n_poly_stack_t St; /* TODO: pass the stack in */
n_fq_poly_struct * q, * in_red;

Expand All @@ -454,7 +454,7 @@ void bad_n_fq_embed_sm_to_lg(

for (i = 0; i < lgd; i++)
out[i] = _nmod_vec_dot(emb->sm_to_lg_mat->rows[i], in_red->coeffs,
smd*in_red->length, emb->lgctx->mod, nlimbs);
smd*in_red->length, emb->lgctx->mod, params);

n_poly_stack_give_back(St, 2);

Expand Down Expand Up @@ -544,11 +544,11 @@ void bad_n_fq_embed_sm_elem_to_lg(
slong smd = fq_nmod_ctx_degree(emb->smctx);
slong lgd = fq_nmod_ctx_degree(emb->lgctx);
slong i;
int nlimbs = _nmod_vec_dot_bound_limbs(smd, emb->lgctx->mod);
const dot_params_t params = _nmod_vec_dot_params(smd, emb->lgctx->mod);

for (i = 0; i < lgd; i++)
out[i] = _nmod_vec_dot(emb->sm_to_lg_mat->rows[i], in, smd,
emb->lgctx->mod, nlimbs);
emb->lgctx->mod, params);
}

void bad_fq_nmod_embed_sm_elem_to_lg(
Expand All @@ -559,7 +559,7 @@ void bad_fq_nmod_embed_sm_elem_to_lg(
slong smd = fq_nmod_ctx_degree(emb->smctx);
slong lgd = fq_nmod_ctx_degree(emb->lgctx);
slong i;
int nlimbs = _nmod_vec_dot_bound_limbs(smd, emb->lgctx->mod);
const dot_params_t params = _nmod_vec_dot_params(smd, emb->lgctx->mod);

FLINT_ASSERT(in->length <= smd);

Expand All @@ -568,7 +568,7 @@ void bad_fq_nmod_embed_sm_elem_to_lg(
for (i = 0; i < lgd; i++)
{
out->coeffs[i] = _nmod_vec_dot(emb->sm_to_lg_mat->rows[i],
in->coeffs, in->length, emb->lgctx->mod, nlimbs);
in->coeffs, in->length, emb->lgctx->mod, params);
}

out->length = lgd;
Expand Down
5 changes: 2 additions & 3 deletions src/fq_nmod_mpoly_factor/n_bpoly_fq_factor_lgprime.c
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,10 @@ static void _lattice(
n_bpoly_t Q, R, dg;
n_bpoly_struct * ld;
nmod_mat_t M, T1, T2;
int nlimbs;
ulong * trow;
slong lift_order = lift_alpha_pow->length - 1;

nlimbs = _nmod_vec_dot_bound_limbs(r, ctx->mod);
const dot_params_t params = _nmod_vec_dot_params(r, ctx->mod);
trow = (ulong *) flint_malloc(r*sizeof(ulong));
n_bpoly_init(Q);
n_bpoly_init(R);
Expand Down Expand Up @@ -243,7 +242,7 @@ static void _lattice(

for (i = 0; i < d; i++)
nmod_mat_entry(M, (j - starts[k])*deg + l, i) =
_nmod_vec_dot(trow, N->rows[i], r, ctx->mod, nlimbs);
_nmod_vec_dot(trow, N->rows[i], r, ctx->mod, params);
}

nmod_mat_init_nullspace_tr(T1, M);
Expand Down
27 changes: 13 additions & 14 deletions src/fq_nmod_mpoly_factor/n_bpoly_fq_factor_smprime.c
Original file line number Diff line number Diff line change
Expand Up @@ -957,10 +957,9 @@ static void _lattice(
n_fq_bpoly_t Q, R, dg;
n_fq_bpoly_struct * ld;
nmod_mat_t M, T1, T2;
int nlimbs;
ulong * trow;

nlimbs = _nmod_vec_dot_bound_limbs(r, ctx->mod);
const dot_params_t params = _nmod_vec_dot_params(r, ctx->mod);
trow = (ulong *) flint_malloc(r*sizeof(ulong));
n_fq_bpoly_init(Q);
n_fq_bpoly_init(R);
Expand All @@ -985,20 +984,20 @@ static void _lattice(
nmod_mat_init(M, d*(lift_order - CLD[k]), nrows, ctx->modulus->mod.n);

for (j = CLD[k]; j < lift_order; j++)
for (l = 0; l < d; l++)
{
for (i = 0; i < r; i++)
for (l = 0; l < d; l++)
{
if (k >= ld[i].length || j >= ld[i].coeffs[k].length)
trow[i] = 0;
else
trow[i] = ld[i].coeffs[k].coeffs[d*j + l];
}
for (i = 0; i < r; i++)
{
if (k >= ld[i].length || j >= ld[i].coeffs[k].length)
trow[i] = 0;
else
trow[i] = ld[i].coeffs[k].coeffs[d*j + l];
}

for (i = 0; i < nrows; i++)
nmod_mat_entry(M, (j - CLD[k])*d + l, i) =
_nmod_vec_dot(trow, N->rows[i], r, ctx->mod, nlimbs);
}
for (i = 0; i < nrows; i++)
nmod_mat_entry(M, (j - CLD[k])*d + l, i) =
_nmod_vec_dot(trow, N->rows[i], r, ctx->mod, params);
}

nmod_mat_init_nullspace_tr(T1, M);

Expand Down
5 changes: 2 additions & 3 deletions src/fq_zech_mpoly_factor/bpoly_factor_smprime.c
Original file line number Diff line number Diff line change
Expand Up @@ -496,10 +496,9 @@ static void _lattice(
fq_zech_bpoly_t Q, R, dg;
fq_zech_bpoly_struct * ld;
nmod_mat_t M, T1, T2;
int nlimbs;
ulong * trow;

nlimbs = _nmod_vec_dot_bound_limbs(r, fq_zech_ctx_mod(ctx));
const dot_params_t params = _nmod_vec_dot_params(r, fq_zech_ctx_mod(ctx));
trow = (ulong *) flint_malloc(r*sizeof(ulong));
fq_zech_bpoly_init(Q, ctx);
fq_zech_bpoly_init(R, ctx);
Expand Down Expand Up @@ -549,7 +548,7 @@ static void _lattice(

for (i = 0; i < d; i++)
nmod_mat_entry(M, (j - starts[k])*deg + l, i) =
_nmod_vec_dot(trow, N->rows[i], r, fq_zech_ctx_mod(ctx), nlimbs);
_nmod_vec_dot(trow, N->rows[i], r, fq_zech_ctx_mod(ctx), params);
}

nmod_mat_init_nullspace_tr(T1, M);
Expand Down
Loading

0 comments on commit 4496945

Please sign in to comment.