Skip to content

Commit

Permalink
Add Tensor type
Browse files Browse the repository at this point in the history
Multi-dimensional arrays
  • Loading branch information
seanmcl committed Jan 4, 2025
1 parent 1a5aa13 commit 5b914b5
Show file tree
Hide file tree
Showing 4 changed files with 475 additions and 11 deletions.
1 change: 1 addition & 0 deletions TensorLib/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ import TensorLib.Dtype
import TensorLib.Npy
import TensorLib.Broadcast
import TensorLib.Slice
import TensorLib.Tensor
16 changes: 11 additions & 5 deletions TensorLib/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ def natDivCeil (num denom : Nat) : Nat := (num + denom - 1) / denom

def natProd (shape : List Nat) : Nat := shape.foldl (fun x y => x * y) 1

-- We generally have large tensors, so don't show them by default
instance ByteArrayRepr : Repr ByteArray where
reprPrec x _ :=
if x.size < 100 then x.toList.repr 100 else
s!"ByteArray of size {x.size}"

/-!
NumPy arrays can be stored in big-endian or little-endian order on disk, regardless
of the architecture of the machine saving the array. Since we read these arrays
Expand All @@ -39,15 +45,11 @@ inductive ByteOrder where
| bigEndian
deriving BEq, Repr, Inhabited

namespace ByteOrder

@[simp]
def isMultiByte (x : ByteOrder) : Bool := match x with
def ByteOrder.isMultiByte (x : ByteOrder) : Bool := match x with
| .oneByte => false
| .littleEndian | .bigEndian => true

end ByteOrder

/-!
The strides are how many bytes you need to skip to get to the next element in that
"row". For example, in an array of 8-byte data with shape 2, 3, the strides are (24, 8).
Expand Down Expand Up @@ -142,13 +144,17 @@ deriving BEq, Repr, Inhabited

namespace Shape

def empty : Shape := Shape.mk []

--! The number of elements in a tensor. All that's needed is the shape for this calculation.
-- TODO: Put this in the struct?
def count (shape : Shape) : Nat := natProd shape.val

--! Number of dimensions
def ndim (shape : Shape) : Nat := shape.val.length

def map (shape : Shape) (f : List Nat -> List Nat) : Shape := Shape.mk (f shape.val)

/-!
Strides can be computed from the shape by figuring out how many elements you
need to jump over to get to the next spot and mulitplying by the bytes in each
Expand Down
6 changes: 0 additions & 6 deletions TensorLib/Npy.lean
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,6 @@ def dataSize (header : Header): Nat := header.descr.itemsize * header.shape.coun

end Header

-- We generally have large tensors, so don't show them by default
local instance ByteArrayRepr : Repr ByteArray where
reprPrec x _ :=
let s := toString x.size
s!"ByteArray of size {s}"

structure Ndarray where
header : Header
data : ByteArray
Expand Down
Loading

0 comments on commit 5b914b5

Please sign in to comment.