/*
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
 *
 * Copyright 2018 Danny Robson <danny@nerdcruft.net>
 */


#ifndef CRUFT_UTIL_COORD_SIMD_SSE_HPP
#define CRUFT_UTIL_COORD_SIMD_SSE_HPP

#ifndef __SSE3__
#error "SSE3 is required"
#endif

#include <xmmintrin.h>
#include <pmmintrin.h>
#include <immintrin.h>

#include <array>

#include <iosfwd>

namespace cruft::coord {
    ///////////////////////////////////////////////////////////////////////////
    constexpr int alignment = 16;


    template <size_t CountV, typename ValueT>
    struct native_type { };

    template <> struct native_type<1,float> { using type = __m128; };
    template <> struct native_type<2,float> { using type = __m128; };
    template <> struct native_type<3,float> { using type = __m128; };
    template <> struct native_type<4,float> { using type = __m128; };

    template <> struct native_type<1,double> { using type = __m128d; };
    template <> struct native_type<2,double> { using type = __m128d; };

    template <> struct native_type<1,uint32_t> { using type = __m128i; };
    template <> struct native_type<2,uint32_t> { using type = __m128i; };
    template <> struct native_type<3,uint32_t> { using type = __m128i; };
    template <> struct native_type<4,uint32_t> { using type = __m128i; };


    template <size_t CountV, typename ValueT>
    struct alignas (16) simd {
        ///////////////////////////////////////////////////////////////////////
        simd (ValueT a, ValueT b, ValueT c, ValueT d):
            data (_mm_setr_ps (a, b, c, d))
        { ; }


        //---------------------------------------------------------------------
        simd (ValueT v):
            data (_mm_set_ps1 (v))
        { ; }


        //---------------------------------------------------------------------
        simd (__m128 _data):
            data (_data)
        { ; }


        //---------------------------------------------------------------------
        explicit operator       __m128& ()       { return data; }
        explicit operator const __m128& () const { return data; }

        explicit operator bool () const;

        ValueT operator[] (int idx) const { return data[idx]; }


        ///////////////////////////////////////////////////////////////////////
        template <size_t IndexV>
        struct accessor {
            operator ValueT () const noexcept
            {
#ifdef __SSE4_1__
                return _mm_extract_epi32 (data, IndexV);
#else
                return _mm_cvtss_f32 (
                    _mm_shuffle_ps (
                        data,
                        data,
                        _MM_SHUFFLE (IndexV, IndexV, IndexV, IndexV)
                    )
                );
#endif
            }

            accessor& operator= (ValueT);

            __m128 data;
        };


        union {
            __m128 data;
            accessor<0> x;
            accessor<1> y;
            accessor<2> z;
            accessor<3> w;
        };
    };


    ///////////////////////////////////////////////////////////////////////////
    template <size_t S,typename T>
    simd<S,T>
    operator+ (simd<S,T> a, simd<S,T> b)
    { return _mm_add_ps (a.data, b.data); }


    //-------------------------------------------------------------------------
    template <size_t S,typename T>
    simd<S,T>
    operator- (simd<S,T> a, simd<S,T> b)
    { return _mm_sub_ps (a.data, b.data); }


    //-------------------------------------------------------------------------
    template <size_t S,typename T>
    simd<S,T>
    operator/ (simd<S,T> a, simd<S,T> b)
    { return _mm_div_ps (a.data, b.data); }


    //-------------------------------------------------------------------------
    template <size_t S,typename T>
    simd<S,T>
    operator* (simd<S,T> a, simd<S,T> b)
    { return _mm_mul_ps (a.data, b.data); }


    ///////////////////////////////////////////////////////////////////////////
    // computes a*b + c
    template <size_t S, typename T>
    auto
    fma (simd<S,T> a, simd<S,T> b, simd<S,T> c)
    {
#if defined(__FMA__)
        return _mm_fmadd_ps (a.data, b.data, c.data);
#else
        return a * b + c;
#endif
    }


    ///////////////////////////////////////////////////////////////////////////
    template <size_t S, typename T>
    simd<S,T>
    operator< (simd<S,T> a, simd<S,T> b)
    { return _mm_cmplt_ps (a.data, b.data); }


    template <size_t S, typename T>
    simd<S,T>
    operator<= (simd<S,T> a, simd<S,T> b)
    { return _mm_cmple_ps (a.data, b.data); }


    template <size_t S, typename T>
    simd<S,T>
    operator>  (simd<S,T> a, simd<S,T> b)
    { return _mm_cmpgt_ps (a.data, b.data); }


    template <size_t S, typename T>
    simd<S,T>
    operator>= (simd<S,T> a, simd<S,T> b)
    { return _mm_cmpge_ps (a.data, b.data); }


    template <size_t S, typename T>
    simd<S,T>
    operator== (simd<S,T> a, simd<S,T> b)
    { return _mm_cmpeq_ps (a.data, b.data); }


    ///////////////////////////////////////////////////////////////////////////
    template <size_t S, typename T>
    simd<S,T>
    operator|  (simd<S,T> a, simd<S,T> b)
    { return _mm_or_ps  (a.data, b.data); }


    template <size_t S, typename T>
    simd<S,T>
    operator|| (simd<S,T> a, simd<S,T> b)
    { return _mm_or_ps  (a.data, b.data); }


    template <size_t S, typename T>
    simd<S,T>
    operator& (simd<S,T> a, simd<S,T> b)
    { return _mm_and_ps (a.data, b.data); }


    template <size_t S, typename T>
    simd<S,T>
    operator&& (simd<S,T> a, simd<S,T> b)
    { return _mm_and_ps (a.data, b.data); }


    ///////////////////////////////////////////////////////////////////////////
    template <size_t S, typename T>
    simd<S,T>
    floor (simd<S,T> val)
    {
#if defined(__SSE4_1__)
        return mm_floor_ps (val.data);
#else
        // NOTE: assumes the rounding mode is 'nearest'

        // cast to int and back to truncate
        const simd<S,T> truncated = _mm_cvtepi32_ps (_mm_cvtps_epi32 (val.data));

        // if the truncated value is greater than the original value we got
        // rounded up so we need to decrement to get the true value.
        return truncated - ((truncated > val) & simd<S,T> (1));
#endif
    }


    //---------------------------------------------------------------------------
    template <size_t S, typename T>
    simd<S,T>
    ceil (simd<S,T> val)
    {
#if defined(__SSE4_1__)
        return _mm_ceil_ps (val.data);
#else
        // NOTE: assumes the rounding mode is 'nearest'

        // truncate by casting to int and back
        const simd<S,T> truncated = _mm_cvtepi32_ps (_mm_cvtps_epi32 (val.data));

        // if the truncated value is below the original value it got rounded
        // down and needs to be incremented to get the true value.
        return truncated + ((truncated < val) & simd<S,T> (1));
#endif
    }

    ///////////////////////////////////////////////////////////////////////////
    template <size_t S, typename T>
    simd<S,T>
    select (simd<S,T> mask, simd<S,T> a, simd<S,T> b)
    {
#if defined(__SSE4_1__)
        return _mm_blendv_ps (a, b, mask);
#else
        return _mm_or_ps (
            _mm_and_ps    (mask.data, a.data),
            _mm_andnot_ps (mask.data, b.data)
        );
#endif
    }


    //-------------------------------------------------------------------------
    template <size_t S, typename T>
    bool
    all (simd<S,T> val)
    {
        return _mm_movemask_ps (val.data) == 0b1111;
    }


    //-------------------------------------------------------------------------
    template <size_t S, typename T>
    bool
    any (simd<S,T> val)
    {
        return _mm_movemask_ps (val.data);
    }


    ///////////////////////////////////////////////////////////////////////////
    template <size_t S, typename T>
    simd<S,T>
    min (simd<S,T> a, simd<S,T> b)
    { return _mm_min_ps (a.data, b.data); }


    template <size_t S, typename T>
    simd<S,T>
    max (simd<S,T> a, simd<S,T> b)
    { return _mm_max_ps (a.data, b.data); }


    template <size_t S, typename T>
    simd<S,T>
    clamp (simd<S,T> val, simd<S,T> lo, simd<S,T> hi)
    {
        return min (max (val, lo), hi);
    }


    ///////////////////////////////////////////////////////////////////////////
    template <size_t S, typename T>
    simd<S,T>
    sum (simd<S,T> a)
    {
        auto part = _mm_hadd_ps (a.data, a.data);
        return _mm_hadd_ps (part, part);
    }


    ///////////////////////////////////////////////////////////////////////////
#if defined(__SSE4_1__)
    simd
    dot (simd a, simd b)
    {
        return _mm_dp_ps (a, b, 0xff);
    }
#else
    template <size_t S, typename T>
    simd<S,T>
    dot (simd<S,T> a, simd<S,T> b)
    {
        return sum (a * b);
    }
#endif

    ///////////////////////////////////////////////////////////////////////////
    template <size_t S, typename T> simd<S,T>  sqrt (simd<S,T> a) { return _mm_sqrt_ps  (a.data); }
    template <size_t S, typename T> simd<S,T> rsqrt (simd<S,T> a) { return _mm_rsqrt_ps (a.data); }


    ///////////////////////////////////////////////////////////////////////////
    template <size_t S, typename T>
    auto
    norm2 (simd<S,T> a)
    {
        return dot (a, a);
    }


    //-------------------------------------------------------------------------
    template <size_t S, typename T>
    auto
    norm (simd<S,T> a)
    {
        return sqrt (norm2 (a));
    }


    //-------------------------------------------------------------------------
    template <size_t S, typename T>
    auto
    normalised (simd<S,T> a)
    {
        return a * rsqrt (norm (a));
    }


    ///////////////////////////////////////////////////////////////////////////
    template <size_t S, typename T>
    simd<S,T>
    abs (simd<S,T> a)
    {
        auto bffff = _mm_set1_epi32 (-1);
        auto b7fff = _mm_srli_epi32 (bffff, 1);
        auto mask = _mm_castsi128_ps (b7fff);

        return _mm_and_ps (mask, a.data);
    }



    ///////////////////////////////////////////////////////////////////////////
    template <size_t S, typename T>
    auto
    hypot (simd<S,T> a)
    {
        return sqrt (sum (a * a));
    }


    ///////////////////////////////////////////////////////////////////////////
    template <size_t S, typename T>
    simd<S,T>::operator bool() const
    {
        return all (*this);
    }


    ///////////////////////////////////////////////////////////////////////////
    template <size_t S, typename T>
    std::ostream& operator<< (std::ostream &os, simd<S,T> val);
}

#endif