Skip to content

Commit

Permalink
feat: tracing Random.jl functionality correctly (#363)
Browse files Browse the repository at this point in the history
* refactor: move stdlib overloads to a different directory

* fix: Ops.rng_bit_generator

* feat: initial prototype for random number generation

* feat: add support for scalar sampling

* feat: efficient sampling for non-native RNGs

* fix: handling floating point sampling

* feat: use the override macro

* fix: use `@noinline`

* feat: support randexp

* feat: override seeding inside interpreter

* refactor: move things into a module

* refactor: rework how the overlays are implemented

* docs: add internal api to the docs

* test: include floating point tests

* test: setup testing

* feat: overlay all generators

* test: ensure distributions are correct

* test: overlay generation

* fix: test whether we can call into the non-overlayed version

* fix: try marking TracedRandom in whitelist

* fix: workaround the AbsInt issues for now

* fix: throw errors for now instead of crashing
  • Loading branch information
avik-pal authored Dec 18, 2024
1 parent 0713d99 commit 94e9576
Show file tree
Hide file tree
Showing 17 changed files with 690 additions and 16 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ jobs:
version: '1.10'
assertions: true
test_group: neural_networks
- os: ubuntu-20.04
arch: x64
libReactant: packaged
version: '1.10'
assertions: true
test_group: integration
- os: ubuntu-20.04
arch: x86
libReactant: packaged
Expand Down
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
Scratch = "6c6a2e73-6563-6170-7368-637461726353"
Expand All @@ -23,17 +24,19 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"

[sources.ReactantCore]
path = "lib/ReactantCore"
[sources]
ReactantCore = {path = "lib/ReactantCore"}

[extensions]
ReactantAbstractFFTsExt = "AbstractFFTs"
ReactantArrayInterfaceExt = "ArrayInterface"
ReactantCUDAExt = "CUDA"
ReactantNNlibExt = "NNlib"
ReactantRandom123Ext = "Random123"
ReactantStatisticsExt = "Statistics"
ReactantYaoBlocksExt = "YaoBlocks"

Expand All @@ -51,6 +54,8 @@ LinearAlgebra = "1.10"
NNlib = "0.9.26"
OrderedCollections = "1"
Preferences = "1.4"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.3"
Reactant_jll = "0.0.26"
Scratch = "1.2"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pages = [
],
"MLIR API" => "api/mlirc.md",
"XLA" => "api/xla.md",
"Internal API" => "api/internal.md",
],
]

Expand Down
4 changes: 3 additions & 1 deletion docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ export default defineConfig({
{ text: "MLIR API", link: "/api/mlirc" },
{ text: "XLA", link: "/api/xla" },
],
}
},
{ text: "Internal API", link: "/api/internal" },
],
},
{
Expand Down Expand Up @@ -132,6 +133,7 @@ export default defineConfig({
{ text: "XLA", link: "/api/xla" },
],
},
{ text: "Internal API", link: "/api/internal" },
],
},
},
Expand Down
12 changes: 12 additions & 0 deletions docs/src/api/internal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
```@meta
CollapsedDocStrings = true
```

# Internal API

These functions are not part of the public API and are subject to change at any time.

```@docs
Reactant.REDUB_ARGUMENTS_NAME
Reactant.within_reactant_interpreter
```
11 changes: 11 additions & 0 deletions ext/ReactantRandom123Ext.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module ReactantRandom123Ext

using Random123: Threefry4x, Threefry2x, Philox4x, Philox2x
using Reactant: TracedRandom

TracedRandom.rng_algorithm(::Threefry4x) = "THREE_FRY"
TracedRandom.rng_algorithm(::Threefry2x) = "THREE_FRY"
TracedRandom.rng_algorithm(::Philox4x) = "PHILOX"
TracedRandom.rng_algorithm(::Philox2x) = "PHILOX"

end
141 changes: 136 additions & 5 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1016,19 +1016,150 @@ end
end

# random ops
"""
rng_bit_generator(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)
Generate a random array of type `T` with the given shape and seed from a uniform random
distribution between 0 and 1. Returns a NamedTuple with the following fields:
- `output_state`: The state of the random number generator after the operation.
- `output`: The generated array.
# Arguments
- `T`: The type of the generated array.
- `seed`: The seed for the random number generator.
- `shape`: The shape of the generated array.
- `algorithm`: The algorithm to use for generating the random numbers. Defaults to
"DEFAULT". Other options include "PHILOX" and "THREE_FRY".
"""
@noinline function rng_bit_generator(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__),
)
output = MLIR.IR.TensorType(TracedRArray{UInt64,1}, shape)
) where {T<:Integer}
@assert algorithm in ("DEFAULT", "PHILOX", "THREE_FRY")
if algorithm == "PHILOX"
@assert length(seed) (2, 3)
elseif algorithm == "THREE_FRY"
@assert length(seed) == 2
end

output = MLIR.IR.TensorType(shape, MLIR.IR.Type(T))
output_state = MLIR.IR.TensorType(size(seed), MLIR.IR.Type(UInt64))
rng_algorithm = MLIR.API.stablehloRngAlgorithmAttrGet(MLIR.IR.context(), algorithm)
op = stablehlo.rng_bit_generator(seed.mlir_data; output, rng_algorithm, location)
op = stablehlo.rng_bit_generator(
seed.mlir_data; output, output_state, rng_algorithm, location
)
return (;
output_state=TracedRArray{UInt64,1}((), MLIR.IR.result(op, 1), MLIR.IR.size(seed)),
output=TracedRArray{T,length(shape)}((), MLIR.IR.result(op, 2), shape),
output_state=TracedRArray{UInt64,1}((), MLIR.IR.result(op, 1), size(seed)),
output=TracedRArray{T,length(shape)}((), MLIR.IR.result(op, 2), Tuple(shape)),
)
end

@noinline function rng_bit_generator(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__),
) where {T<:AbstractFloat}
nbits = sizeof(T) * 8
uT = nbits == 16 ? UInt16 : (nbits == 32 ? UInt32 : UInt64)
(; output_state, output) = rng_bit_generator(uT, seed, shape; algorithm, location)
output = divide(
convert(TracedRArray{T,ndims(output)}, output),
constant(fill(T(typemax(uT)), Tuple(shape)); location),
)
return (; output_state, output)
end

"""
randn(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)
Generate a random array of type `T` with the given shape and seed from a standard normal
distribution of mean 0 and standard deviation 1. Returns a NamedTuple with the following
fields:
- `output_state`: The state of the random number generator after the operation.
- `output`: The generated array.
# Arguments
- `T`: The type of the generated array.
- `seed`: The seed for the random number generator.
- `shape`: The shape of the generated array.
- `algorithm`: The algorithm to use for generating the random numbers. Defaults to
"DEFAULT". Other options include "PHILOX" and "THREE_FRY".
"""
@noinline function randn(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
) where {T}
res = rng_bit_generator(T, seed, shape; algorithm, location)
rand_uniform = res.output
seed = res.output_state
scaled_uniform = subtract(
multiply(rand_uniform, constant(fill(T(2), size(rand_uniform)))),
constant(fill(T(1), size(rand_uniform))),
)
probit = erf_inv(scaled_uniform)
rand_normal = multiply(probit, constant(fill(Base.sqrt(T(2)), size(rand_uniform))))
return (; output_state=seed, output=rand_normal)
end

"""
randexp(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)
Generate a random array of type `T` with the given shape and seed from an exponential
distribution with rate 1. Returns a NamedTuple with the following fields:
- `output_state`: The state of the random number generator after the operation.
- `output`: The generated array.
# Arguments
- `T`: The type of the generated array.
- `seed`: The seed for the random number generator.
- `shape`: The shape of the generated array.
- `algorithm`: The algorithm to use for generating the random numbers. Defaults to
"DEFAULT". Other options include "PHILOX" and "THREE_FRY".
"""
@noinline function randexp(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
) where {T}
res = rng_bit_generator(T, seed, shape; algorithm, location)
rand_uniform = res.output
seed = res.output_state
rand_exp = negate(log_plus_one(negate(rand_uniform)))
return (; output_state=seed, output=rand_exp)
end

# functional ops
Expand Down
95 changes: 94 additions & 1 deletion src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,23 @@
# correctly. Once that (https://github.com/timholy/Revise.jl/issues/646) is resolved
# we should move all the reactant_overrides to relevant files.

# Helper Function to determine if we are inside the ReactantInterpreter
"""
within_reactant_interpreter()
Returns `true` if we are currently inside the ReactantInterpreter.
"""
@noinline within_reactant_interpreter() = false
@reactant_overlay @noinline within_reactant_interpreter() = true

# Compiling within a compile should return simply the original function
@reactant_overlay function Compiler.compile(
f, args; client=nothing, optimize=true, sync=false
)
return f
end

# Enzyme overrides
# Enzyme.jl overlays
@reactant_overlay @noinline function Enzyme.autodiff_deferred(
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
Expand All @@ -22,3 +31,87 @@ end
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
end

# Random.jl overlays
@reactant_overlay @noinline function Random.default_rng()
return call_with_reactant(TracedRandom.default_rng)
end

## Only problematic edge case here is the direct `<randfun!>(rng, A::AbstractArray)` call
## We can't directly overlay that call without breaking the semantics of inplace update
for randfun in (:rand, :randn, :randexp)
randfun! = Symbol(randfun, :!)
overload_randfun = Symbol(:overload_, randfun)
overload_randfun! = Symbol(:overload_, randfun!)

@eval begin
@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}, dims::Dims
) where {T}
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T, dims)
end
return error(
"Reactant doesn't support sampling of $(T) with the current interpreter."
)
# XXX: The following will lead to illegal instruction
# @warn "Reactant doesn't support sampling of $(T) with the current \
# interpreter. Falling back to native interpreter." maxlog = 1
# return Random.$(randfun)(rng, T, dims)
end

@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, dim1::Integer, dims::Integer...
)
return TracedRandom.$(overload_randfun)(rng, dim1, dims...)
end

@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}, dim1::Integer, dims::Integer...
) where {T}
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...)
end
return error(
"Reactant doesn't support sampling of $(T) with the current interpreter."
)
# XXX: The following will lead to illegal instruction
# @warn "Reactant doesn't support sampling of $(T) with the current \
# interpreter. Falling back to native interpreter." maxlog = 1
# return Random.$(randfun)(rng, T, dim1, dims...)
end

# scalars
@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}=Float64
) where {T}
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T)
end
return error(
"Reactant doesn't support sampling of $(T) with the current interpreter."
)
# XXX: The following will lead to illegal instruction
# @warn "Reactant doesn't support sampling of $(T) with the current \
# interpreter. Falling back to native interpreter." maxlog = 1
# return Random.$(randfun)(rng, T)
end

# inplace
@reactant_overlay @noinline function Random.$(randfun!)(
rng::AbstractRNG, A::AnyTracedRArray
)
return TracedRandom.$(overload_randfun!)(rng, A)
end

# XXX: Uncomment once AbsInt issues with recursive calls are resolved
# @reactant_overlay @noinline function Random.$(randfun!)(
# rng::AbstractRNG, A::AbstractArray
# )
# @warn "Directly writing to an array using Random.jl functions inside \
# ReactantInterpreter will generate a constant array in the IR. Use with \
# caution." maxlog = 1
# return Random.$(randfun!)(rng, A)
# end
end
end
11 changes: 10 additions & 1 deletion src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ module Reactant
using ReactantCore: ReactantCore, @trace, MissingTracedValue

using LinearAlgebra: LinearAlgebra
using Random: Random, AbstractRNG

using Adapt: Adapt, WrappedArray
using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)`

Expand Down Expand Up @@ -122,7 +124,14 @@ include("TracedRArray.jl")

include("ConcreteRArray.jl")

include("linear_algebra.jl")
mutable struct TracedRNG <: Random.AbstractRNG
seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}}
const algorithm::String
end

# StdLib Overloads
include("stdlibs/LinearAlgebra.jl")
include("stdlibs/Random.jl")

const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue}

Expand Down
File renamed without changes.
Loading

0 comments on commit 94e9576

Please sign in to comment.