Skip to content

Commit

Permalink
Advanced indexing
Browse files Browse the repository at this point in the history
We implement a common version of advanced indexing, where
all arguments are (broadcastable to) int arrays.
  • Loading branch information
seanmcl committed Jan 8, 2025
1 parent 2f3e42c commit 575bf43
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 38 deletions.
24 changes: 22 additions & 2 deletions TensorLib/Broadcast.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin
-/

import Aesop
import Batteries.Data.List
import TensorLib.Common

/-!
Expand Down Expand Up @@ -39,6 +38,8 @@ Rule 2
A: (3, 2, 7)
B: (3, 2, 7)
Theorem to prove: If we can broadcast s1 to s2, then given an array with shape s1, then s1.reshape s2 succeeds
-/

namespace TensorLib
Expand Down Expand Up @@ -81,7 +82,7 @@ private def matchPairs (b : Broadcast) : Option Shape :=
else if x == 1 then some y
else if y == 1 then some x
else none
let dims := (b.left.val.zip b.right.val).traverse f
let dims := (b.left.val.zip b.right.val).mapM f
dims.map Shape.mk

--! Returns the shape resulting from broadcast the arguments
Expand All @@ -103,4 +104,23 @@ def canBroadcast (b : Broadcast) : Bool := (broadcast b).isSome
broadcast b2 == broadcast b1 &&
broadcast b2 == .some (Shape.mk [1, 2, 3])

def broadcastList (shapes : List Shape) : Option Shape := Id.run do
match shapes with
| [] => none
| shape :: shapes =>
let mut shape := shape
for s in shapes do
let b := Broadcast.mk shape s
match b.broadcast with
| .none => return .none
| .some s =>
shape := s
return shape

#guard
let x1 := Shape.mk [1, 2, 3]
let x2 := Shape.mk [2, 3]
let x3 := Shape.mk []
broadcastList [x1, x2, x3] == .some x1

end Broadcast
67 changes: 60 additions & 7 deletions TensorLib/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin
-/

import Std.Tactic.BVDecide

namespace TensorLib

--! The error monad for TensorLib
Expand All @@ -27,6 +26,7 @@ def natDivCeil (num denom : Nat) : Nat := (num + denom - 1) / denom

def natProd (shape : List Nat) : Nat := shape.foldl (fun x y => x * y) 1


-- We generally have large tensors, so don't show them by default
instance ByteArrayRepr : Repr ByteArray where
reprPrec x _ :=
Expand All @@ -45,11 +45,41 @@ inductive ByteOrder where
| bigEndian
deriving BEq, Repr, Inhabited

namespace ByteOrder

@[simp]
def ByteOrder.isMultiByte (x : ByteOrder) : Bool := match x with
def isMultiByte (x : ByteOrder) : Bool := match x with
| .oneByte => false
| .littleEndian | .bigEndian => true

def bytesToInt (order : ByteOrder) (bytes : ByteArray) : Int := Id.run do
let mut n : Nat := 0
let nbytes := bytes.size
let signByte := match order with
| .littleEndian => bytes.get! (nbytes - 1)
| .bigEndian | oneByte => bytes.get! 0
let negative := 128 <= signByte
for i in [0:nbytes] do
let v : UInt8 := bytes.get! i
let v := if negative then UInt8.complement v else v
let p := match order with
| .oneByte => 0 -- nbytes = 1
| .littleEndian => i
| .bigEndian => nbytes - 1 - i
n := n + Pow.pow 2 (8 * p) * v.toNat
return if 128 <= signByte then -(n + 1) else n

#guard bytesToInt .littleEndian (ByteArray.mk #[1, 1]) == 257
#guard bytesToInt .bigEndian (ByteArray.mk #[1, 1]) == 257
#guard bytesToInt .littleEndian (ByteArray.mk #[0, 1]) == 256
#guard bytesToInt .bigEndian (ByteArray.mk #[0, 1]) == 1
#guard bytesToInt .littleEndian (ByteArray.mk #[0xFF, 0xFF]) == -1
#guard bytesToInt .bigEndian (ByteArray.mk #[0xFF, 0xFF]) == -1
#guard bytesToInt .bigEndian (ByteArray.mk #[0x80, 0]) == -32768
#guard bytesToInt .littleEndian (ByteArray.mk #[0x80, 0]) == 0x80

end ByteOrder

/-!
The strides are how many bytes you need to skip to get to the next element in that
"row". For example, in an array of 8-byte data with shape 2, 3, the strides are (24, 8).
Expand Down Expand Up @@ -144,6 +174,9 @@ deriving BEq, Repr, Inhabited

namespace Shape

instance : ToString Shape where
toString := reprStr

def empty : Shape := Shape.mk []

--! The number of elements in a tensor. All that's needed is the shape for this calculation.
Expand All @@ -155,6 +188,10 @@ def ndim (shape : Shape) : Nat := shape.val.length

def map (shape : Shape) (f : List Nat -> List Nat) : Shape := Shape.mk (f shape.val)

def dimIndexInRange (shape : Shape) (dimIndex : DimIndex) :=
shape.ndim == dimIndex.length &&
(shape.val.zip dimIndex).all fun (n, i) => i < n

/-!
Strides can be computed from the shape by figuring out how many elements you
need to jump over to get to the next spot and mulitplying by the bytes in each
Expand Down Expand Up @@ -205,6 +242,7 @@ def positionToDimIndex (strides : Strides) (n : Position) : DimIndex :=
let (_, idx) := strides.foldl foldFn (n, [])
idx.reverse

-- TODO: Return `Err Offset` for when the strides and index have different lengths?
def dimIndexToOffset (strides : Strides) (index : DimIndex) : Offset := dot strides (index.map Int.ofNat)

#guard positionToDimIndex [3, 1] 4 == [1, 1]
Expand All @@ -223,6 +261,21 @@ def allDimIndices (shape : Shape) : List DimIndex := Id.run do
#guard allDimIndices (Shape.mk [5]) == [[0], [1], [2], [3], [4]]
#guard allDimIndices (Shape.mk [3, 2]) == [[0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1]]

-- NumPy supports negative indices, which simply wrap around. E.g. `x[.., -1, ..] = x[.., n-1, ..]` where `n` is the
-- dimension in question. It only supports `-n` to `n`.
def intIndexToDimIndex (shape : Shape) (index : List Int) : Err DimIndex := do
if shape.ndim != index.length then .error "intsToDimIndex length mismatch" else
let conv (dim : Nat) (ind : Int) : Err Nat :=
if 0 <= ind then
if ind < dim then .ok ind.toNat
else .error "index out of bounds"
else if ind < -dim then .error "index out of bounds"
else .ok (dim + ind).toNat
(shape.val.zip index).mapM (fun (dim, ind) => conv dim ind)

#guard intIndexToDimIndex (Shape.mk [1, 2, 3]) [0, -1, -1] == (.ok [0, 1, 2])
#guard intIndexToDimIndex (Shape.mk [1, 2, 3]) [0, 1, -2] == (.ok [0, 1, 1])

end Shape

/-
Expand Down Expand Up @@ -278,7 +331,7 @@ Note: I tried writing this as a `do/for` loop and in this case the recursive
one seems nicer. We are walking over two lists simultaneously, which is easy
here but with a for loop is either quadratic or awkward.
-/
def next (iter : DimsIter) : List Nat × DimsIter :=
def next (iter : DimsIter) : DimIndex × DimsIter :=
-- Invariant: `acc` is a list of 0s, so doesn't need to be reversed
let rec loop (acc ms ns : List Nat) : List Nat :=
match ms, ns with
Expand All @@ -293,7 +346,7 @@ def next (iter : DimsIter) : List Nat × DimsIter :=
let curr' := loop [] iter.dims iter.curr
(iter.curr.reverse, { iter with curr := curr' })

instance [Monad m] : ForIn m DimsIter (List Nat) where
instance [Monad m] : ForIn m DimsIter DimIndex where
forIn {α} [Monad m] (iter : DimsIter) (x : α) (f : List Nat -> α -> m (ForInStep α)) : m α := do
let mut iter := iter
let mut res := x
Expand All @@ -305,7 +358,7 @@ instance [Monad m] : ForIn m DimsIter (List Nat) where
| .done k => return k
return res

private def toList (iter : DimsIter) : List (List Nat) := Id.run do
private def toList (iter : DimsIter) : List DimIndex := Id.run do
let mut res := []
for xs in iter do
res := xs :: res
Expand All @@ -321,7 +374,7 @@ private def toList (iter : DimsIter) : List (List Nat) := Id.run do
#guard (DimsIter.make $ Shape.mk [1, 1, 2]).toList == [[0, 0, 0], [0, 0, 1]]
#guard (DimsIter.make $ Shape.mk [3, 2]).toList == [[0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1]]

private def testBreak (iter : DimsIter) : List (List Nat) := Id.run do
private def testBreak (iter : DimsIter) : List DimIndex := Id.run do
let mut res := []
for xs in iter do
res := xs :: res
Expand All @@ -330,7 +383,7 @@ private def testBreak (iter : DimsIter) : List (List Nat) := Id.run do

#guard (DimsIter.make $ Shape.mk [3, 2]).testBreak == [[0, 0]]

private def testReturn (iter : DimsIter) : List (List Nat) := Id.run do
private def testReturn (iter : DimsIter) : List DimIndex := Id.run do
let mut res := []
let mut i := 0
for xs in iter do
Expand Down
14 changes: 14 additions & 0 deletions TensorLib/Dtype.lean
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ def isMultiByte (x : Name) : Bool := match x with
| bool | int8 | uint8 => false
| _ => true

def isInt (x : Name) : Bool := match x with
| int8 | int16 | int32 | int64 => true
| _ => false

def isUint (x : Name) : Bool := match x with
| uint8 | uint16 | uint32 | uint64 => true
| _ => false

def isIntLike (x : Name) : Bool := x.isInt || x.isUint

--! Number of bytes used by each element of the given dtype
def itemsize (x : Name) : Nat := match x with
| float64 | int64 | uint64 => 8
Expand All @@ -57,6 +67,10 @@ def itemsize (t : Dtype) := t.name.itemsize

def sizedStrides (dtype : Dtype) (s : Shape) : Strides := List.map (fun x => x * dtype.itemsize) s.unitStrides

def isInt (dtype : Dtype) : Bool := dtype.name.isInt
def isUint (dtype : Dtype) : Bool := dtype.name.isUint
def isIntLike (dtype : Dtype) : Bool := dtype.isInt || dtype.isUint

def int8 : Dtype := Dtype.mk Dtype.Name.int8 ByteOrder.littleEndian
def uint8 : Dtype := Dtype.mk Dtype.Name.uint8 ByteOrder.littleEndian
def uint64 : Dtype := Dtype.mk Dtype.Name.uint64 ByteOrder.littleEndian
Expand Down
Loading

0 comments on commit 575bf43

Please sign in to comment.