r/dist: add the normal distribution
This commit is contained in:
parent
05880da691
commit
a2fa34c619
@ -487,6 +487,8 @@ list (
|
|||||||
"${CMAKE_CURRENT_BINARY_DIR}/prefix/${PREFIX}/preprocessor.hpp"
|
"${CMAKE_CURRENT_BINARY_DIR}/prefix/${PREFIX}/preprocessor.hpp"
|
||||||
quaternion.cpp
|
quaternion.cpp
|
||||||
quaternion.hpp
|
quaternion.hpp
|
||||||
|
rand/distribution/normal.cpp
|
||||||
|
rand/distribution/normal.hpp
|
||||||
rand/distribution/uniform.cpp
|
rand/distribution/uniform.cpp
|
||||||
rand/distribution/uniform.hpp
|
rand/distribution/uniform.hpp
|
||||||
rand/generic.hpp
|
rand/generic.hpp
|
||||||
@ -732,6 +734,7 @@ if (TESTS)
|
|||||||
preprocessor
|
preprocessor
|
||||||
quaternion
|
quaternion
|
||||||
rand/buckets
|
rand/buckets
|
||||||
|
rand/generator/normal
|
||||||
random
|
random
|
||||||
range
|
range
|
||||||
rational
|
rational
|
||||||
|
9
rand/distribution/normal.cpp
Normal file
9
rand/distribution/normal.cpp
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
/*
|
||||||
|
* This Source Code Form is subject to the terms of the Mozilla Public
|
||||||
|
* License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||||
|
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
*
|
||||||
|
* Copyright 2020, Danny Robson <danny@nerdcruft.net>
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "normal.hpp"
|
103
rand/distribution/normal.hpp
Normal file
103
rand/distribution/normal.hpp
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
/*
|
||||||
|
* This Source Code Form is subject to the terms of the Mozilla Public
|
||||||
|
* License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||||
|
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
*
|
||||||
|
* Copyright 2020, Danny Robson <danny@nerdcruft.net>
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "uniform.hpp"
|
||||||
|
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
|
||||||
|
namespace cruft::rand::distribution {
|
||||||
|
template <typename ResultT>
|
||||||
|
requires (std::is_floating_point_v<ResultT>)
|
||||||
|
class normal {
|
||||||
|
public:
|
||||||
|
using result_type = ResultT;
|
||||||
|
struct param_type {
|
||||||
|
result_type mean;
|
||||||
|
result_type stddev;
|
||||||
|
};
|
||||||
|
|
||||||
|
normal ():
|
||||||
|
normal (0)
|
||||||
|
{ ; }
|
||||||
|
|
||||||
|
explicit normal (result_type mean, result_type stddev = 1)
|
||||||
|
: m_param {
|
||||||
|
.mean = mean,
|
||||||
|
.stddev = stddev
|
||||||
|
}
|
||||||
|
{ ; }
|
||||||
|
|
||||||
|
|
||||||
|
explicit normal (param_type const &_param)
|
||||||
|
: m_param (_param)
|
||||||
|
{ ; }
|
||||||
|
|
||||||
|
|
||||||
|
void reset (void)
|
||||||
|
{
|
||||||
|
m_live = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename GeneratorT>
|
||||||
|
result_type
|
||||||
|
operator() (GeneratorT &&g)
|
||||||
|
{
|
||||||
|
return (*this)(g, m_param);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// We use the Box–Muller transform to convert pairs of uniform reals
|
||||||
|
// to normally distributed reals.
|
||||||
|
template <typename GeneratorT>
|
||||||
|
result_type
|
||||||
|
operator() (GeneratorT &&g, param_type const ¶ms)
|
||||||
|
{
|
||||||
|
if (m_live) {
|
||||||
|
m_live = false;
|
||||||
|
return m_prev * params.stddev + params.mean;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto [u, v, s] = find_uvs (g);
|
||||||
|
result_type z0 = u * std::sqrt (-2 * std::log (s) / s);
|
||||||
|
result_type z1 = v * std::sqrt (-2 * std::log (s) / s);
|
||||||
|
|
||||||
|
m_prev = z1;
|
||||||
|
m_live = true;
|
||||||
|
|
||||||
|
return z0 * params.stddev + params.mean;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename GeneratorT>
|
||||||
|
std::tuple<result_type, result_type, result_type>
|
||||||
|
find_uvs (GeneratorT &&g)
|
||||||
|
{
|
||||||
|
while (1) {
|
||||||
|
uniform_real_distribution<result_type> unit (-1, 1);
|
||||||
|
result_type u = unit (g);
|
||||||
|
result_type v = unit (g);
|
||||||
|
result_type s = u * u + v * v;
|
||||||
|
|
||||||
|
#pragma GCC diagnostic push
|
||||||
|
#pragma GCC diagnostic ignored "-Wfloat-equal"
|
||||||
|
if (s != 0 && s < 1)
|
||||||
|
return { u, v, s };
|
||||||
|
#pragma GCC diagnostic pop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
param_type m_param;
|
||||||
|
bool m_live = false;
|
||||||
|
result_type m_prev;
|
||||||
|
};
|
||||||
|
}
|
69
test/rand/generator/normal.cpp
Normal file
69
test/rand/generator/normal.cpp
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
#include "maths.hpp"
|
||||||
|
#include "rand/distribution/normal.hpp"
|
||||||
|
#include "tap.hpp"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <random>
|
||||||
|
|
||||||
|
|
||||||
|
/// Probability density function for a normal distribution with specified
|
||||||
|
/// mean and stddev at point `x`.
|
||||||
|
static
|
||||||
|
float pdf (float x, float mean, float stddev)
|
||||||
|
{
|
||||||
|
float const power = cruft::pow2 ((x - mean) / stddev) / -2;
|
||||||
|
float const scale = 1.f / (stddev * std::sqrt (2.f * cruft::pi<float>));
|
||||||
|
return scale * std::exp (power);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/// Calculate the maximum difference between a histogram and a PDF for a
|
||||||
|
/// normal distribution with a number of buckets.
|
||||||
|
static
|
||||||
|
float max_histogram_error (int buckets)
|
||||||
|
{
|
||||||
|
// These constants weren't rigorously selected. Eyeballing the generated
|
||||||
|
// values suggested they had some level of precision and didn't explode
|
||||||
|
// the test's runtime.
|
||||||
|
int const BUCKETS = buckets;
|
||||||
|
int const ITERATIONS = BUCKETS * 10'000;
|
||||||
|
float const MEAN = BUCKETS / 2.f;
|
||||||
|
float const STDDEV = BUCKETS * .15f;
|
||||||
|
|
||||||
|
// Use _our_ normal distribution, not the stdlib one.
|
||||||
|
cruft::rand::distribution::normal<float> g (MEAN, STDDEV);
|
||||||
|
// We use a stdlib generator with reasonable quality just so we're not
|
||||||
|
// testing both our generator and our distributions simultaneously.
|
||||||
|
std::mt19937_64 rand;
|
||||||
|
|
||||||
|
std::vector<int> counts (BUCKETS, 0);
|
||||||
|
for (int i = 0; i < ITERATIONS; ++i) {
|
||||||
|
auto const val = g (rand);
|
||||||
|
if (val >= BUCKETS || val < 0)
|
||||||
|
continue;
|
||||||
|
counts[int (val)]++;
|
||||||
|
}
|
||||||
|
|
||||||
|
float max_diff = 0.f;
|
||||||
|
for (int i = 0; i < BUCKETS; ++i) {
|
||||||
|
float expected = pdf (i, MEAN, STDDEV);
|
||||||
|
float actual = counts[i] / float (ITERATIONS);
|
||||||
|
float diff = std::abs (expected - actual);
|
||||||
|
|
||||||
|
max_diff = std::max (max_diff, diff);
|
||||||
|
}
|
||||||
|
|
||||||
|
return max_diff;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
int main (void)
|
||||||
|
{
|
||||||
|
cruft::TAP::logger tap;
|
||||||
|
tap.expect_lt (
|
||||||
|
max_histogram_error (500),
|
||||||
|
1.e-4f,
|
||||||
|
"normal distribution histogram maximum relative error"
|
||||||
|
);
|
||||||
|
return tap.status ();
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user