FastRng/FastRng/ShapeFitter.cs
2023-07-06 10:08:46 +02:00

78 lines
2.7 KiB
C#

using System;
using System.Threading;
using System.Threading.Tasks;
using FastRng.Distributions;
namespace FastRng
{
/// <summary>
/// ShapeFitter is a rejection sampler, cf. https://en.wikipedia.org/wiki/Rejection_sampling
/// </summary>
public sealed class ShapeFitter
{
private readonly float[] probabilities;
private readonly IRandom rng;
private readonly float max;
private readonly float sampleSize;
private readonly IDistribution uniform;
/// <summary>
/// Creates a shape fitter instance.
/// </summary>
/// <param name="shapeFunction">The function which describes the desired shape.</param>
/// <param name="rng">The random number generator instance to use.</param>
/// <param name="sampleSize">The number of sampling steps to sample the given function.</param>
public ShapeFitter(Func<float, float> shapeFunction, IRandom rng, ushort sampleSize = 50)
{
this.rng = rng;
this.uniform = new Uniform(rng);
this.sampleSize = sampleSize;
this.probabilities = new float[sampleSize];
var sampleStepSize = 1.0f / sampleSize;
var nextStep = 0.0f + sampleStepSize;
var maxValue = 0.0f;
for (var n = 0; n < sampleSize; n++)
{
this.probabilities[n] = shapeFunction(nextStep);
if (this.probabilities[n] > maxValue)
maxValue = this.probabilities[n];
nextStep += sampleStepSize;
}
this.max = maxValue;
}
/// <summary>
/// Returns a random number regarding the given shape.
/// </summary>
/// <param name="token">An optional cancellation token.</param>
/// <returns>The next value regarding the given shape.</returns>
public async ValueTask<float> NextNumber(CancellationToken token = default)
{
while (!token.IsCancellationRequested)
{
var x = await this.rng.GetUniform(token);
if (float.IsNaN(x))
return x;
var nextBucket = (int)MathF.Floor(x * this.sampleSize);
if (nextBucket >= this.probabilities.Length)
nextBucket = this.probabilities.Length - 1;
var threshold = this.probabilities[nextBucket];
var y = await this.uniform.NextNumber(0.0f, this.max, token);
if (float.IsNaN(y))
return y;
if(y > threshold)
continue;
return x;
}
return float.NaN;
}
}
}