From 934a12fe022b42d3441dec549c41513c7b8b08d2 Mon Sep 17 00:00:00 2001 From: "Trevor L. McDonell" Date: Thu, 28 Sep 2023 15:13:07 +0200 Subject: [PATCH] fix undef size computation --- .../Array/Accelerate/Representation/Elt.hs | 60 ++++--------------- 1 file changed, 10 insertions(+), 50 deletions(-) diff --git a/src/Data/Array/Accelerate/Representation/Elt.hs b/src/Data/Array/Accelerate/Representation/Elt.hs index b888074fa..af2ab2d64 100644 --- a/src/Data/Array/Accelerate/Representation/Elt.hs +++ b/src/Data/Array/Accelerate/Representation/Elt.hs @@ -74,7 +74,10 @@ undefElt = tuple vector :: KnownNat n => Proxy# n -> SingleIntegralType t -> Vec n t vector n t = runST $ do - let bytes = bytesElt (TupRsingle (NumScalarType (IntegralNumType (VectorIntegralType n t)))) + let bits = case t of + TypeInt w -> w + TypeWord w -> w + bytes = max 1 (quot (fromInteger (natVal' n) * bits) 8) mba <- newAlignedPinnedByteArray bytes 16 ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# @@ -92,60 +95,17 @@ undefElt = tuple vector :: KnownNat n => Proxy# n -> SingleFloatingType t -> Vec n t vector n t = runST $ do - let bytes = bytesElt (TupRsingle (NumScalarType (FloatingNumType (VectorFloatingType n t)))) + let bits = case t of + TypeFloat16 -> 16 + TypeFloat32 -> 32 + TypeFloat64 -> 64 + TypeFloat128 -> 128 + bytes = max 1 (quot (fromInteger (natVal' n) * bits) 8) mba <- newAlignedPinnedByteArray bytes 16 ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# -bytesElt :: TypeR e -> Int -bytesElt = tuple - where - tuple :: TypeR t -> Int - tuple TupRunit = 0 - tuple (TupRpair ta tb) = tuple ta + tuple tb - tuple (TupRsingle t) = scalar t - - scalar :: ScalarType t -> Int - scalar (NumScalarType t) = num t - scalar (BitScalarType t) = bit t - - bit :: BitType t -> Int - bit TypeBit = 1 -- stored as Word8 - bit (TypeMask n) = quot (fromInteger (natVal' n)+7) 8 - - num :: NumType t -> Int - num (IntegralNumType t) = integral t - num (FloatingNumType t) = floating t - - integral :: IntegralType t -> Int - integral = \case - SingleIntegralType t -> single t - VectorIntegralType n t -> fromInteger (natVal' n) * single t - where - single :: SingleIntegralType t -> Int - single TypeInt8 = 1 - single TypeInt16 = 2 - single TypeInt32 = 4 - single TypeInt64 = 8 - single TypeInt128 = 16 - single TypeWord8 = 1 - single TypeWord16 = 2 - single TypeWord32 = 4 - single TypeWord64 = 8 - single TypeWord128 = 16 - - floating :: FloatingType t -> Int - floating = \case - SingleFloatingType t -> single t - VectorFloatingType n t -> fromInteger (natVal' n) * single t - where - single :: SingleFloatingType t -> Int - single TypeFloat16 = 2 - single TypeFloat32 = 4 - single TypeFloat64 = 8 - single TypeFloat128 = 16 - showElt :: TypeR e -> e -> String showElt t v = showsElt t v ""