libcruft-util/rand/distribution/normal.hpp

108 lines
2.7 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*
* Copyright 2020-2021, Danny Robson <danny@nerdcruft.net>
*/
#pragma once
#include "uniform.hpp"
#include <optional>
#include <type_traits>
namespace cruft::rand::distribution {
template <typename ResultT>
requires (std::is_floating_point_v<ResultT>)
class normal {
public:
using result_type = ResultT;
struct param_type {
result_type mean;
result_type stddev;
};
normal ():
normal (0)
{ ; }
explicit normal (result_type mean, result_type stddev = 1)
: m_param {
.mean = mean,
.stddev = stddev
}
{ ; }
explicit normal (param_type const &_param)
: m_param (_param)
{ ; }
void reset (void)
{
m_prev.reset ();
}
template <typename GeneratorT>
result_type
operator() (GeneratorT &&g)
{
return (*this)(g, m_param);
}
// We use the BoxMuller transform to convert pairs of uniform reals
// to normally distributed reals.
template <typename GeneratorT>
result_type
operator() (GeneratorT &&g, param_type const &params)
{
if (m_prev) {
auto const res = m_prev.value () * params.stddev + params.mean;
m_prev.reset ();
return res;
}
auto [u, v, s] = find_uvs (g);
result_type z0 = u * std::sqrt (-2 * std::log (s) / s);
result_type z1 = v * std::sqrt (-2 * std::log (s) / s);
m_prev = z1;
return z0 * params.stddev + params.mean;
}
result_type mean (void) const { return m_param.mean; }
result_type stddev (void) const { return m_param.stddev; }
private:
template <typename GeneratorT>
std::tuple<result_type, result_type, result_type>
find_uvs (GeneratorT &&g)
{
while (1) {
uniform_real_distribution<result_type> unit (-1, 1);
result_type u = unit (g);
result_type v = unit (g);
result_type s = u * u + v * v;
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wfloat-equal"
if (s != 0 && s < 1)
return { u, v, s };
#pragma GCC diagnostic pop
}
}
param_type m_param;
std::optional<result_type> m_prev;
};
}