251 lines
7.9 KiB
C++
Raw Normal View History

/*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*
* Copyright 2020, Danny Robson <danny@nerdcruft.net>
*/
#pragma once
#include "../../debug/assert.hpp"
#include "../../cast.hpp"
#include <algorithm>
#include <type_traits>
#include <limits>
#include <cmath>
///////////////////////////////////////////////////////////////////////////////
namespace cruft::rand::distribution {
template <typename ResultT>
struct uniform_int_distribution {
static_assert (std::is_integral_v<ResultT>);
using result_type = ResultT;
struct param_type {
result_type a;
result_type b;
};
uniform_int_distribution (
result_type _a,
result_type _b = std::numeric_limits<ResultT>::max ()
)
: uniform_int_distribution (param_type { .a = _a, .b = _b })
{ ; }
uniform_int_distribution ()
: uniform_int_distribution (0)
{ ; }
uniform_int_distribution (param_type const &_param)
: m_param (_param)
{ ; }
void reset (void);
// As a difference to libstd++ and libcxx we accept universal
// references as arguments; purely because it fits some of our
// existing code better, not because it's an efficient or desirable
// pattern.
template <typename GeneratorT>
result_type operator() (GeneratorT &&gen)
{
return this->template operator() (std::forward<GeneratorT> (gen), m_param);
}
// As a difference to libstd++ and libcxx we accept universal
// references as arguments; purely because it fits some of our
// existing code better, not because it's an efficient or desirable
// pattern.
template <typename GeneratorT>
result_type
operator() (GeneratorT &&gen, param_type const &p)
{
// We use the same approach as libstdc++.
//
// Specialising for downscaling gen, upscaling gen, and identify
// gen transforms.
using num_t = std::common_type_t<
result_type,
typename std::remove_cvref_t<GeneratorT>::result_type
>;
num_t const gen_range = gen.max () - gen.min ();
num_t const our_range = p.b - p.a;
if (gen_range > our_range) {
// The output of gen can be seen as: (range * scale) + bias.
//
// We calculate the scale to fit gen's range...
num_t const range = our_range + 1;
num_t const scale = gen_range / range;
num_t const limit = range * scale;
// ...and use rejection sampling if the value falls outside
// the target range...
num_t res;
do {
res = gen () - gen.min ();
} while (res >= limit);
// ...and rescale back to the target range, and add the
// target offset.
return cruft::cast::lossless<result_type> (res / scale + p.a);
} else if (gen_range < our_range) {
// The output range can be modelled as (range * scale) + bias
num_t top, low, res;
num_t const range = gen_range + 1;
num_t const scale = our_range / range;
CHECK_LE (scale, num_t (std::numeric_limits<result_type>::max ()));
param_type p_ {
.a = 0,
.b = cruft::cast::lossless<result_type> (scale),
};
do {
top = range * this->operator() (gen, p_);
low = gen () - gen.min ();
res = top + low;
} while (res > our_range || res < top);
return cruft::cast::lossless<result_type> (res + p.a);
} else {
return cruft::cast::lossless<result_type> (gen () - gen.min () + p.a);
}
}
result_type a (void) const { return m_param.a; }
result_type b (void) const { return m_param.b; }
param_type param (void) const { return m_param; }
void param (param_type const &_param) { m_param = _param; }
result_type min (void) const { return m_param.a; }
result_type max (void) const { return m_param.b; }
private:
param_type m_param;
};
template <typename ResultT>
bool operator== (
uniform_int_distribution<ResultT> const&,
uniform_int_distribution<ResultT> const&
);
template <typename ResultT>
bool operator!= (
uniform_int_distribution<ResultT> const&,
uniform_int_distribution<ResultT> const&
);
template <typename RealT, std::size_t BitsV, typename GeneratorT>
RealT
generate_canonical (GeneratorT &gen)
{
static_assert (std::is_floating_point_v<RealT>);
static constexpr std::size_t b = std::min (
BitsV,
std::size_t (std::numeric_limits<RealT>::digits)
);
// We use a RealT here so that we can avoid overflow when the
// generator output spans the range of size_t (given the +1).
RealT R = RealT (gen.max () - gen.min ()) + 1;
// Ideally we'd compute this without floating point overhead by using
// integral log2 and summation identities.
std::size_t const log2R = std::size_t (std::log (R) / std::log (2));
std::size_t const k = std::max<std::size_t> (1, (b + log2R - 1) / log2R);
RealT base = 1;
RealT accum = 0;
for (std::size_t i = 0; i < k; ++i) {
accum += RealT (gen () - gen.min ()) * base;
base *= R;
}
return accum / base;
}
template <typename ResultT>
struct uniform_real_distribution {
static_assert (std::is_floating_point_v<ResultT>);
using result_type = ResultT;
struct param_type {
result_type a;
result_type b;
};
uniform_real_distribution (
result_type _a,
result_type _b = 1
)
: uniform_real_distribution (param_type { .a = _a, .b = _b })
{ ; }
uniform_real_distribution ()
: uniform_real_distribution (0)
{ ; }
uniform_real_distribution (param_type const &_param)
: m_param (_param)
{ ; }
void reset (void);
// As a difference to libstd++ and libcxx we accept universal
// references as arguments; purely because it fits some of our
// existing code better, not because it's an efficient or desirable
// pattern.
template <typename GeneratorT>
result_type operator() (GeneratorT &&gen)
{
return this->template operator() (std::forward<GeneratorT> (gen), m_param);
}
// As a difference to libstd++ and libcxx we accept universal
// references as arguments; purely because it fits some of our
// existing code better, not because it's an efficient or desirable
// pattern.
template <typename GeneratorT>
result_type
operator() (GeneratorT &&gen, param_type const &p)
{
return ::cruft::rand::distribution::generate_canonical<
result_type,
std::numeric_limits<result_type>::digits
> (gen) * (p.b - p.a) + p.a;
}
result_type a (void) const { return m_param.a; }
result_type b (void) const { return m_param.b; }
param_type param (void) const { return m_param; }
void param (param_type const &_param) { m_param = _param; }
result_type min (void) const { return m_param.a; }
result_type max (void) const { return m_param.b; }
private:
param_type m_param;
};
}