Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
joelberkeley committed Dec 5, 2023
1 parent 78b55d2 commit 35419fb
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 13 deletions.
19 changes: 6 additions & 13 deletions src/Tensor.idr
Original file line number Diff line number Diff line change
Expand Up @@ -628,13 +628,6 @@ fill xs = broadcast {shapesOK=scalarToAnyOk shape} !(tensor (Scalar xs))

----------------------------- generic operations ----------------------------

{-
arg : Primitive dtype => {shape : _} -> Graph (Tensor shape dtype, Nat, ShapeAndType)
arg = do
i <- new
pure (MkTensor i (singleton i (Arg i)), (i, MkShapeAndType shape dtype))
-}

||| Lift a unary function on scalars to an element-wise function on `Tensor`s of arbitrary shape.
||| For example,
||| ```idris
Expand All @@ -649,12 +642,12 @@ map :
(Tensor [] a -> Graph $ Tensor [] b) ->
Tensor shape a ->
Graph $ Tensor shape b
{-
map f $ MkTensor {shape = _} i env = do
(arg, param) <- arg
MkTensor l subEnv <- f arg
env `addNode` Map (MkFn [param] l subEnv) [i] (range $ length shape)
-}
map f $ MkTensor {shape = _} x = do
MkEnvN max env <- get
(subEnv, MkTensor l) <- runState !get (f $ MkTensor $ S max)
let shapeDtype = MkShapeAndType shape dtype
fn = MkFn [(S max, shapeDtype)] l subEnv
addNode $ Map fn [x] (range $ length shape)

||| Lift a binary function on scalars to an element-wise function on `Tensor`s of arbitrary shape.
||| For example,
Expand Down
2 changes: 2 additions & 0 deletions test/Unit/TestTensor/HigherOrder.idr
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ condResultWithReusedArgs = fixedProperty $ do
export partial
all : List (PropertyName, Property)
all = [
{-
("map", mapResult)
, ("map with non-trivial function", mapNonTrivial)
, ("map2", map2Result)
Expand All @@ -217,4 +218,5 @@ all = [
, ("sort with repeated elements", sortWithRepeatedElements)
, ("cond for trivial usage", condResultTrivialUsage)
, ("cond for re-used arguments", condResultWithReusedArgs)
-}
]
2 changes: 2 additions & 0 deletions test/Unit/TestTensor/Slice.idr
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ sliceForVariableIndex = property $ do
export partial
all : List (PropertyName, Property)
all = [
{-
("MultiSlice.slice", MultiSlice.slice)
, ("slice for static index", sliceStaticIndex)
, ("slice for static slice", sliceStaticSlice)
Expand All @@ -213,4 +214,5 @@ all = [
, ("slice for static index and slice", sliceStaticMixed)
, ("slice for mixed static and dynamic index and slice", sliceMixed)
, ("slice for variable index", sliceForVariableIndex)
-}
]

0 comments on commit 35419fb

Please sign in to comment.