Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
joelberkeley committed May 27, 2023
1 parent 09f5d88 commit 8e0bbe8
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/Optimize.idr
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,23 @@ Optimizer domain = (domain -> Ref $ Tensor [] F64) -> Ref domain
||| @upper The upper (exclusive) bound of the grid.
export
gridSearch : {d : _} ->
(density : Vect d Nat) ->
(density : Literal [d] Nat) ->
(lower : Tensor [d] F64) ->
(upper : Tensor [d] F64) ->
Optimizer $ Tensor [d] F64
gridSearch {d=Z} _ _ _ _ = fromLiteral []
gridSearch {d=S k} density lower upper f =
let densityAll : Nat
densityAll = product density
let gridSize : Nat
gridSize = product density

prodDims : Tensor [S k] U64 := fromLiteral $ cast $ scanr (*) 1 (tail density)
idxs = fromLiteral {shape=[densityAll]} $ cast $ Vect.range densityAll
idxs = fromLiteral {shape=[gridSize]} $ cast $ Vect.range gridSize
densityTensor = broadcast $ fromLiteral {shape=[S k]} {dtype=U64} (cast density)
grid = broadcast {to=[densityAll, S k]} (expand 1 idxs)
grid = broadcast {to=[gridSize, S k]} (expand 1 idxs)
`div` broadcast {from=[S k]} (cast prodDims) `rem` densityTensor
gridRelative = cast grid / cast densityTensor
points = with Tensor.(+)
broadcast lower + broadcast {to=[densityAll, _]} (upper - lower) * gridRelative
broadcast lower + broadcast {to=[gridSize, _]} (upper - lower) * gridRelative
in slice [at (argmin 0 (vmap f points))] points

||| The limited-memory BFGS (L-BFGS) optimization tactic, see
Expand Down

0 comments on commit 8e0bbe8

Please sign in to comment.