diff --git a/src/Optical/Emitters/Origins.jl b/src/Optical/Emitters/Origins.jl index b5ff29dd8..01ae45065 100644 --- a/src/Optical/Emitters/Origins.jl +++ b/src/Optical/Emitters/Origins.jl @@ -3,7 +3,7 @@ # See LICENSE in the project root for full license information. module Origins -export Point, RectUniform, RectGrid, Hexapolar +export Point, RectUniform, RectGrid, Hexapolar, RectJitterGrid using ....OpticSim using ...Emitters @@ -115,6 +115,51 @@ function Emitters.generate(o::RectGrid{T}, n::Int64) where {T<:Real} return zeros(Vec3{T}) + ((o.width / 2) * u * unitX3(T)) + ((o.height/2) * v * unitY3(T)) end +""" + RectJitterGrid{T} <: AbstractOriginDistribution{T} + +Encapsulates a rectangle sampled in a grid fashion with jitter. + +```julia +RectGrid(width::T, height::T, ures::Int64, vres::Int64, samplesPerRegion::Int64) where {T<:Real} +``` +""" +struct RectJitterGrid{T} <: AbstractOriginDistribution{T} + width::T + height::T + uResolution::Int64 + vResolution::Int64 + samplesPerRegion::Int64 + ustep::T + vstep::T + rng::Random.AbstractRNG + + function RectJitterGrid(width::T, height::T, ures::Int64, vres::Int64, samplesPerRegion::Int64; rng=Random.GLOBAL_RNG) where {T<:Real} + return new{T}(width, height, ures, vres, samplesPerRegion, width / ures, height / vres, rng) + end +end + +Base.length(o::RectJitterGrid) = o.uResolution * o.vResolution * o.samplesPerRegion +Emitters.visual_size(o::RectJitterGrid) = max(o.width, o.height) + +# generate origin on the grid +function Emitters.generate(o::RectJitterGrid{T}, n::Int64) where {T<:Real} + n = mod(n, length(o)) + + uu = rand(o.rng, Distributions.Uniform(zero(T), o.ustep)) + vv = rand(o.rng, Distributions.Uniform(zero(T), o.vstep)) + + nn = Int64(floor(n / o.samplesPerRegion)) + v = (o.vResolution == 1) ? zero(T) : Int64(floor(nn / o.uResolution)) + u = (o.uResolution == 1) ? zero(T) : Int64(floor(mod(nn, o.uResolution))) + + v = v * o.vstep + vv - (o.height / 2.0) + u = u * o.ustep + uu - (o.width / 2.0) + + return zeros(Vec3{T}) + u * unitX3(T) + v * unitY3(T) +end + + """ Hexapolar{T} <: AbstractOriginDistribution{T} diff --git a/test/testsets/Emitters.jl b/test/testsets/Emitters.jl index 6fd5a43da..c90b481c4 100644 --- a/test/testsets/Emitters.jl +++ b/test/testsets/Emitters.jl @@ -161,6 +161,27 @@ using StaticArrays ] end + @testset "RectJitterGrid" begin + @test Origins.RectJitterGrid(1., 2., 3, 4, 1).width === 1.0 + @test Origins.RectJitterGrid(1., 2., 3, 4, 1).height === 2.0 + @test Origins.RectJitterGrid(1., 2., 2, 2, 3).ustep === 0.5 + @test Origins.RectJitterGrid(1., 2., 2, 2, 3).vstep === 1.0 + + @test Base.length(Origins.RectJitterGrid(1., 2., 5, 6, 7)) === 210 + @test Emitters.visual_size(Origins.RectJitterGrid(1., 2., 5, 6, 7)) === 2.0 + + @test collect(Origins.RectJitterGrid(1., 2., 2, 2, 2, rng=Random.MersenneTwister(0))) == [ + [-0.08817624601129381, -0.08964346207356355, 0.0], + [-0.4177171009331574, -0.8226711535337354, 0.0], + [0.1394400546656005, -0.7965234419580773, 0.0], + [0.021150832966014832, -0.9317307444943552, 0.0], + [-0.3190858046118913, 0.9732164043865108, 0.0], + [-0.2070942241283379, 0.5392892841426182, 0.0], + [0.13001792513452393, 0.910046541351011, 0.0], + [0.08351809722107484, 0.6554484126999125, 0.0], + ] + end + @testset "Hexapolar" begin @test Origins.Hexapolar(1, 0, 0).nrings === 1 @test_throws MethodError Origins.Hexapolar(1., 0, 0)