From 341907ac79f34a2ef7bd8375af136cabe8145c1a Mon Sep 17 00:00:00 2001 From: Danny Robson Date: Tue, 20 Mar 2018 13:30:05 +1100 Subject: [PATCH] coord/simd: add more sse operations --- coord/simd_neon.hpp | 96 ++++++++++++++++++++++++++ coord/simd_sse.hpp | 162 ++++++++++++++++++++++++++------------------ test/coord/simd.cpp | 40 ++++++++--- 3 files changed, 224 insertions(+), 74 deletions(-) create mode 100644 coord/simd_neon.hpp diff --git a/coord/simd_neon.hpp b/coord/simd_neon.hpp new file mode 100644 index 00000000..ab0eb059 --- /dev/null +++ b/coord/simd_neon.hpp @@ -0,0 +1,96 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Copyright 2018 Danny Robson + */ + +#ifndef __ARM_NEON__ +#error +#endif + +#include "arm_neon.h" + +namespace util::coord { + struct alignas (16) simd { + using value_type = float32x4_t; + + /////////////////////////////////////////////////////////////////////// + simd (float a, float b, float c, float d): + data (_mm_setr_ps (a, b, c, d)) + { ; } + + + //--------------------------------------------------------------------- + simd (float v): + data (_mm_set_ps1 (v)) + { ; } + + + //--------------------------------------------------------------------- + simd (value_type _data): + data (_data) + { ; } + + + //--------------------------------------------------------------------- + operator value_type& () { return data; } + operator const value_type& () const { return data; } + + explicit operator bool () const; + + float operator[] (int idx) const { return data[idx]; } + + + /////////////////////////////////////////////////////////////////////// + value_type data; + }; + + + /////////////////////////////////////////////////////////////////////////// + simd operator* (simd a, simd b) { return vmulq_f32 (a, b); }; + simd operator/ (simd a, simd b) { return vdivq_f32 (a, b); }; + simd operator+ (simd a, simd b) { return vaddq_f32 (a, b); }; + simd operator- (simd a, simd b) { return vsubq_f32 (a, b); }; + + simd operator< (simd a, simd b); + simd operator<= (simd a, simd b); + simd operator> (simd a, simd b); + simd operator>= (simd a, simd b); + simd operator== (simd a, simd b); + + simd select (simd mask, simd a, simd b); + + + auto sum (simd val) + { + // reverse and add to self giving: 0123 + 3210 + auto revq = vrev64q_f32 (val); + auto pair = vaddq_f32 (val, revq); + + // reverse the upper and lower pairs given (2301 + 1023) + auto shuf = vcombine_f32 ( + vget_high_f32 (pair), + vget_low_f32 (pair) + ); + + // add both partial sums: (2301 + 1032) + (0123 + 3210) + return vaddq_f32 (shuf, pair); + } + + simd + dot (simd a, simd b) + { + return sum (a * b); + } +} +#endif diff --git a/coord/simd_sse.hpp b/coord/simd_sse.hpp index a544ce60..cb18bc96 100644 --- a/coord/simd_sse.hpp +++ b/coord/simd_sse.hpp @@ -18,8 +18,8 @@ #ifndef CRUFT_UTIL_COORD_SIMD_SSE_HPP #define CRUFT_UTIL_COORD_SIMD_SSE_HPP -#ifndef __SSE2__ -#error "SSE2 is required" +#ifndef __SSE3__ +#error "SSE3 is required" #endif #include @@ -58,8 +58,8 @@ namespace util::coord { //--------------------------------------------------------------------- - operator __m128& () { return data; } - operator const __m128& () const { return data; } + explicit operator __m128& () { return data; } + explicit operator const __m128& () const { return data; } explicit operator bool () const; @@ -72,25 +72,75 @@ namespace util::coord { /////////////////////////////////////////////////////////////////////////// - simd operator+ (simd a, simd b) { return _mm_add_ps (a, b); } - simd operator- (simd a, simd b) { return _mm_sub_ps (a, b); } - simd operator/ (simd a, simd b) { return _mm_div_ps (a, b); } - simd operator* (simd a, simd b) { return _mm_mul_ps (a, b); } + simd operator+ (simd a, simd b) { return _mm_add_ps (a.data, b.data); } + simd operator- (simd a, simd b) { return _mm_sub_ps (a.data, b.data); } + simd operator/ (simd a, simd b) { return _mm_div_ps (a.data, b.data); } + simd operator* (simd a, simd b) { return _mm_mul_ps (a.data, b.data); } //------------------------------------------------------------------------- - simd operator< (simd a, simd b) { return _mm_cmplt_ps (a, b); } - simd operator<= (simd a, simd b) { return _mm_cmple_ps (a, b); } - simd operator> (simd a, simd b) { return _mm_cmpgt_ps (a, b); } - simd operator>= (simd a, simd b) { return _mm_cmpge_ps (a, b); } - simd operator== (simd a, simd b) { return _mm_cmpeq_ps (a, b); } + // computes a*b + c + auto + fma (simd a, simd b, simd c) + { +#if defined(__FMA__) + return _mm_fmadd_ps (a.data, b.data, c.data); +#else + return a * b + c; +#endif + } + + + /////////////////////////////////////////////////////////////////////////// + simd operator< (simd a, simd b) { return _mm_cmplt_ps (a.data, b.data); } + simd operator<= (simd a, simd b) { return _mm_cmple_ps (a.data, b.data); } + simd operator> (simd a, simd b) { return _mm_cmpgt_ps (a.data, b.data); } + simd operator>= (simd a, simd b) { return _mm_cmpge_ps (a.data, b.data); } + simd operator== (simd a, simd b) { return _mm_cmpeq_ps (a.data, b.data); } //------------------------------------------------------------------------- - simd operator| (simd a, simd b) { return _mm_or_ps (a, b); } - simd operator& (simd a, simd b) { return _mm_and_ps (a, b); } + simd operator| (simd a, simd b) { return _mm_or_ps (a.data, b.data); } + simd operator|| (simd a, simd b) { return _mm_or_ps (a.data, b.data); } + simd operator& (simd a, simd b) { return _mm_and_ps (a.data, b.data); } + simd operator&& (simd a, simd b) { return _mm_and_ps (a.data, b.data); } + /////////////////////////////////////////////////////////////////////////// + simd floor (simd val) + { +#if defined(__SSE4_1__) + return mm_floor_ps (val.data); +#else + // NOTE: assumes the rounding mode is 'nearest' + + // cast to int and back to truncate + const auto truncated = _mm_cvtepi32_ps (_mm_cvtps_epi32 (val.data)); + + // if the truncated value is greater than the original value we got + // rounded up so we need to decrement to get the true value. + return truncated - ((truncated > val) & simd (1)); +#endif + } + + + //--------------------------------------------------------------------------- + simd ceil (simd val) + { +#if defined(__SSE4_1__) + return _mm_ceil_ps (val.data); +#else + // NOTE: assumes the rounding mode is 'nearest' + + // truncate by casting to int and back + const auto truncated = _mm_cvtepi32_ps (_mm_cvtps_epi32 (val.data)); + + // if the truncated value is below the original value it got rounded + // down and needs to be incremented to get the true value. + return truncated + ((truncated < val) & simd (1)); +#endif + } + /////////////////////////////////////////////////////////////////////////// simd select (simd mask, simd a, simd b) @@ -99,8 +149,8 @@ namespace util::coord { return _mm_blendv_ps (a, b, mask); #else return _mm_or_ps ( - _mm_and_ps (mask, a), - _mm_andnot_ps (mask, b) + _mm_and_ps (mask.data, a.data), + _mm_andnot_ps (mask.data, b.data) ); #endif } @@ -110,58 +160,36 @@ namespace util::coord { bool all (simd val) { - return _mm_movemask_ps (val) == 0b1111; + return _mm_movemask_ps (val.data) == 0b1111; } //------------------------------------------------------------------------- - auto - clamp (simd val, simd lo, simd hi) + bool + any (simd val) { - auto lo_mask = val > lo; - auto hi_mask = val < hi; - - auto res = (lo_mask & val) + return _mm_movemask_ps (val.data); } /////////////////////////////////////////////////////////////////////////// - // use the same comparator in both because we're likely to use min - // and max near each other and the mask might be sharable this way. - simd min (simd a, simd b) { return select (a < b, a, b); } - simd max (simd a, simd b) { return select (a < b, b, a); } + simd min (simd a, simd b) { return _mm_min_ps (a.data, b.data); } + simd max (simd a, simd b) { return _mm_max_ps (a.data, b.data); } + + simd + clamp (simd val, simd lo, simd hi) + { + return min (max (val, lo), hi); + } /////////////////////////////////////////////////////////////////////////// -#if defined (__SSE3__) simd sum (simd a) { - auto part = _mm_hadd_ps (a, a); + auto part = _mm_hadd_ps (a.data, a.data); return _mm_hadd_ps (part, part); } -#else - auto - sum (simd vals) - { - // swap pairs of components - // vals: 3 2 1 0 - // shuf: 2 3 0 1 - auto shuf = _mm_shuffle_ps (vals, vals, _MM_SHUFFLE(2, 3, 0, 1)); - - // combine the pairs - auto sums = _mm_add_ps (vals, shuf); - - // copy the lower components of sums up, then add with the original sums - // sums: 2+3 2+3 1+0 1+0 - // shuf: xxx xxx 2+3 2+3 - shuf = _mm_movehl_ps (shuf, sums); - sums = _mm_add_ss (sums, shuf); - - // sums: xxx xxx 0123 1234 - return _mm_cvtss_f32 (sums); - } -#endif /////////////////////////////////////////////////////////////////////////// @@ -171,25 +199,17 @@ namespace util::coord { { return _mm_dp_ps (a, b, 0xff); } -#elif defined(__SSE3__) +#else simd dot (simd a, simd b) { - return sum (a * b) + return sum (a * b); } -#else - auto - dot (simd a, simd b) - { - auto mul = a * b; - return sum (mul); - } #endif - /////////////////////////////////////////////////////////////////////////// - simd sqrt (simd a) { return _mm_sqrt_ps (a); } - simd rsqrt (simd a) { return _mm_rsqrt_ps (a); } + simd sqrt (simd a) { return _mm_sqrt_ps (a.data); } + simd rsqrt (simd a) { return _mm_rsqrt_ps (a.data); } /////////////////////////////////////////////////////////////////////////// @@ -224,7 +244,7 @@ namespace util::coord { auto b7fff = _mm_srli_epi32 (bffff, 1); auto mask = _mm_castsi128_ps (b7fff); - return _mm_and_ps (mask, a); + return _mm_and_ps (mask, a.data); } @@ -242,6 +262,18 @@ namespace util::coord { { return all (data); } + + + std::ostream& + operator<< (std::ostream &os, simd val) + { + return os << "[ " + << val[0] << ", " + << val[1] << ", " + << val[2] << ", " + << val[3] + << " ]"; + } } #endif diff --git a/test/coord/simd.cpp b/test/coord/simd.cpp index 99a3b120..24d71e76 100644 --- a/test/coord/simd.cpp +++ b/test/coord/simd.cpp @@ -5,30 +5,52 @@ int main () { + using util::coord::simd; + util::TAP::logger tap; + std::clog << "rounding mode is: " << [] () { + switch (_MM_GET_ROUNDING_MODE ()) { + case _MM_ROUND_NEAREST: return "nearest"; + case _MM_ROUND_DOWN: return "down"; + case _MM_ROUND_UP: return "up"; + case _MM_ROUND_TOWARD_ZERO: return "toward_zero"; + } + + return "unknown"; + } () << '\n'; + { - const util::coord::simd a (1,2,3,4); - const util::coord::simd b (4,1,3,2); - const float res = dot (a, b); + const simd a (1,2,3,4); + const simd b (4,1,3,2); + const float res = dot (a, b)[0]; tap.expect_eq (res, 4+2+9+8, "trivial dot product"); } { - const util::coord::simd a (1, 2, 3, 4); - const util::coord::simd b (0, 3, 3, 9); + const simd a (1, 2, 3, 4); + const simd b (0, 3, 3, 9); const auto lo = min (a, b); const auto hi = max (a, b); - tap.expect_eq (lo, util::coord::simd {0,2,3,4}, "vector minimum"); - tap.expect_eq (hi, util::coord::simd {1,3,3,9}, "vector maximum"); + tap.expect_eq (lo, simd {0,2,3,4}, "vector minimum"); + tap.expect_eq (hi, simd {1,3,3,9}, "vector maximum"); } { - const util::coord::simd val { -INFINITY, INFINITY, 0, -9 }; - tap.expect_eq (abs (val), util::coord::simd {INFINITY,INFINITY,0,9}, "absolute value"); + const simd val { -INFINITY, INFINITY, 0, -9 }; + tap.expect_eq (abs (val), simd {INFINITY,INFINITY,0,9}, "absolute value"); } + { + const simd test { -1.25f, 1.25f, 0.f, 1.f }; + const auto lo = floor (test); + const auto hi = ceil (test); + + tap.expect_eq (lo, simd { -2, 1, 0, 1 }, "floor"); + tap.expect_eq (hi, simd { -1, 2, 0, 1 }, "ceil"); + }; + return tap.status (); } \ No newline at end of file