-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor all SIMD to one file, simd_support.rs
This should make it a bit easier to port to other SIMD instruction sets when the SIMD instructions are not littered randomly around the tensor.rs file.
- Loading branch information
Showing
3 changed files
with
340 additions
and
316 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
// This file contains platform-specific SIMD so that rest of rllama does not need to care which | ||
// platform it is on. | ||
|
||
use core::arch::x86_64::*; | ||
use half::f16; | ||
|
||
pub type I32x8 = __m256i; | ||
pub type F32x8 = __m256; | ||
pub type I16x8 = __m128i; | ||
|
||
/* ------------------ */ | ||
/* Loading and storing things */ | ||
/* ------------------ */ | ||
|
||
#[inline] | ||
pub fn load_i16x8(ptr: *const I16x8) -> I16x8 { | ||
unsafe { _mm_loadu_si128(ptr) } | ||
} | ||
|
||
#[inline] | ||
pub fn store_i16x8(ptr: *mut I16x8, a: I16x8) { | ||
unsafe { _mm_storeu_si128(ptr, a) } | ||
} | ||
|
||
#[inline] | ||
pub fn load_f32x8(ptr: *const F32x8) -> F32x8 { | ||
unsafe { _mm256_loadu_ps(ptr as *const f32) } | ||
} | ||
|
||
#[inline] | ||
pub fn store_f32x8(ptr: *mut F32x8, a: F32x8) { | ||
unsafe { _mm256_storeu_ps(ptr as *mut f32, a) } | ||
} | ||
|
||
#[inline] | ||
pub fn gather_f32x8(ptr: *const f32, indices: I32x8) -> F32x8 { | ||
unsafe { _mm256_i32gather_ps(ptr, indices, 1) } | ||
} | ||
|
||
/* ------------------ */ | ||
/* Conversions */ | ||
/* ------------------ */ | ||
|
||
#[inline] | ||
pub fn i16x8_as_f16_to_f32x8(a: I16x8) -> F32x8 { | ||
unsafe { _mm256_cvtph_ps(a) } | ||
} | ||
|
||
#[inline] | ||
pub fn f32x8_to_i16x8_as_f16(a: F32x8) -> I16x8 { | ||
unsafe { _mm256_cvtps_ph(a, 0) } | ||
} | ||
|
||
/* | ||
* Constants, creating from constants | ||
*/ | ||
|
||
pub fn f32x8_zero() -> F32x8 { | ||
unsafe { _mm256_setzero_ps() } | ||
} | ||
|
||
pub fn i16x8_zero() -> I16x8 { | ||
unsafe { _mm_setzero_si128() } | ||
} | ||
|
||
pub fn f32x8_singleton(value: f32) -> F32x8 { | ||
unsafe { _mm256_set1_ps(value) } | ||
} | ||
|
||
pub fn i32x8_from_values( | ||
val0: i32, | ||
val1: i32, | ||
val2: i32, | ||
val3: i32, | ||
val4: i32, | ||
val5: i32, | ||
val6: i32, | ||
val7: i32, | ||
) -> I32x8 { | ||
unsafe { _mm256_set_epi32(val0, val1, val2, val3, val4, val5, val6, val7) } | ||
} | ||
|
||
/* | ||
* Operations | ||
*/ | ||
|
||
// FMA | ||
|
||
// a * b + c | ||
pub fn fma_f32x8(a: F32x8, b: F32x8, c: F32x8) -> F32x8 { | ||
unsafe { _mm256_fmadd_ps(a, b, c) } | ||
} | ||
|
||
// Horizontal sums | ||
|
||
#[inline] | ||
pub fn horizontal_sum_f32x8(mut ymm: __m256) -> f32 { | ||
unsafe { | ||
let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1); | ||
ymm = _mm256_add_ps(ymm, ymm2); | ||
ymm = _mm256_hadd_ps(ymm, ymm); | ||
ymm = _mm256_hadd_ps(ymm, ymm); | ||
_mm256_cvtss_f32(ymm) | ||
} | ||
} | ||
|
||
#[inline] | ||
pub fn horizontal_sum_and_f32_to_f16(mut ymm: __m256) -> f16 { | ||
unsafe { | ||
let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1); | ||
ymm = _mm256_add_ps(ymm, ymm2); | ||
ymm = _mm256_hadd_ps(ymm, ymm); | ||
ymm = _mm256_hadd_ps(ymm, ymm); | ||
f16::from_f32(_mm256_cvtss_f32(ymm)) | ||
} | ||
} |
Oops, something went wrong.