/*
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
 *
    * Copyright 2018 Danny Robson <danny@nerdcruft.net>
 */

#ifndef CRUFT_UTIL_MATHS_FAST_HPP
#define CRUFT_UTIL_MATHS_FAST_HPP

#include "../maths.hpp"
#include "../debug.hpp"

namespace cruft::maths::fast {
    float
    estrin (float t, float a, float b, float c, float d)
    {
        const auto t2 = t * t;
        return std::fma (b, t, a) + std::fma (d, t, c) * t2;
    }

    // approximates sin over the range [0; pi/4] using a polynomial.
    //
    // the relative error should be on the order of 1e-6.
    //
    // sollya:
    //   display=hexadecimal;
    //   canonical=on!;
    //   P = fpminimax(sin(x), [|3,5,7|], [|single...|],[0x1p-126;pi/4], , relative, x);
    //   P;
    //   dirtyinfnorm(sin(x)-P, [0;pi/4]);
    //   plot(P(x)-sin(x),[0x1p-126;pi/4]);
    //
    // using minimax(sin(x)-x, [|3,5,7,9|], ...) has higher accuracy, but
    // probably not enough to offset the extra multiplications.
    float
    sin (float x) noexcept
    {
        CHECK_LT (x,  2 * pi<float>);
        CHECK_GT (x, -2 * pi<float>);

        // flip negative angles due to having an even function
        float neg = x < 0 ? x *= -1, -1 : 1;

        // reduce the angle due to the period
        int periods = static_cast<int> (x / 2 / pi<float>);
        x -= periods * 2 * pi<float>;

        // 1,3,5,7:      1, -0x1.555546p-3f,  0x1.11077ap-7f, -0x1.9954e8p-13f;
        // 1,3,5,7,9:    1, -0x1.555556p-3f,  0x1.11115cp-7f, -0x1.a0403ap-13f,  0x1.75dc26p-19f;

        // sin(x) =    x + Ax3 + Bx5 + Cx7
        // sin(x) = x (1 + Ax2 + Bx4 + Cx6)
        // sin(x) = x (1 + Ay1 + By2 + Cy3), y = x^2

        return neg * x * estrin (x * x, 1, -0x1.555546p-3f, 0x1.11077ap-7f, -0x1.9954e8p-13f);
    }

    // calculates an approximation of e^x
    //
    // we split the supplied value into integer and fractional components as
    // use the identity: e^(a+b) == e^a * e^b;
    //
    // integer powers can be computed efficiently using:
    //    e^x
    //  = 2^(x/ln2)
    //  = 2^kx, where k = 1/ln2
    // we set the floating point exponent directly with kx
    //
    // the fractional component is approximated by a polynomial across [0;1]
    //
    // we force the first two terms as 1 and x because they are fantastically
    // efficient to load in this context.
    //
    // using an order 5 polynomial is overkill, but we use something with 4
    // coefficients because it should be pretty efficient to evaluate using
    // SIMD.
    //
    // sollya:
    //   display=hexadecimal;
    //   canonical=on!;
    //   P = fpminimax(exp(x), [|2,3,4,5|], [|single...|], [0x1p-126;1], 1+x);
    //   P;
    //   dirtyinfnorm(P(x)-exp(x), [0;1]);
    //   plot(P(x)-exp(x),[0; 1]);
    float
    exp (float x)
    {
        union {
            int32_t i;
            float f;
        } whole;

        whole.i  = static_cast<int32_t> (x);
        float frac = x - whole.i;

        whole.i *= 0b00000000101110001010101000111011;
        //whole.i *= static_cast<int32_t> ((1<<(23)) / std::log(2));
        whole.i += 0b00111111100000000000000000000000; //127 << 23;

        frac = 1 + frac + frac * frac * estrin (frac,
            0.4997530281543731689453125f,
            0.16853784024715423583984375,
            3.6950431764125823974609375e-2,
            1.303750835359096527099609375e-2
        );

        return frac * whole.f;
    }
}

#endif