/*
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
 *
 * Copyright 2020-2021, Danny Robson <danny@nerdcruft.net>
 */

#pragma once

#include "uniform.hpp"

#include <optional>
#include <type_traits>


namespace cruft::rand::distribution {
    template <typename ResultT>
    requires (std::is_floating_point_v<ResultT>)
    class normal {
    public:
        using result_type = ResultT;
        struct param_type {
            result_type mean;
            result_type stddev;
        };

        normal ():
            normal (0)
        { ; }

        explicit normal (result_type mean, result_type stddev = 1)
            : m_param {
                .mean   = mean,
                .stddev = stddev
            }
        { ; }


        explicit normal (param_type const &_param)
            : m_param (_param)
        { ; }


        void reset (void)
        {
            m_prev.reset ();
        }


        template <typename GeneratorT>
        result_type
        operator() (GeneratorT &&g)
        {
            return (*this)(g, m_param);
        }


        // We use the Box–Muller transform to convert pairs of uniform reals
        // to normally distributed reals.
        template <typename GeneratorT>
        result_type
        operator() (GeneratorT &&g, param_type const &params)
        {
            if (m_prev) {
                auto const res = m_prev.value () * params.stddev + params.mean;
                m_prev.reset ();
                return res;
            }

            auto [u, v, s] = find_uvs (g);
            result_type z0 = u * std::sqrt (-2 * std::log (s) / s);
            result_type z1 = v * std::sqrt (-2 * std::log (s) / s);

            m_prev = z1;

            return z0 * params.stddev + params.mean;
        }


        result_type mean   (void) const { return m_param.mean; }
        result_type stddev (void) const { return m_param.stddev; }


    private:
        template <typename GeneratorT>
        std::tuple<result_type, result_type, result_type>
        find_uvs (GeneratorT &&g)
        {
            while (1) {
                uniform_real_distribution<result_type> unit (-1, 1);
                result_type u = unit (g);
                result_type v = unit (g);
                result_type s = u * u + v * v;

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wfloat-equal"
                if (s != 0 && s < 1)
                    return { u, v, s };
#pragma GCC diagnostic pop
            }
        }

        param_type m_param;
        std::optional<result_type> m_prev;
    };
}