Skip to content

Commit

Permalink
Refactor all SIMD to one file, simd_support.rs
Browse files Browse the repository at this point in the history
This should make it a bit easier to port to other SIMD instruction sets
when the SIMD instructions are not littered randomly around the
tensor.rs file.
  • Loading branch information
Noeda committed Mar 18, 2023
1 parent 25e3e12 commit 9c86c17
Show file tree
Hide file tree
Showing 3 changed files with 340 additions and 316 deletions.
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
pub mod embedding;
pub mod protomodels;
pub mod rllama_main;
pub mod simd_support;
pub mod tensor;
#[cfg(feature = "opencl")]
pub mod tensor_opencl_support;
Expand Down
116 changes: 116 additions & 0 deletions src/simd_support.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// This file contains platform-specific SIMD so that rest of rllama does not need to care which
// platform it is on.

use core::arch::x86_64::*;
use half::f16;

pub type I32x8 = __m256i;
pub type F32x8 = __m256;
pub type I16x8 = __m128i;

/* ------------------ */
/* Loading and storing things */
/* ------------------ */

#[inline]
pub fn load_i16x8(ptr: *const I16x8) -> I16x8 {
unsafe { _mm_loadu_si128(ptr) }
}

#[inline]
pub fn store_i16x8(ptr: *mut I16x8, a: I16x8) {
unsafe { _mm_storeu_si128(ptr, a) }
}

#[inline]
pub fn load_f32x8(ptr: *const F32x8) -> F32x8 {
unsafe { _mm256_loadu_ps(ptr as *const f32) }
}

#[inline]
pub fn store_f32x8(ptr: *mut F32x8, a: F32x8) {
unsafe { _mm256_storeu_ps(ptr as *mut f32, a) }
}

#[inline]
pub fn gather_f32x8(ptr: *const f32, indices: I32x8) -> F32x8 {
unsafe { _mm256_i32gather_ps(ptr, indices, 1) }
}

/* ------------------ */
/* Conversions */
/* ------------------ */

#[inline]
pub fn i16x8_as_f16_to_f32x8(a: I16x8) -> F32x8 {
unsafe { _mm256_cvtph_ps(a) }
}

#[inline]
pub fn f32x8_to_i16x8_as_f16(a: F32x8) -> I16x8 {
unsafe { _mm256_cvtps_ph(a, 0) }
}

/*
* Constants, creating from constants
*/

pub fn f32x8_zero() -> F32x8 {
unsafe { _mm256_setzero_ps() }
}

pub fn i16x8_zero() -> I16x8 {
unsafe { _mm_setzero_si128() }
}

pub fn f32x8_singleton(value: f32) -> F32x8 {
unsafe { _mm256_set1_ps(value) }
}

pub fn i32x8_from_values(
val0: i32,
val1: i32,
val2: i32,
val3: i32,
val4: i32,
val5: i32,
val6: i32,
val7: i32,
) -> I32x8 {
unsafe { _mm256_set_epi32(val0, val1, val2, val3, val4, val5, val6, val7) }
}

/*
* Operations
*/

// FMA

// a * b + c
pub fn fma_f32x8(a: F32x8, b: F32x8, c: F32x8) -> F32x8 {
unsafe { _mm256_fmadd_ps(a, b, c) }
}

// Horizontal sums

#[inline]
pub fn horizontal_sum_f32x8(mut ymm: __m256) -> f32 {
unsafe {
let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1);
ymm = _mm256_add_ps(ymm, ymm2);
ymm = _mm256_hadd_ps(ymm, ymm);
ymm = _mm256_hadd_ps(ymm, ymm);
_mm256_cvtss_f32(ymm)
}
}

#[inline]
pub fn horizontal_sum_and_f32_to_f16(mut ymm: __m256) -> f16 {
unsafe {
let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1);
ymm = _mm256_add_ps(ymm, ymm2);
ymm = _mm256_hadd_ps(ymm, ymm);
ymm = _mm256_hadd_ps(ymm, ymm);
f16::from_f32(_mm256_cvtss_f32(ymm))
}
}
Loading

0 comments on commit 9c86c17

Please sign in to comment.