Skip to content

Commit

Permalink
Add Tensor type
Browse files Browse the repository at this point in the history
Multi-dimensional arrays
  • Loading branch information
seanmcl committed Jan 4, 2025
1 parent 1a5aa13 commit 5a52db1
Show file tree
Hide file tree
Showing 3 changed files with 496 additions and 0 deletions.
1 change: 1 addition & 0 deletions TensorLib/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ import TensorLib.Dtype
import TensorLib.Npy
import TensorLib.Broadcast
import TensorLib.Slice
import TensorLib.Tensor
4 changes: 4 additions & 0 deletions TensorLib/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,17 @@ deriving BEq, Repr, Inhabited

namespace Shape

def empty : Shape := Shape.mk []

--! The number of elements in a tensor. All that's needed is the shape for this calculation.
-- TODO: Put this in the struct?
def count (shape : Shape) : Nat := natProd shape.val

--! Number of dimensions
def ndim (shape : Shape) : Nat := shape.val.length

def map (shape : Shape) (f : List Nat -> List Nat) : Shape := Shape.mk (f shape.val)

/-!
Strides can be computed from the shape by figuring out how many elements you
need to jump over to get to the next spot and mulitplying by the bytes in each
Expand Down
Loading

0 comments on commit 5a52db1

Please sign in to comment.