coord/simd: add more sse operations
This commit is contained in:
parent
7708b12c37
commit
341907ac79
96
coord/simd_neon.hpp
Normal file
96
coord/simd_neon.hpp
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
/*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
* Copyright 2018 Danny Robson <danny@nerdcruft.net>
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef __ARM_NEON__
|
||||||
|
#error
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "arm_neon.h"
|
||||||
|
|
||||||
|
namespace util::coord {
|
||||||
|
struct alignas (16) simd {
|
||||||
|
using value_type = float32x4_t;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////
|
||||||
|
simd (float a, float b, float c, float d):
|
||||||
|
data (_mm_setr_ps (a, b, c, d))
|
||||||
|
{ ; }
|
||||||
|
|
||||||
|
|
||||||
|
//---------------------------------------------------------------------
|
||||||
|
simd (float v):
|
||||||
|
data (_mm_set_ps1 (v))
|
||||||
|
{ ; }
|
||||||
|
|
||||||
|
|
||||||
|
//---------------------------------------------------------------------
|
||||||
|
simd (value_type _data):
|
||||||
|
data (_data)
|
||||||
|
{ ; }
|
||||||
|
|
||||||
|
|
||||||
|
//---------------------------------------------------------------------
|
||||||
|
operator value_type& () { return data; }
|
||||||
|
operator const value_type& () const { return data; }
|
||||||
|
|
||||||
|
explicit operator bool () const;
|
||||||
|
|
||||||
|
float operator[] (int idx) const { return data[idx]; }
|
||||||
|
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////
|
||||||
|
value_type data;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////
|
||||||
|
simd operator* (simd a, simd b) { return vmulq_f32 (a, b); };
|
||||||
|
simd operator/ (simd a, simd b) { return vdivq_f32 (a, b); };
|
||||||
|
simd operator+ (simd a, simd b) { return vaddq_f32 (a, b); };
|
||||||
|
simd operator- (simd a, simd b) { return vsubq_f32 (a, b); };
|
||||||
|
|
||||||
|
simd operator< (simd a, simd b);
|
||||||
|
simd operator<= (simd a, simd b);
|
||||||
|
simd operator> (simd a, simd b);
|
||||||
|
simd operator>= (simd a, simd b);
|
||||||
|
simd operator== (simd a, simd b);
|
||||||
|
|
||||||
|
simd select (simd mask, simd a, simd b);
|
||||||
|
|
||||||
|
|
||||||
|
auto sum (simd val)
|
||||||
|
{
|
||||||
|
// reverse and add to self giving: 0123 + 3210
|
||||||
|
auto revq = vrev64q_f32 (val);
|
||||||
|
auto pair = vaddq_f32 (val, revq);
|
||||||
|
|
||||||
|
// reverse the upper and lower pairs given (2301 + 1023)
|
||||||
|
auto shuf = vcombine_f32 (
|
||||||
|
vget_high_f32 (pair),
|
||||||
|
vget_low_f32 (pair)
|
||||||
|
);
|
||||||
|
|
||||||
|
// add both partial sums: (2301 + 1032) + (0123 + 3210)
|
||||||
|
return vaddq_f32 (shuf, pair);
|
||||||
|
}
|
||||||
|
|
||||||
|
simd
|
||||||
|
dot (simd a, simd b)
|
||||||
|
{
|
||||||
|
return sum (a * b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
@ -18,8 +18,8 @@
|
|||||||
#ifndef CRUFT_UTIL_COORD_SIMD_SSE_HPP
|
#ifndef CRUFT_UTIL_COORD_SIMD_SSE_HPP
|
||||||
#define CRUFT_UTIL_COORD_SIMD_SSE_HPP
|
#define CRUFT_UTIL_COORD_SIMD_SSE_HPP
|
||||||
|
|
||||||
#ifndef __SSE2__
|
#ifndef __SSE3__
|
||||||
#error "SSE2 is required"
|
#error "SSE3 is required"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <xmmintrin.h>
|
#include <xmmintrin.h>
|
||||||
@ -58,8 +58,8 @@ namespace util::coord {
|
|||||||
|
|
||||||
|
|
||||||
//---------------------------------------------------------------------
|
//---------------------------------------------------------------------
|
||||||
operator __m128& () { return data; }
|
explicit operator __m128& () { return data; }
|
||||||
operator const __m128& () const { return data; }
|
explicit operator const __m128& () const { return data; }
|
||||||
|
|
||||||
explicit operator bool () const;
|
explicit operator bool () const;
|
||||||
|
|
||||||
@ -72,25 +72,75 @@ namespace util::coord {
|
|||||||
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////
|
||||||
simd operator+ (simd a, simd b) { return _mm_add_ps (a, b); }
|
simd operator+ (simd a, simd b) { return _mm_add_ps (a.data, b.data); }
|
||||||
simd operator- (simd a, simd b) { return _mm_sub_ps (a, b); }
|
simd operator- (simd a, simd b) { return _mm_sub_ps (a.data, b.data); }
|
||||||
simd operator/ (simd a, simd b) { return _mm_div_ps (a, b); }
|
simd operator/ (simd a, simd b) { return _mm_div_ps (a.data, b.data); }
|
||||||
simd operator* (simd a, simd b) { return _mm_mul_ps (a, b); }
|
simd operator* (simd a, simd b) { return _mm_mul_ps (a.data, b.data); }
|
||||||
|
|
||||||
|
|
||||||
//-------------------------------------------------------------------------
|
//-------------------------------------------------------------------------
|
||||||
simd operator< (simd a, simd b) { return _mm_cmplt_ps (a, b); }
|
// computes a*b + c
|
||||||
simd operator<= (simd a, simd b) { return _mm_cmple_ps (a, b); }
|
auto
|
||||||
simd operator> (simd a, simd b) { return _mm_cmpgt_ps (a, b); }
|
fma (simd a, simd b, simd c)
|
||||||
simd operator>= (simd a, simd b) { return _mm_cmpge_ps (a, b); }
|
{
|
||||||
simd operator== (simd a, simd b) { return _mm_cmpeq_ps (a, b); }
|
#if defined(__FMA__)
|
||||||
|
return _mm_fmadd_ps (a.data, b.data, c.data);
|
||||||
|
#else
|
||||||
|
return a * b + c;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////
|
||||||
|
simd operator< (simd a, simd b) { return _mm_cmplt_ps (a.data, b.data); }
|
||||||
|
simd operator<= (simd a, simd b) { return _mm_cmple_ps (a.data, b.data); }
|
||||||
|
simd operator> (simd a, simd b) { return _mm_cmpgt_ps (a.data, b.data); }
|
||||||
|
simd operator>= (simd a, simd b) { return _mm_cmpge_ps (a.data, b.data); }
|
||||||
|
simd operator== (simd a, simd b) { return _mm_cmpeq_ps (a.data, b.data); }
|
||||||
|
|
||||||
|
|
||||||
//-------------------------------------------------------------------------
|
//-------------------------------------------------------------------------
|
||||||
simd operator| (simd a, simd b) { return _mm_or_ps (a, b); }
|
simd operator| (simd a, simd b) { return _mm_or_ps (a.data, b.data); }
|
||||||
simd operator& (simd a, simd b) { return _mm_and_ps (a, b); }
|
simd operator|| (simd a, simd b) { return _mm_or_ps (a.data, b.data); }
|
||||||
|
simd operator& (simd a, simd b) { return _mm_and_ps (a.data, b.data); }
|
||||||
|
simd operator&& (simd a, simd b) { return _mm_and_ps (a.data, b.data); }
|
||||||
|
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////
|
||||||
|
simd floor (simd val)
|
||||||
|
{
|
||||||
|
#if defined(__SSE4_1__)
|
||||||
|
return mm_floor_ps (val.data);
|
||||||
|
#else
|
||||||
|
// NOTE: assumes the rounding mode is 'nearest'
|
||||||
|
|
||||||
|
// cast to int and back to truncate
|
||||||
|
const auto truncated = _mm_cvtepi32_ps (_mm_cvtps_epi32 (val.data));
|
||||||
|
|
||||||
|
// if the truncated value is greater than the original value we got
|
||||||
|
// rounded up so we need to decrement to get the true value.
|
||||||
|
return truncated - ((truncated > val) & simd (1));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//---------------------------------------------------------------------------
|
||||||
|
simd ceil (simd val)
|
||||||
|
{
|
||||||
|
#if defined(__SSE4_1__)
|
||||||
|
return _mm_ceil_ps (val.data);
|
||||||
|
#else
|
||||||
|
// NOTE: assumes the rounding mode is 'nearest'
|
||||||
|
|
||||||
|
// truncate by casting to int and back
|
||||||
|
const auto truncated = _mm_cvtepi32_ps (_mm_cvtps_epi32 (val.data));
|
||||||
|
|
||||||
|
// if the truncated value is below the original value it got rounded
|
||||||
|
// down and needs to be incremented to get the true value.
|
||||||
|
return truncated + ((truncated < val) & simd (1));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////
|
||||||
simd
|
simd
|
||||||
select (simd mask, simd a, simd b)
|
select (simd mask, simd a, simd b)
|
||||||
@ -99,8 +149,8 @@ namespace util::coord {
|
|||||||
return _mm_blendv_ps (a, b, mask);
|
return _mm_blendv_ps (a, b, mask);
|
||||||
#else
|
#else
|
||||||
return _mm_or_ps (
|
return _mm_or_ps (
|
||||||
_mm_and_ps (mask, a),
|
_mm_and_ps (mask.data, a.data),
|
||||||
_mm_andnot_ps (mask, b)
|
_mm_andnot_ps (mask.data, b.data)
|
||||||
);
|
);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
@ -110,58 +160,36 @@ namespace util::coord {
|
|||||||
bool
|
bool
|
||||||
all (simd val)
|
all (simd val)
|
||||||
{
|
{
|
||||||
return _mm_movemask_ps (val) == 0b1111;
|
return _mm_movemask_ps (val.data) == 0b1111;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//-------------------------------------------------------------------------
|
//-------------------------------------------------------------------------
|
||||||
auto
|
bool
|
||||||
clamp (simd val, simd lo, simd hi)
|
any (simd val)
|
||||||
{
|
{
|
||||||
auto lo_mask = val > lo;
|
return _mm_movemask_ps (val.data);
|
||||||
auto hi_mask = val < hi;
|
|
||||||
|
|
||||||
auto res = (lo_mask & val)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////
|
||||||
// use the same comparator in both because we're likely to use min
|
simd min (simd a, simd b) { return _mm_min_ps (a.data, b.data); }
|
||||||
// and max near each other and the mask might be sharable this way.
|
simd max (simd a, simd b) { return _mm_max_ps (a.data, b.data); }
|
||||||
simd min (simd a, simd b) { return select (a < b, a, b); }
|
|
||||||
simd max (simd a, simd b) { return select (a < b, b, a); }
|
simd
|
||||||
|
clamp (simd val, simd lo, simd hi)
|
||||||
|
{
|
||||||
|
return min (max (val, lo), hi);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////
|
||||||
#if defined (__SSE3__)
|
|
||||||
simd
|
simd
|
||||||
sum (simd a)
|
sum (simd a)
|
||||||
{
|
{
|
||||||
auto part = _mm_hadd_ps (a, a);
|
auto part = _mm_hadd_ps (a.data, a.data);
|
||||||
return _mm_hadd_ps (part, part);
|
return _mm_hadd_ps (part, part);
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
auto
|
|
||||||
sum (simd vals)
|
|
||||||
{
|
|
||||||
// swap pairs of components
|
|
||||||
// vals: 3 2 1 0
|
|
||||||
// shuf: 2 3 0 1
|
|
||||||
auto shuf = _mm_shuffle_ps (vals, vals, _MM_SHUFFLE(2, 3, 0, 1));
|
|
||||||
|
|
||||||
// combine the pairs
|
|
||||||
auto sums = _mm_add_ps (vals, shuf);
|
|
||||||
|
|
||||||
// copy the lower components of sums up, then add with the original sums
|
|
||||||
// sums: 2+3 2+3 1+0 1+0
|
|
||||||
// shuf: xxx xxx 2+3 2+3
|
|
||||||
shuf = _mm_movehl_ps (shuf, sums);
|
|
||||||
sums = _mm_add_ss (sums, shuf);
|
|
||||||
|
|
||||||
// sums: xxx xxx 0123 1234
|
|
||||||
return _mm_cvtss_f32 (sums);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////
|
||||||
@ -171,25 +199,17 @@ namespace util::coord {
|
|||||||
{
|
{
|
||||||
return _mm_dp_ps (a, b, 0xff);
|
return _mm_dp_ps (a, b, 0xff);
|
||||||
}
|
}
|
||||||
#elif defined(__SSE3__)
|
#else
|
||||||
simd
|
simd
|
||||||
dot (simd a, simd b)
|
dot (simd a, simd b)
|
||||||
{
|
{
|
||||||
return sum (a * b)
|
return sum (a * b);
|
||||||
}
|
|
||||||
#else
|
|
||||||
auto
|
|
||||||
dot (simd a, simd b)
|
|
||||||
{
|
|
||||||
auto mul = a * b;
|
|
||||||
return sum (mul);
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////
|
||||||
simd sqrt (simd a) { return _mm_sqrt_ps (a); }
|
simd sqrt (simd a) { return _mm_sqrt_ps (a.data); }
|
||||||
simd rsqrt (simd a) { return _mm_rsqrt_ps (a); }
|
simd rsqrt (simd a) { return _mm_rsqrt_ps (a.data); }
|
||||||
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////
|
||||||
@ -224,7 +244,7 @@ namespace util::coord {
|
|||||||
auto b7fff = _mm_srli_epi32 (bffff, 1);
|
auto b7fff = _mm_srli_epi32 (bffff, 1);
|
||||||
auto mask = _mm_castsi128_ps (b7fff);
|
auto mask = _mm_castsi128_ps (b7fff);
|
||||||
|
|
||||||
return _mm_and_ps (mask, a);
|
return _mm_and_ps (mask, a.data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -242,6 +262,18 @@ namespace util::coord {
|
|||||||
{
|
{
|
||||||
return all (data);
|
return all (data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::ostream&
|
||||||
|
operator<< (std::ostream &os, simd val)
|
||||||
|
{
|
||||||
|
return os << "[ "
|
||||||
|
<< val[0] << ", "
|
||||||
|
<< val[1] << ", "
|
||||||
|
<< val[2] << ", "
|
||||||
|
<< val[3]
|
||||||
|
<< " ]";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@ -5,30 +5,52 @@
|
|||||||
int
|
int
|
||||||
main ()
|
main ()
|
||||||
{
|
{
|
||||||
|
using util::coord::simd;
|
||||||
|
|
||||||
util::TAP::logger tap;
|
util::TAP::logger tap;
|
||||||
|
|
||||||
|
std::clog << "rounding mode is: " << [] () {
|
||||||
|
switch (_MM_GET_ROUNDING_MODE ()) {
|
||||||
|
case _MM_ROUND_NEAREST: return "nearest";
|
||||||
|
case _MM_ROUND_DOWN: return "down";
|
||||||
|
case _MM_ROUND_UP: return "up";
|
||||||
|
case _MM_ROUND_TOWARD_ZERO: return "toward_zero";
|
||||||
|
}
|
||||||
|
|
||||||
|
return "unknown";
|
||||||
|
} () << '\n';
|
||||||
|
|
||||||
{
|
{
|
||||||
const util::coord::simd a (1,2,3,4);
|
const simd a (1,2,3,4);
|
||||||
const util::coord::simd b (4,1,3,2);
|
const simd b (4,1,3,2);
|
||||||
const float res = dot (a, b);
|
const float res = dot (a, b)[0];
|
||||||
tap.expect_eq (res, 4+2+9+8, "trivial dot product");
|
tap.expect_eq (res, 4+2+9+8, "trivial dot product");
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
const util::coord::simd a (1, 2, 3, 4);
|
const simd a (1, 2, 3, 4);
|
||||||
const util::coord::simd b (0, 3, 3, 9);
|
const simd b (0, 3, 3, 9);
|
||||||
|
|
||||||
const auto lo = min (a, b);
|
const auto lo = min (a, b);
|
||||||
const auto hi = max (a, b);
|
const auto hi = max (a, b);
|
||||||
|
|
||||||
tap.expect_eq (lo, util::coord::simd {0,2,3,4}, "vector minimum");
|
tap.expect_eq (lo, simd {0,2,3,4}, "vector minimum");
|
||||||
tap.expect_eq (hi, util::coord::simd {1,3,3,9}, "vector maximum");
|
tap.expect_eq (hi, simd {1,3,3,9}, "vector maximum");
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
const util::coord::simd val { -INFINITY, INFINITY, 0, -9 };
|
const simd val { -INFINITY, INFINITY, 0, -9 };
|
||||||
tap.expect_eq (abs (val), util::coord::simd {INFINITY,INFINITY,0,9}, "absolute value");
|
tap.expect_eq (abs (val), simd {INFINITY,INFINITY,0,9}, "absolute value");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const simd test { -1.25f, 1.25f, 0.f, 1.f };
|
||||||
|
const auto lo = floor (test);
|
||||||
|
const auto hi = ceil (test);
|
||||||
|
|
||||||
|
tap.expect_eq (lo, simd { -2, 1, 0, 1 }, "floor");
|
||||||
|
tap.expect_eq (hi, simd { -1, 2, 0, 1 }, "ceil");
|
||||||
|
};
|
||||||
|
|
||||||
return tap.status ();
|
return tap.status ();
|
||||||
}
|
}
|
Loading…
x
Reference in New Issue
Block a user