diff --git a/FastRng/Double/ShapeFitter.cs b/FastRng/Double/ShapeFitter.cs
index a0c9c43..6fe2ff8 100644
--- a/FastRng/Double/ShapeFitter.cs
+++ b/FastRng/Double/ShapeFitter.cs
@@ -2,51 +2,55 @@ using System;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
+using FastRng.Double.Distributions;
namespace FastRng.Double
{
+ ///
+ /// ShapeFitter is a rejection sampler, cf. https://en.wikipedia.org/wiki/Rejection_sampling
+ ///
public sealed class ShapeFitter
{
private readonly double[] probabilities;
- private readonly double[] samples;
private readonly IRandom rng;
- private readonly ushort sampleSize;
- private readonly double threshold;
-
- public ShapeFitter(Func shapeFunction, IRandom rng, ushort sampleSize = 50, double threshold = 0.99)
+ private readonly double max;
+ private readonly double sampleSize;
+ private readonly IDistribution uniform = new Uniform();
+
+ public ShapeFitter(Func shapeFunction, IRandom rng, ushort sampleSize = 50)
{
this.rng = rng;
- this.threshold = threshold;
this.sampleSize = sampleSize;
- this.samples = new double[sampleSize];
this.probabilities = new double[sampleSize];
-
- var sampleStepSize = 1.0 / sampleSize;
+
+ var sampleStepSize = 1.0d / sampleSize;
var nextStep = 0.0 + sampleStepSize;
+ var maxValue = 0.0d;
for (var n = 0; n < sampleSize; n++)
{
this.probabilities[n] = shapeFunction(nextStep);
+ if (this.probabilities[n] > maxValue)
+ maxValue = this.probabilities[n];
+
nextStep += sampleStepSize;
}
+
+ this.max = maxValue;
}
public async ValueTask NextNumber(CancellationToken token = default)
{
while (!token.IsCancellationRequested)
{
- var nextNumber = await this.rng.GetUniform(token);
- var nextBucket = (int)Math.Floor(nextNumber * this.sampleSize);
- // var firstInBucket = this.samples[nextBucket] == 0;
- this.samples[nextBucket] += this.probabilities[nextBucket];
-
- // if (firstInBucket) // TODO: Could be an option (optional)
- // return nextNumber;
+ var x = await this.rng.GetUniform(token);
+ var nextBucket = (int)Math.Floor(x * this.sampleSize);
+ var threshold = this.probabilities[nextBucket];
+ var y = await this.rng.NextNumber(0.0d, this.max, this.uniform, token);
- if (this.samples[nextBucket] >= this.threshold)
- {
- this.samples[nextBucket] -= this.threshold;
- return nextNumber;
- }
+ if(y > threshold)
+ continue;
+
+ return x;
}
return double.NaN;