Skip to content

Commit

Permalink
tests : sync test-grad0 from llama.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Jun 19, 2023
1 parent 169ff7b commit ea67c09
Showing 1 changed file with 60 additions and 2 deletions.
62 changes: 60 additions & 2 deletions tests/test-grad0.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <stdlib.h>
#include <assert.h>

#define MAX_NARGS 2
#define MAX_NARGS 3

#undef MIN
#undef MAX
Expand Down Expand Up @@ -350,7 +350,7 @@ int main(int argc, const char ** argv) {

// original loop: 1000
int niter = 4;
const char *env = getenv("GGML_NITER");
const char *env = getenv("GGML_NLOOP");
if (env != NULL) {
niter = atoi(env);
}
Expand Down Expand Up @@ -1090,6 +1090,25 @@ int main(int argc, const char ** argv) {
}
}

// cross_entropy_loss
{
const int nargs = 1;

int64_t ne2[4];
get_random_dims(ne2, 4);

for (int ndims = 1; ndims <= 3; ++ndims) {
x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
x[1] = get_random_tensor(ctx0, ndims, ne2, 0.0f, 1.0f);
ggml_set_param(ctx0, x[0]);

struct ggml_tensor * f = ggml_sum(ctx0, ggml_cross_entropy_loss(ctx0, x[0], x[1]));

check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-1f, 1e-2f, INFINITY);
// finite differences regularly fails!
}
}

// rope
{
const int nargs = 1;
Expand Down Expand Up @@ -1124,6 +1143,45 @@ int main(int argc, const char ** argv) {
}
}

// flash_attn
{
const int nargs = 3;

int64_t ne2[4];

get_random_dims(ne2, 4);
int64_t D = ne2[0];
int64_t N = ne2[1];
int64_t M = ne2[2] + N;
int64_t B = ne2[3];

for (int masked = 0; masked <= 1; ++masked) {
for (int ndims = 2; ndims <= 4; ++ndims) {
int64_t neq[4] = { D, N, B, ne[3] };
int64_t nek[4] = { D, M, B, ne[3] };
int64_t nev[4] = { M, D, B, ne[3] };
if (ndims == 2) {
neq[2] = 1; neq[3] = 1;
nek[2] = 1; nek[3] = 1;
nev[2] = 1; nev[3] = 1;
} else if (ndims == 3) {
neq[3] = 1;
nek[3] = 1;
nev[3] = 1;
}
x[0] = get_random_tensor(ctx0, ndims, neq, -0.1250f, 0.1250f);
x[1] = get_random_tensor(ctx0, ndims, nek, -0.1250f, 0.1250f);
x[2] = get_random_tensor(ctx0, ndims, nev, -0.1250f, 0.1250f);
ggml_set_param(ctx0, x[0]);
ggml_set_param(ctx0, x[1]);
ggml_set_param(ctx0, x[2]);

struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));

check_gradient("flash_attn", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
}
}
}
ggml_free(ctx0);
}

Expand Down

0 comments on commit ea67c09

Please sign in to comment.