/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * Copyright 2018 Danny Robson */ #ifndef CRUFT_UTIL_COORD_SIMD_SSE_HPP #define CRUFT_UTIL_COORD_SIMD_SSE_HPP #ifndef __SSE3__ #error "SSE3 is required" #endif #include #include #include #include #include namespace util::coord { /////////////////////////////////////////////////////////////////////////// namespace detail { } constexpr int alignment = 16; struct alignas (16) simd { /////////////////////////////////////////////////////////////////////// simd (float a, float b, float c, float d): data (_mm_setr_ps (a, b, c, d)) { ; } //--------------------------------------------------------------------- simd (float v): data (_mm_set_ps1 (v)) { ; } //--------------------------------------------------------------------- simd (__m128 _data): data (_data) { ; } //--------------------------------------------------------------------- explicit operator __m128& () { return data; } explicit operator const __m128& () const { return data; } explicit operator bool () const; float operator[] (int idx) const { return data[idx]; } /////////////////////////////////////////////////////////////////////// __m128 data; }; /////////////////////////////////////////////////////////////////////////// simd operator+ (simd a, simd b) { return _mm_add_ps (a.data, b.data); } simd operator- (simd a, simd b) { return _mm_sub_ps (a.data, b.data); } simd operator/ (simd a, simd b) { return _mm_div_ps (a.data, b.data); } simd operator* (simd a, simd b) { return _mm_mul_ps (a.data, b.data); } //------------------------------------------------------------------------- // computes a*b + c auto fma (simd a, simd b, simd c) { #if defined(__FMA__) return _mm_fmadd_ps (a.data, b.data, c.data); #else return a * b + c; #endif } /////////////////////////////////////////////////////////////////////////// simd operator< (simd a, simd b) { return _mm_cmplt_ps (a.data, b.data); } simd operator<= (simd a, simd b) { return _mm_cmple_ps (a.data, b.data); } simd operator> (simd a, simd b) { return _mm_cmpgt_ps (a.data, b.data); } simd operator>= (simd a, simd b) { return _mm_cmpge_ps (a.data, b.data); } simd operator== (simd a, simd b) { return _mm_cmpeq_ps (a.data, b.data); } //------------------------------------------------------------------------- simd operator| (simd a, simd b) { return _mm_or_ps (a.data, b.data); } simd operator|| (simd a, simd b) { return _mm_or_ps (a.data, b.data); } simd operator& (simd a, simd b) { return _mm_and_ps (a.data, b.data); } simd operator&& (simd a, simd b) { return _mm_and_ps (a.data, b.data); } /////////////////////////////////////////////////////////////////////////// simd floor (simd val) { #if defined(__SSE4_1__) return mm_floor_ps (val.data); #else // NOTE: assumes the rounding mode is 'nearest' // cast to int and back to truncate const auto truncated = _mm_cvtepi32_ps (_mm_cvtps_epi32 (val.data)); // if the truncated value is greater than the original value we got // rounded up so we need to decrement to get the true value. return truncated - ((truncated > val) & simd (1)); #endif } //--------------------------------------------------------------------------- simd ceil (simd val) { #if defined(__SSE4_1__) return _mm_ceil_ps (val.data); #else // NOTE: assumes the rounding mode is 'nearest' // truncate by casting to int and back const auto truncated = _mm_cvtepi32_ps (_mm_cvtps_epi32 (val.data)); // if the truncated value is below the original value it got rounded // down and needs to be incremented to get the true value. return truncated + ((truncated < val) & simd (1)); #endif } /////////////////////////////////////////////////////////////////////////// simd select (simd mask, simd a, simd b) { #if defined(__SSE4_1__) return _mm_blendv_ps (a, b, mask); #else return _mm_or_ps ( _mm_and_ps (mask.data, a.data), _mm_andnot_ps (mask.data, b.data) ); #endif } //------------------------------------------------------------------------- bool all (simd val) { return _mm_movemask_ps (val.data) == 0b1111; } //------------------------------------------------------------------------- bool any (simd val) { return _mm_movemask_ps (val.data); } /////////////////////////////////////////////////////////////////////////// simd min (simd a, simd b) { return _mm_min_ps (a.data, b.data); } simd max (simd a, simd b) { return _mm_max_ps (a.data, b.data); } simd clamp (simd val, simd lo, simd hi) { return min (max (val, lo), hi); } /////////////////////////////////////////////////////////////////////////// simd sum (simd a) { auto part = _mm_hadd_ps (a.data, a.data); return _mm_hadd_ps (part, part); } /////////////////////////////////////////////////////////////////////////// #if defined(__SSE4_1__) simd dot (simd a, simd b) { return _mm_dp_ps (a, b, 0xff); } #else simd dot (simd a, simd b) { return sum (a * b); } #endif /////////////////////////////////////////////////////////////////////////// simd sqrt (simd a) { return _mm_sqrt_ps (a.data); } simd rsqrt (simd a) { return _mm_rsqrt_ps (a.data); } /////////////////////////////////////////////////////////////////////////// simd norm2 (simd a) { return dot (a, a); } //------------------------------------------------------------------------- simd norm (simd a) { return sqrt (norm2 (a)); } //------------------------------------------------------------------------- simd normalised (simd a) { return a * rsqrt (norm (a)); } /////////////////////////////////////////////////////////////////////////// auto abs (simd a) { auto bffff = _mm_set1_epi32 (-1); auto b7fff = _mm_srli_epi32 (bffff, 1); auto mask = _mm_castsi128_ps (b7fff); return _mm_and_ps (mask, a.data); } /////////////////////////////////////////////////////////////////////////// simd hypot (simd a) { return sqrt (sum (a * a)); } /////////////////////////////////////////////////////////////////////////// simd::operator bool() const { return all (data); } std::ostream& operator<< (std::ostream &os, simd val) { return os << "[ " << val[0] << ", " << val[1] << ", " << val[2] << ", " << val[3] << " ]"; } } #endif