diff --git a/FastRng/Double/ShapeFitter.cs b/FastRng/Double/ShapeFitter.cs index a0c9c43..6fe2ff8 100644 --- a/FastRng/Double/ShapeFitter.cs +++ b/FastRng/Double/ShapeFitter.cs @@ -2,51 +2,55 @@ using System; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; +using FastRng.Double.Distributions; namespace FastRng.Double { + /// + /// ShapeFitter is a rejection sampler, cf. https://en.wikipedia.org/wiki/Rejection_sampling + /// public sealed class ShapeFitter { private readonly double[] probabilities; - private readonly double[] samples; private readonly IRandom rng; - private readonly ushort sampleSize; - private readonly double threshold; - - public ShapeFitter(Func shapeFunction, IRandom rng, ushort sampleSize = 50, double threshold = 0.99) + private readonly double max; + private readonly double sampleSize; + private readonly IDistribution uniform = new Uniform(); + + public ShapeFitter(Func shapeFunction, IRandom rng, ushort sampleSize = 50) { this.rng = rng; - this.threshold = threshold; this.sampleSize = sampleSize; - this.samples = new double[sampleSize]; this.probabilities = new double[sampleSize]; - - var sampleStepSize = 1.0 / sampleSize; + + var sampleStepSize = 1.0d / sampleSize; var nextStep = 0.0 + sampleStepSize; + var maxValue = 0.0d; for (var n = 0; n < sampleSize; n++) { this.probabilities[n] = shapeFunction(nextStep); + if (this.probabilities[n] > maxValue) + maxValue = this.probabilities[n]; + nextStep += sampleStepSize; } + + this.max = maxValue; } public async ValueTask NextNumber(CancellationToken token = default) { while (!token.IsCancellationRequested) { - var nextNumber = await this.rng.GetUniform(token); - var nextBucket = (int)Math.Floor(nextNumber * this.sampleSize); - // var firstInBucket = this.samples[nextBucket] == 0; - this.samples[nextBucket] += this.probabilities[nextBucket]; - - // if (firstInBucket) // TODO: Could be an option (optional) - // return nextNumber; + var x = await this.rng.GetUniform(token); + var nextBucket = (int)Math.Floor(x * this.sampleSize); + var threshold = this.probabilities[nextBucket]; + var y = await this.rng.NextNumber(0.0d, this.max, this.uniform, token); - if (this.samples[nextBucket] >= this.threshold) - { - this.samples[nextBucket] -= this.threshold; - return nextNumber; - } + if(y > threshold) + continue; + + return x; } return double.NaN;