From 93185775e69bdd43175714b5193b19a8ab795cd2 Mon Sep 17 00:00:00 2001 From: Danny Robson Date: Fri, 23 Mar 2018 17:52:08 +1100 Subject: [PATCH] coord/simd: template on arity and type --- coord/simd.cpp | 4 +- coord/simd_sse.hpp | 246 ++++++++++++++++++++++++++++++++++---------- test/coord/simd.cpp | 29 +++--- 3 files changed, 210 insertions(+), 69 deletions(-) diff --git a/coord/simd.cpp b/coord/simd.cpp index cab5b249..91fdfadb 100644 --- a/coord/simd.cpp +++ b/coord/simd.cpp @@ -18,9 +18,11 @@ #include + /////////////////////////////////////////////////////////////////////////////// +template std::ostream& -util::coord::operator<< (std::ostream &os, simd val) +util::coord::operator<< (std::ostream &os, simd val) { return os << "[ " << val[0] << ", " diff --git a/coord/simd_sse.hpp b/coord/simd_sse.hpp index 2346ff5c..485b17bf 100644 --- a/coord/simd_sse.hpp +++ b/coord/simd_sse.hpp @@ -32,21 +32,36 @@ namespace util::coord { /////////////////////////////////////////////////////////////////////////// - namespace detail { - - } - constexpr int alignment = 16; + + template + struct native_type { }; + + template <> struct native_type<1,float> { using type = __m128; }; + template <> struct native_type<2,float> { using type = __m128; }; + template <> struct native_type<3,float> { using type = __m128; }; + template <> struct native_type<4,float> { using type = __m128; }; + + template <> struct native_type<1,double> { using type = __m128d; }; + template <> struct native_type<2,double> { using type = __m128d; }; + + template <> struct native_type<1,uint32_t> { using type = __m128i; }; + template <> struct native_type<2,uint32_t> { using type = __m128i; }; + template <> struct native_type<3,uint32_t> { using type = __m128i; }; + template <> struct native_type<4,uint32_t> { using type = __m128i; }; + + + template struct alignas (16) simd { /////////////////////////////////////////////////////////////////////// - simd (float a, float b, float c, float d): + simd (ValueT a, ValueT b, ValueT c, ValueT d): data (_mm_setr_ps (a, b, c, d)) { ; } //--------------------------------------------------------------------- - simd (float v): + simd (ValueT v): data (_mm_set_ps1 (v)) { ; } @@ -63,25 +78,77 @@ namespace util::coord { explicit operator bool () const; - float operator[] (int idx) const { return data[idx]; } + ValueT operator[] (int idx) const { return data[idx]; } /////////////////////////////////////////////////////////////////////// - __m128 data; + + template + struct accessor { + operator ValueT () const noexcept + { +#ifdef __SSE4_1__ + return _mm_extrat_epi32 (data, IndexV); +#else + return _mm_cvtss_f32 ( + _mm_shuffle_ps ( + data, + data, + _MM_SHUFFLE (IndexV, IndexV, IndexV, IndexV) + ) + ); +#endif + } + + accessor& operator= (ValueT); + + __m128 data; + }; + + + union { + __m128 data; + accessor<0> x; + accessor<1> y; + accessor<2> z; + accessor<3> w; + }; }; /////////////////////////////////////////////////////////////////////////// - simd operator+ (simd a, simd b) { return _mm_add_ps (a.data, b.data); } - simd operator- (simd a, simd b) { return _mm_sub_ps (a.data, b.data); } - simd operator/ (simd a, simd b) { return _mm_div_ps (a.data, b.data); } - simd operator* (simd a, simd b) { return _mm_mul_ps (a.data, b.data); } + template + simd + operator+ (simd a, simd b) + { return _mm_add_ps (a.data, b.data); } //------------------------------------------------------------------------- + template + simd + operator- (simd a, simd b) + { return _mm_sub_ps (a.data, b.data); } + + + //------------------------------------------------------------------------- + template + simd + operator/ (simd a, simd b) + { return _mm_div_ps (a.data, b.data); } + + + //------------------------------------------------------------------------- + template + simd + operator* (simd a, simd b) + { return _mm_mul_ps (a.data, b.data); } + + + /////////////////////////////////////////////////////////////////////////// // computes a*b + c + template auto - fma (simd a, simd b, simd c) + fma (simd a, simd b, simd c) { #if defined(__FMA__) return _mm_fmadd_ps (a.data, b.data, c.data); @@ -92,22 +159,65 @@ namespace util::coord { /////////////////////////////////////////////////////////////////////////// - simd operator< (simd a, simd b) { return _mm_cmplt_ps (a.data, b.data); } - simd operator<= (simd a, simd b) { return _mm_cmple_ps (a.data, b.data); } - simd operator> (simd a, simd b) { return _mm_cmpgt_ps (a.data, b.data); } - simd operator>= (simd a, simd b) { return _mm_cmpge_ps (a.data, b.data); } - simd operator== (simd a, simd b) { return _mm_cmpeq_ps (a.data, b.data); } + template + simd + operator< (simd a, simd b) + { return _mm_cmplt_ps (a.data, b.data); } - //------------------------------------------------------------------------- - simd operator| (simd a, simd b) { return _mm_or_ps (a.data, b.data); } - simd operator|| (simd a, simd b) { return _mm_or_ps (a.data, b.data); } - simd operator& (simd a, simd b) { return _mm_and_ps (a.data, b.data); } - simd operator&& (simd a, simd b) { return _mm_and_ps (a.data, b.data); } + template + simd + operator<= (simd a, simd b) + { return _mm_cmple_ps (a.data, b.data); } + + + template + simd + operator> (simd a, simd b) + { return _mm_cmpgt_ps (a.data, b.data); } + + + template + simd + operator>= (simd a, simd b) + { return _mm_cmpge_ps (a.data, b.data); } + + + template + simd + operator== (simd a, simd b) + { return _mm_cmpeq_ps (a.data, b.data); } /////////////////////////////////////////////////////////////////////////// - simd floor (simd val) + template + simd + operator| (simd a, simd b) + { return _mm_or_ps (a.data, b.data); } + + + template + simd + operator|| (simd a, simd b) + { return _mm_or_ps (a.data, b.data); } + + + template + simd + operator& (simd a, simd b) + { return _mm_and_ps (a.data, b.data); } + + + template + simd + operator&& (simd a, simd b) + { return _mm_and_ps (a.data, b.data); } + + + /////////////////////////////////////////////////////////////////////////// + template + simd + floor (simd val) { #if defined(__SSE4_1__) return mm_floor_ps (val.data); @@ -115,17 +225,19 @@ namespace util::coord { // NOTE: assumes the rounding mode is 'nearest' // cast to int and back to truncate - const auto truncated = _mm_cvtepi32_ps (_mm_cvtps_epi32 (val.data)); + const simd truncated = _mm_cvtepi32_ps (_mm_cvtps_epi32 (val.data)); // if the truncated value is greater than the original value we got // rounded up so we need to decrement to get the true value. - return truncated - ((truncated > val) & simd (1)); + return truncated - ((truncated > val) & simd (1)); #endif } //--------------------------------------------------------------------------- - simd ceil (simd val) + template + simd + ceil (simd val) { #if defined(__SSE4_1__) return _mm_ceil_ps (val.data); @@ -133,17 +245,18 @@ namespace util::coord { // NOTE: assumes the rounding mode is 'nearest' // truncate by casting to int and back - const auto truncated = _mm_cvtepi32_ps (_mm_cvtps_epi32 (val.data)); + const simd truncated = _mm_cvtepi32_ps (_mm_cvtps_epi32 (val.data)); // if the truncated value is below the original value it got rounded // down and needs to be incremented to get the true value. - return truncated + ((truncated < val) & simd (1)); + return truncated + ((truncated < val) & simd (1)); #endif } /////////////////////////////////////////////////////////////////////////// - simd - select (simd mask, simd a, simd b) + template + simd + select (simd mask, simd a, simd b) { #if defined(__SSE4_1__) return _mm_blendv_ps (a, b, mask); @@ -157,35 +270,48 @@ namespace util::coord { //------------------------------------------------------------------------- + template bool - all (simd val) + all (simd val) { return _mm_movemask_ps (val.data) == 0b1111; } //------------------------------------------------------------------------- + template bool - any (simd val) + any (simd val) { return _mm_movemask_ps (val.data); } /////////////////////////////////////////////////////////////////////////// - simd min (simd a, simd b) { return _mm_min_ps (a.data, b.data); } - simd max (simd a, simd b) { return _mm_max_ps (a.data, b.data); } + template + simd + min (simd a, simd b) + { return _mm_min_ps (a.data, b.data); } - simd - clamp (simd val, simd lo, simd hi) + + template + simd + max (simd a, simd b) + { return _mm_max_ps (a.data, b.data); } + + + template + simd + clamp (simd val, simd lo, simd hi) { return min (max (val, lo), hi); } /////////////////////////////////////////////////////////////////////////// - simd - sum (simd a) + template + simd + sum (simd a) { auto part = _mm_hadd_ps (a.data, a.data); return _mm_hadd_ps (part, part); @@ -200,45 +326,50 @@ namespace util::coord { return _mm_dp_ps (a, b, 0xff); } #else - simd - dot (simd a, simd b) + template + simd + dot (simd a, simd b) { return sum (a * b); } #endif /////////////////////////////////////////////////////////////////////////// - simd sqrt (simd a) { return _mm_sqrt_ps (a.data); } - simd rsqrt (simd a) { return _mm_rsqrt_ps (a.data); } + template simd sqrt (simd a) { return _mm_sqrt_ps (a.data); } + template simd rsqrt (simd a) { return _mm_rsqrt_ps (a.data); } /////////////////////////////////////////////////////////////////////////// - simd - norm2 (simd a) + template + auto + norm2 (simd a) { return dot (a, a); } //------------------------------------------------------------------------- - simd - norm (simd a) + template + auto + norm (simd a) { return sqrt (norm2 (a)); } //------------------------------------------------------------------------- - simd - normalised (simd a) + template + auto + normalised (simd a) { return a * rsqrt (norm (a)); } /////////////////////////////////////////////////////////////////////////// - auto - abs (simd a) + template + simd + abs (simd a) { auto bffff = _mm_set1_epi32 (-1); auto b7fff = _mm_srli_epi32 (bffff, 1); @@ -250,22 +381,25 @@ namespace util::coord { /////////////////////////////////////////////////////////////////////////// - simd - hypot (simd a) + template + auto + hypot (simd a) { return sqrt (sum (a * a)); } /////////////////////////////////////////////////////////////////////////// - simd::operator bool() const + template + simd::operator bool() const { - return all (data); + return all (*this); } /////////////////////////////////////////////////////////////////////////// - std::ostream& operator<< (std::ostream &os, simd val); + template + std::ostream& operator<< (std::ostream &os, simd val); } #endif diff --git a/test/coord/simd.cpp b/test/coord/simd.cpp index e076afb3..56acdddc 100644 --- a/test/coord/simd.cpp +++ b/test/coord/simd.cpp @@ -5,41 +5,46 @@ int main () { - using util::coord::simd; + using simd_t = util::coord::simd<4,float>; util::TAP::logger tap; { - const simd a (1,2,3,4); - const simd b (4,1,3,2); + const simd_t a (1,2,3,4); + const simd_t b (4,1,3,2); const float res = dot (a, b)[0]; tap.expect_eq (res, 4+2+9+8, "trivial dot product"); } { - const simd a (1, 2, 3, 4); - const simd b (0, 3, 3, 9); + const simd_t a (1, 2, 3, 4); + const simd_t b (0, 3, 3, 9); const auto lo = min (a, b); const auto hi = max (a, b); - tap.expect_eq (lo, simd {0,2,3,4}, "vector minimum"); - tap.expect_eq (hi, simd {1,3,3,9}, "vector maximum"); + tap.expect_eq (lo, simd_t {0,2,3,4}, "vector minimum"); + tap.expect_eq (hi, simd_t {1,3,3,9}, "vector maximum"); } { - const simd val { -INFINITY, INFINITY, 0, -9 }; - tap.expect_eq (abs (val), simd {INFINITY,INFINITY,0,9}, "absolute value"); + const simd_t val { -INFINITY, INFINITY, 0, -9 }; + tap.expect_eq (abs (val), simd_t {INFINITY,INFINITY,0,9}, "absolute value"); } { - const simd test { -1.25f, 1.25f, 0.f, 1.f }; + const simd_t test { -1.25f, 1.25f, 0.f, 1.f }; const auto lo = floor (test); const auto hi = ceil (test); - tap.expect_eq (lo, simd { -2, 1, 0, 1 }, "floor"); - tap.expect_eq (hi, simd { -1, 2, 0, 1 }, "ceil"); + tap.expect_eq (lo, simd_t { -2, 1, 0, 1 }, "floor"); + tap.expect_eq (hi, simd_t { -1, 2, 0, 1 }, "ceil"); }; + tap.expect_eq (simd_t {1,2,3,4}.x, 1.f, "named accessor, x"); + tap.expect_eq (simd_t {1,2,3,4}.y, 2.f, "named accessor, y"); + tap.expect_eq (simd_t {1,2,3,4}.z, 3.f, "named accessor, z"); + tap.expect_eq (simd_t {1,2,3,4}.w, 4.f, "named accessor, w"); + return tap.status (); } \ No newline at end of file