Skip to content

Commit

Permalink
LPGEMM <u|s>8s8s16ou8 fixes for incorrect zero point addition.
Browse files Browse the repository at this point in the history
-The zero point data type is different based on the downscale data
type. For int8_t downscale type, zero point type is int8_t whereas for
uint8_t downscale type, it is uint8_t. During downscale post-op, the
micro-kernels upscales the zero point from its data type (int8_t or
uint8_t) to that of the accumulation data type and then performs the
zero point addition. The accumulated output is then stored as downscaled
type in a later storage phase. For the <u|s>8s8s16 micro-kernels, the
upscaling to int16_t (accumulation type) is always performed assuming
the zero point is int8_t using the _mm256_cvtepi8_epi16 instruction.
However this will result in incorrect upscaled zero point values if the
downscale type is uint8_t and the associated zero point type is also
uint8_t. This issue is corrected by switching between the correct
upscale instruction based on the zero point type.

AMD-Internal: [SWLCSG-2500]
Change-Id: I92eed4aed686c447d29312836b9e551d6dd4b076
  • Loading branch information
MithunMohanKadavil committed Nov 2, 2023
1 parent b3391ef commit d184467
Show file tree
Hide file tree
Showing 9 changed files with 399 additions and 79 deletions.
21 changes: 19 additions & 2 deletions kernels/zen/lpgemm/s8s8s16/lpgemm_s8_6x32rowmajor_amd256.c
Original file line number Diff line number Diff line change
Expand Up @@ -776,10 +776,19 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32)
post_ops_attr.post_op_c_j + (1 * 8));

// Load zero points (2 byte values).
__m128i zero_point_0 =
__m128i _zero_point_0 =
_mm_loadu_si128(
( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 16 ) ) );
__m256i zero_point_0 = _mm256_setzero_si256();
if ( post_ops_attr.c_stor_type == S8 )
{
zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 );
}
else if ( post_ops_attr.c_stor_type == U8 )
{
zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 );
}

// Scale first 16 columns of the 6 rows.
CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0)
Expand All @@ -798,10 +807,18 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32)
(float *)post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_j + (3 * 8));

zero_point_0 =
_zero_point_0 =
_mm_loadu_si128(
( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 1 * 16 ) ) );
if ( post_ops_attr.c_stor_type == S8 )
{
zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 );
}
else if ( post_ops_attr.c_stor_type == U8 )
{
zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 );
}

// Scale next 16 columns of the 6 rows.
CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2, zero_point_0)
Expand Down
69 changes: 60 additions & 9 deletions kernels/zen/lpgemm/s8s8s16/lpgemm_s8_m_fringe_amd256.c
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,8 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32)
__m256i temp_32[2];
__m256 temp_float[2];
__m256 scale_1, scale_2;
__m128i zero_point_0;
__m128i _zero_point_0;
__m256i zero_point_0 = _mm256_setzero_si256();
__m256 res_1, res_2;

/* Load the scale vector values into the register*/
Expand All @@ -535,10 +536,18 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32)
post_ops_attr.post_op_c_j + (1 * 8));

// Load zero points (2 byte values).
zero_point_0 =
_zero_point_0 =
_mm_loadu_si128(
( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 16 ) ) );
if ( post_ops_attr.c_stor_type == S8 )
{
zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 );
}
else if ( post_ops_attr.c_stor_type == U8 )
{
zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 );
}

// Scale first 16 columns of the 4 rows.
CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0)
Expand All @@ -555,10 +564,18 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32)
(float *)post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_j + (3 * 8));

zero_point_0 =
_zero_point_0 =
_mm_loadu_si128(
( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 1 * 16 ) ) );
if ( post_ops_attr.c_stor_type == S8 )
{
zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 );
}
else if ( post_ops_attr.c_stor_type == U8 )
{
zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 );
}

// Scale next 16 columns of the 4 rows.
CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2, zero_point_0)
Expand Down Expand Up @@ -930,7 +947,8 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x32)
__m256i temp_32[2];
__m256 temp_float[2];
__m256 scale_1, scale_2;
__m128i zero_point_0;
__m128i _zero_point_0;
__m256i zero_point_0 = _mm256_setzero_si256();
__m256 res_1, res_2;

/* Load the scale vector values into the register*/
Expand All @@ -944,10 +962,18 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x32)
post_ops_attr.post_op_c_j + (1 * 8));

// Load zero points (2 byte values).
zero_point_0 =
_zero_point_0 =
_mm_loadu_si128(
( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 16 ) ) );
if ( post_ops_attr.c_stor_type == S8 )
{
zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 );
}
else if ( post_ops_attr.c_stor_type == U8 )
{
zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 );
}

// Scale first 16 columns of the 4 rows.
CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0)
Expand All @@ -962,10 +988,18 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x32)
(float *)post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_j + (3 * 8));

zero_point_0 =
_zero_point_0 =
_mm_loadu_si128(
( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 1 * 16 ) ) );
if ( post_ops_attr.c_stor_type == S8 )
{
zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 );
}
else if ( post_ops_attr.c_stor_type == U8 )
{
zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 );
}

// Scale next 16 columns of the 4 rows.
CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2, zero_point_0)
Expand Down Expand Up @@ -1229,7 +1263,8 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x32)
__m256i temp_32[2];
__m256 temp_float[2];
__m256 scale_1, scale_2;
__m128i zero_point_0;
__m128i _zero_point_0;
__m256i zero_point_0 = _mm256_setzero_si256();
__m256 res_1, res_2;

/* Load the scale vector values into the register*/
Expand All @@ -1243,10 +1278,18 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x32)
post_ops_attr.post_op_c_j + (1 * 8));

// Load zero points (2 byte values).
zero_point_0 =
_zero_point_0 =
_mm_loadu_si128(
( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 16 ) ) );
if ( post_ops_attr.c_stor_type == S8 )
{
zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 );
}
else if ( post_ops_attr.c_stor_type == U8 )
{
zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 );
}

// Scale first 16 columns of the 4 rows.
CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0)
Expand All @@ -1260,10 +1303,18 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x32)
(float *)post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_j + (3 * 8));

zero_point_0 =
_zero_point_0 =
_mm_loadu_si128(
( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 1 * 16 ) ) );
if ( post_ops_attr.c_stor_type == S8 )
{
zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 );
}
else if ( post_ops_attr.c_stor_type == U8 )
{
zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 );
}

// Scale next 16 columns of the 4 rows.
CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2, zero_point_0)
Expand Down
111 changes: 90 additions & 21 deletions kernels/zen/lpgemm/s8s8s16/lpgemm_s8_mn_fringe_amd256.c
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,8 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x16)
__m256i temp_32[2];
__m256 temp_float[2];
__m256 scale_1, scale_2;
__m128i zero_point_0;
__m128i _zero_point_0;
__m256i zero_point_0 = _mm256_setzero_si256();
__m256 res_1, res_2;

/* Load the scale vector values into the register*/
Expand All @@ -398,10 +399,18 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x16)
post_ops_attr.post_op_c_j + (1 * 8));

// Load zero points (2 byte values).
zero_point_0 =
_zero_point_0 =
_mm_loadu_si128(
( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 16 ) ) );
if ( post_ops_attr.c_stor_type == S8 )
{
zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 );
}
else if ( post_ops_attr.c_stor_type == U8 )
{
zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 );
}

// Scale first 16 columns of the 4 rows.
CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0)
Expand Down Expand Up @@ -816,7 +825,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4xlt16)
__m256i temp_32[2];
__m256 temp_float[2];
__m256 scale_1, scale_2;
__m128i zero_point_0;
__m128i _zero_point_0;
__m256i zero_point_0 = _mm256_setzero_si256();
__m256 res_1, res_2;

float float_buf[16];
Expand All @@ -828,11 +838,24 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4xlt16)
scale_1 = _mm256_loadu_ps(float_buf + (0 * 8));
scale_2 = _mm256_loadu_ps(float_buf + (1 * 8));

int8_t zero_point_buf[16];
if ( post_ops_attr.c_stor_type == S8 )
{
int8_t zero_point_buf[16];

memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) );
_zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf );
zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 );
}
else if ( post_ops_attr.c_stor_type == U8 )
{
uint8_t zero_point_buf[16];

memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) );
zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf );
memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) );
_zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf );
zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 );
}

// Scale first 16 columns of the 6 rows.
CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0)
Expand Down Expand Up @@ -1135,7 +1158,8 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x16)
__m256i temp_32[2];
__m256 temp_float[2];
__m256 scale_1, scale_2;
__m128i zero_point_0;
__m128i _zero_point_0;
__m256i zero_point_0 = _mm256_setzero_si256();
__m256 res_1, res_2;

/* Load the scale vector values into the register*/
Expand All @@ -1149,10 +1173,18 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x16)
post_ops_attr.post_op_c_j + (1 * 8));

// Load zero points (2 byte values).
zero_point_0 =
_zero_point_0 =
_mm_loadu_si128(
( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 16 ) ) );
if ( post_ops_attr.c_stor_type == S8 )
{
zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 );
}
else if ( post_ops_attr.c_stor_type == U8 )
{
zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 );
}

// Scale first 16 columns of the 2 rows.
CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0)
Expand Down Expand Up @@ -1439,7 +1471,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2xlt16)
__m256i temp_32[2];
__m256 temp_float[2];
__m256 scale_1, scale_2;
__m128i zero_point_0;
__m128i _zero_point_0;
__m256i zero_point_0 = _mm256_setzero_si256();
__m256 res_1, res_2;

float float_buf[16];
Expand All @@ -1451,11 +1484,24 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2xlt16)
scale_1 = _mm256_loadu_ps(float_buf + (0 * 8));
scale_2 = _mm256_loadu_ps(float_buf + (1 * 8));

int8_t zero_point_buf[16];
if ( post_ops_attr.c_stor_type == S8 )
{
int8_t zero_point_buf[16];

memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) );
zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf );
memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) );
_zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf );
zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 );
}
else if ( post_ops_attr.c_stor_type == U8 )
{
uint8_t zero_point_buf[16];

memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) );
_zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf );
zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 );
}

// Scale first 16 columns of the 6 rows.
CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0)
Expand Down Expand Up @@ -1684,7 +1730,8 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x16)
__m256i temp_32[2];
__m256 temp_float[2];
__m256 scale_1, scale_2;
__m128i zero_point_0;
__m128i _zero_point_0;
__m256i zero_point_0 = _mm256_setzero_si256();
__m256 res_1, res_2;

/* Load the scale vector values into the register*/
Expand All @@ -1698,10 +1745,18 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x16)
post_ops_attr.post_op_c_j + (1 * 8));

// Load zero points (2 byte values).
zero_point_0 =
_zero_point_0 =
_mm_loadu_si128(
( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 16 ) ) );
if ( post_ops_attr.c_stor_type == S8 )
{
zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 );
}
else if ( post_ops_attr.c_stor_type == U8 )
{
zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 );
}

// Scale first 16 columns of the 2 rows.
CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0)
Expand Down Expand Up @@ -1927,7 +1982,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1xlt16)
__m256i temp_32[2];
__m256 temp_float[2];
__m256 scale_1, scale_2;
__m128i zero_point_0;
__m128i _zero_point_0;
__m256i zero_point_0 = _mm256_setzero_si256();
__m256 res_1, res_2;

float float_buf[16];
Expand All @@ -1939,11 +1995,24 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1xlt16)
scale_1 = _mm256_loadu_ps(float_buf + (0 * 8));
scale_2 = _mm256_loadu_ps(float_buf + (1 * 8));

int8_t zero_point_buf[16];
if ( post_ops_attr.c_stor_type == S8 )
{
int8_t zero_point_buf[16];

memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) );
_zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf );
zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 );
}
else if ( post_ops_attr.c_stor_type == U8 )
{
uint8_t zero_point_buf[16];

memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) );
zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf );
memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) );
_zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf );
zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 );
}

// Scale first 16 columns of the 2 rows.
CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0)
Expand Down
Loading

0 comments on commit d184467

Please sign in to comment.