Skip to content

Commit

Permalink
fix undef size computation
Browse files Browse the repository at this point in the history
  • Loading branch information
tmcdonell committed Sep 28, 2023
1 parent 03a0b60 commit 934a12f
Showing 1 changed file with 10 additions and 50 deletions.
60 changes: 10 additions & 50 deletions src/Data/Array/Accelerate/Representation/Elt.hs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ undefElt = tuple

vector :: KnownNat n => Proxy# n -> SingleIntegralType t -> Vec n t
vector n t = runST $ do
let bytes = bytesElt (TupRsingle (NumScalarType (IntegralNumType (VectorIntegralType n t))))
let bits = case t of
TypeInt w -> w
TypeWord w -> w
bytes = max 1 (quot (fromInteger (natVal' n) * bits) 8)
mba <- newAlignedPinnedByteArray bytes 16
ByteArray ba# <- unsafeFreezeByteArray mba
return $! Vec ba#
Expand All @@ -92,60 +95,17 @@ undefElt = tuple

vector :: KnownNat n => Proxy# n -> SingleFloatingType t -> Vec n t
vector n t = runST $ do
let bytes = bytesElt (TupRsingle (NumScalarType (FloatingNumType (VectorFloatingType n t))))
let bits = case t of
TypeFloat16 -> 16
TypeFloat32 -> 32
TypeFloat64 -> 64
TypeFloat128 -> 128
bytes = max 1 (quot (fromInteger (natVal' n) * bits) 8)
mba <- newAlignedPinnedByteArray bytes 16
ByteArray ba# <- unsafeFreezeByteArray mba
return $! Vec ba#


bytesElt :: TypeR e -> Int
bytesElt = tuple
where
tuple :: TypeR t -> Int
tuple TupRunit = 0
tuple (TupRpair ta tb) = tuple ta + tuple tb
tuple (TupRsingle t) = scalar t

scalar :: ScalarType t -> Int
scalar (NumScalarType t) = num t
scalar (BitScalarType t) = bit t

bit :: BitType t -> Int
bit TypeBit = 1 -- stored as Word8
bit (TypeMask n) = quot (fromInteger (natVal' n)+7) 8

num :: NumType t -> Int
num (IntegralNumType t) = integral t
num (FloatingNumType t) = floating t

integral :: IntegralType t -> Int
integral = \case
SingleIntegralType t -> single t
VectorIntegralType n t -> fromInteger (natVal' n) * single t
where
single :: SingleIntegralType t -> Int
single TypeInt8 = 1
single TypeInt16 = 2
single TypeInt32 = 4
single TypeInt64 = 8
single TypeInt128 = 16
single TypeWord8 = 1
single TypeWord16 = 2
single TypeWord32 = 4
single TypeWord64 = 8
single TypeWord128 = 16

floating :: FloatingType t -> Int
floating = \case
SingleFloatingType t -> single t
VectorFloatingType n t -> fromInteger (natVal' n) * single t
where
single :: SingleFloatingType t -> Int
single TypeFloat16 = 2
single TypeFloat32 = 4
single TypeFloat64 = 8
single TypeFloat128 = 16

showElt :: TypeR e -> e -> String
showElt t v = showsElt t v ""

Expand Down

0 comments on commit 934a12f

Please sign in to comment.