diff --git a/src/Data/Array/Accelerate/Data/Complex.hs b/src/Data/Array/Accelerate/Data/Complex.hs index 58c188d87..6872c1db7 100644 --- a/src/Data/Array/Accelerate/Data/Complex.hs +++ b/src/Data/Array/Accelerate/Data/Complex.hs @@ -1,19 +1,20 @@ -{-# LANGUAGE CPP #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RebindableSyntax #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RebindableSyntax #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | -- Module : Data.Array.Accelerate.Data.Complex @@ -30,7 +31,7 @@ module Data.Array.Accelerate.Data.Complex ( -- * Rectangular from - Complex(..), pattern (::+), + Complex, pattern (:+), real, imag, @@ -66,7 +67,7 @@ import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Type import qualified Data.Primitive.Vec as Prim -import Data.Complex ( Complex(..) ) +import Data.Complex ( Complex ) import Data.Primitive.Types import Prelude ( ($) ) import qualified Data.Complex as C @@ -76,19 +77,32 @@ import Data.Type.Equality #endif -infix 6 ::+ -pattern (::+) :: Elt a => Exp a -> Exp a -> Exp (Complex a) -pattern r ::+ i <- (deconstructComplex -> (r, i)) - where (::+) = constructComplex -{-# COMPLETE (::+) #-} +infix 6 :+ +pattern (:+) :: IsComplex a b => a -> a -> b +pattern r :+ i <- (matchComplex -> (r,i)) + where (:+) = buildComplex +{-# COMPLETE (:+) :: Complex #-} +{-# COMPLETE (:+) :: Exp #-} + +class IsComplex a b | b -> a where + matchComplex :: b -> (a, a) + buildComplex :: a -> a -> b + +instance IsComplex a (Complex a) where + buildComplex = (C.:+) + matchComplex (r C.:+ i) = (r, i) -- Use an array-of-structs representation for complex numbers if possible. --- This matches the standard C-style layout, but we can use this representation only at --- specific types (not for any type 'a') as we can only have vectors of primitive type. --- For other types, we use a structure-of-arrays representation. This is handled by the --- ComplexR. We use the GADT ComplexR and function complexR to reconstruct --- information on how the elements are represented. +-- +-- This matches the standard C-style layout, but we can use this representation +-- only at specific types (not for any type 'a') as we can only have vectors of +-- primitive type. For other types, we use a structure-of-arrays representation. +-- This is handled by the ComplexR. We use the GADT ComplexR and function +-- complexR to reconstruct information on how the elements are represented. +-- +-- TODO: This is no longer true, we could SIMD-ify more types here. +-- - TLM 2023-09-28 -- instance Elt a => Elt (Complex a) where type EltR (Complex a) = ComplexR (EltR a) @@ -180,108 +194,108 @@ complexR = tuple TypeFloat64 -> ComplexVec numType TypeFloat128 -> ComplexVec numType -constructComplex :: forall a. Elt a => Exp a -> Exp a -> Exp (Complex a) -constructComplex r@(Exp r') i@(Exp i') = - case complexR (eltR @a) of - ComplexTup -> Pattern (r,i) - ComplexVec t -> Exp $ num t r' i' - where - num :: NumType (Prim.V2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) - num (IntegralNumType t) = integral t - num (FloatingNumType t) = floating t - - integral :: IntegralType (Prim.V2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) - integral (SingleIntegralType t) = case t of - integral (VectorIntegralType n t) = - let v = NumScalarType (IntegralNumType (VectorIntegralType n t)) - in case t of - TypeInt8 -> pack v - TypeInt16 -> pack v - TypeInt32 -> pack v - TypeInt64 -> pack v - TypeInt128 -> pack v - TypeWord8 -> pack v - TypeWord16 -> pack v - TypeWord32 -> pack v - TypeWord64 -> pack v - TypeWord128 -> pack v - - floating :: FloatingType (Prim.V2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) - floating (SingleFloatingType t) = case t of - floating (VectorFloatingType n t) = - let v = NumScalarType (FloatingNumType (VectorFloatingType n t)) - in case t of - TypeFloat16 -> pack v - TypeFloat32 -> pack v - TypeFloat64 -> pack v - TypeFloat128 -> pack v - - pack :: ScalarType (Prim.Vec 2 t) -> SmartExp t -> SmartExp t -> SmartExp (Prim.Vec 2 t) - pack v x y - = SmartExp (Insert v TypeWord8 - (SmartExp (Insert v TypeWord8 (SmartExp (Undef v)) (SmartExp (Const scalarType 0)) x)) - (SmartExp (Const scalarType 1)) y) - -deconstructComplex :: forall a. Elt a => Exp (Complex a) -> (Exp a, Exp a) -deconstructComplex (Exp c) = - case complexR (eltR @a) of - ComplexTup -> - let i = SmartExp (Prj PairIdxRight c) - r = SmartExp (Prj PairIdxRight (SmartExp (Prj PairIdxLeft c))) - in (Exp r, Exp i) - ComplexVec t -> - let (r, i) = num t c - in (Exp r, Exp i) - where - num :: NumType (Prim.V2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) - num (IntegralNumType t) = integral t - num (FloatingNumType t) = floating t - - integral :: IntegralType (Prim.V2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) - integral (SingleIntegralType t) = case t of - integral (VectorIntegralType n t) = - let v = NumScalarType (IntegralNumType (VectorIntegralType n t)) - in case t of - TypeInt8 -> unpack v - TypeInt16 -> unpack v - TypeInt32 -> unpack v - TypeInt64 -> unpack v - TypeInt128 -> unpack v - TypeWord8 -> unpack v - TypeWord16 -> unpack v - TypeWord32 -> unpack v - TypeWord64 -> unpack v - TypeWord128 -> unpack v - - floating :: FloatingType (Prim.V2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) - floating (SingleFloatingType t) = case t of - floating (VectorFloatingType n t) = - let v = NumScalarType (FloatingNumType (VectorFloatingType n t)) - in case t of - TypeFloat16 -> unpack v - TypeFloat32 -> unpack v - TypeFloat64 -> unpack v - TypeFloat128 -> unpack v - - unpack :: ScalarType (Prim.Vec 2 t) -> SmartExp (Prim.Vec 2 t) -> (SmartExp t, SmartExp t) - unpack v x = - let r = SmartExp (Extract v TypeWord8 x (SmartExp (Const scalarType 0))) - i = SmartExp (Extract v TypeWord8 x (SmartExp (Const scalarType 1))) - in - (r, i) + +instance Elt a => IsComplex (Exp a) (Exp (Complex a)) where + matchComplex (Exp c) = + case complexR (eltR @a) of + ComplexTup -> + let i = SmartExp (Prj PairIdxRight c) + r = SmartExp (Prj PairIdxRight (SmartExp (Prj PairIdxLeft c))) + in (Exp r, Exp i) + ComplexVec t -> + let (r, i) = num t c + in (Exp r, Exp i) + where + num :: NumType (Prim.V2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType (Prim.V2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) + integral (SingleIntegralType t) = case t of + integral (VectorIntegralType n t) = + let v = NumScalarType (IntegralNumType (VectorIntegralType n t)) + in case t of + TypeInt8 -> unpack v + TypeInt16 -> unpack v + TypeInt32 -> unpack v + TypeInt64 -> unpack v + TypeInt128 -> unpack v + TypeWord8 -> unpack v + TypeWord16 -> unpack v + TypeWord32 -> unpack v + TypeWord64 -> unpack v + TypeWord128 -> unpack v + + floating :: FloatingType (Prim.V2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) + floating (SingleFloatingType t) = case t of + floating (VectorFloatingType n t) = + let v = NumScalarType (FloatingNumType (VectorFloatingType n t)) + in case t of + TypeFloat16 -> unpack v + TypeFloat32 -> unpack v + TypeFloat64 -> unpack v + TypeFloat128 -> unpack v + + unpack :: ScalarType (Prim.Vec 2 t) -> SmartExp (Prim.Vec 2 t) -> (SmartExp t, SmartExp t) + unpack v x = + let r = SmartExp (Extract v TypeWord8 x (SmartExp (Const scalarType 0))) + i = SmartExp (Extract v TypeWord8 x (SmartExp (Const scalarType 1))) + in + (r, i) + + buildComplex r@(Exp r') i@(Exp i') = + case complexR (eltR @a) of + ComplexTup -> Pattern (r,i) + ComplexVec t -> Exp $ num t r' i' + where + num :: NumType (Prim.V2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType (Prim.V2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) + integral (SingleIntegralType t) = case t of + integral (VectorIntegralType n t) = + let v = NumScalarType (IntegralNumType (VectorIntegralType n t)) + in case t of + TypeInt8 -> pack v + TypeInt16 -> pack v + TypeInt32 -> pack v + TypeInt64 -> pack v + TypeInt128 -> pack v + TypeWord8 -> pack v + TypeWord16 -> pack v + TypeWord32 -> pack v + TypeWord64 -> pack v + TypeWord128 -> pack v + + floating :: FloatingType (Prim.V2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) + floating (SingleFloatingType t) = case t of + floating (VectorFloatingType n t) = + let v = NumScalarType (FloatingNumType (VectorFloatingType n t)) + in case t of + TypeFloat16 -> pack v + TypeFloat32 -> pack v + TypeFloat64 -> pack v + TypeFloat128 -> pack v + + pack :: ScalarType (Prim.Vec 2 t) -> SmartExp t -> SmartExp t -> SmartExp (Prim.Vec 2 t) + pack v x y + = SmartExp (Insert v TypeWord8 + (SmartExp (Insert v TypeWord8 (SmartExp (Undef v)) (SmartExp (Const scalarType 0)) x)) + (SmartExp (Const scalarType 1)) y) instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Complex a) where type Plain (Complex a) = Complex (Plain a) - lift (r :+ i) = lift r ::+ lift i + lift (r :+ i) = lift r :+ lift i instance Elt a => Unlift Exp (Complex (Exp a)) where - unlift (r ::+ i) = r :+ i + unlift (r :+ i) = r :+ i instance Eq a => Eq (Complex a) where - r1 ::+ c1 == r2 ::+ c2 = r1 == r2 && c1 == c2 - r1 ::+ c1 /= r2 ::+ c2 = r1 /= r2 || c1 /= c2 + r1 :+ c1 == r2 :+ c2 = r1 == r2 && c1 == c2 + r1 :+ c1 /= r2 :+ c2 = r1 /= r2 || c1 /= c2 instance (RealFloat a, Exponent a ~ Int) => P.Num (Exp (Complex a)) where (+) = case complexR (eltR @a) of @@ -294,20 +308,20 @@ instance (RealFloat a, Exponent a ~ Int) => P.Num (Exp (Complex a)) where negate = case complexR (eltR @a) of ComplexTup -> lift1 (negate :: Complex (Exp a) -> Complex (Exp a)) ComplexVec t -> mkPrimUnary $ PrimNeg t - signum z@(x ::+ y) = + signum z@(x :+ y) = if z == 0 then z else let r = magnitude z - in x/r ::+ y/r - abs z = magnitude z ::+ 0 - fromInteger n = fromInteger n ::+ 0 + in x/r :+ y/r + abs z = magnitude z :+ 0 + fromInteger n = fromInteger n :+ 0 instance (RealFloat a, Exponent a ~ Int) => P.Fractional (Exp (Complex a)) where - fromRational x = fromRational x ::+ 0 - z / z' = (x*x''+y*y'') / d ::+ (y*x''-x*y'') / d + fromRational x = fromRational x :+ 0 + z / z' = (x*x''+y*y'') / d :+ (y*x''-x*y'') / d where - x :+ y = unlift z - x' :+ y' = unlift z' + x :+ y = z + x' :+ y' = z' -- x'' = scaleFloat k x' y'' = scaleFloat k y' @@ -315,14 +329,14 @@ instance (RealFloat a, Exponent a ~ Int) => P.Fractional (Exp (Complex a)) where d = x'*x'' + y'*y'' instance (RealFloat a, Exponent a ~ Int, BitOrMask (EltR a) ~ Bit) => P.Floating (Exp (Complex a)) where - pi = pi ::+ 0 - exp (x ::+ y) = let expx = exp x - in expx * cos y ::+ expx * sin y - log z = log (magnitude z) ::+ phase z - sqrt z@(x ::+ y) = + pi = pi :+ 0 + exp (x :+ y) = let expx = exp x + in expx * cos y :+ expx * sin y + log z = log (magnitude z) :+ phase z + sqrt z@(x :+ y) = if z == 0 then 0 - else u ::+ (y < 0 ? (-v, v)) + else u :+ (y < 0 ? (-v, v)) where T2 u v = x < 0 ? (T2 v' u', T2 u' v') v' = abs y / (u'*2) @@ -331,50 +345,50 @@ instance (RealFloat a, Exponent a ~ Int, BitOrMask (EltR a) ~ Bit) => P.Floating x ** y = if y == 0 then 1 else if x == 0 then if exp_r > 0 then 0 else - if exp_r < 0 then inf ::+ 0 - else nan ::+ nan + if exp_r < 0 then inf :+ 0 + else nan :+ nan else if isInfinite r || isInfinite i - then if exp_r > 0 then inf ::+ 0 else + then if exp_r > 0 then inf :+ 0 else if exp_r < 0 then 0 - else nan ::+ nan + else nan :+ nan else exp (log x * y) where - r ::+ i = x - exp_r ::+ _ = y + r :+ i = x + exp_r :+ _ = y -- inf = 1 / 0 nan = 0 / 0 - sin (x ::+ y) = sin x * cosh y ::+ cos x * sinh y - cos (x ::+ y) = cos x * cosh y ::+ (- sin x * sinh y) - tan (x ::+ y) = (sinx*coshy ::+ cosx*sinhy) / (cosx*coshy ::+ (-sinx*sinhy)) + sin (x :+ y) = sin x * cosh y :+ cos x * sinh y + cos (x :+ y) = cos x * cosh y :+ (- sin x * sinh y) + tan (x :+ y) = (sinx*coshy :+ cosx*sinhy) / (cosx*coshy :+ (-sinx*sinhy)) where sinx = sin x cosx = cos x sinhy = sinh y coshy = cosh y - sinh (x ::+ y) = cos y * sinh x ::+ sin y * cosh x - cosh (x ::+ y) = cos y * cosh x ::+ sin y * sinh x - tanh (x ::+ y) = (cosy*sinhx ::+ siny*coshx) / (cosy*coshx ::+ siny*sinhx) + sinh (x :+ y) = cos y * sinh x :+ sin y * cosh x + cosh (x :+ y) = cos y * cosh x :+ sin y * sinh x + tanh (x :+ y) = (cosy*sinhx :+ siny*coshx) / (cosy*coshx :+ siny*sinhx) where siny = sin y cosy = cos y sinhx = sinh x coshx = cosh x - asin z@(x ::+ y) = y' ::+ (-x') + asin z@(x :+ y) = y' :+ (-x') where - x' ::+ y' = log (((-y) ::+ x) + sqrt (1 - z*z)) + x' :+ y' = log (((-y) :+ x) + sqrt (1 - z*z)) - acos z = y'' ::+ (-x'') + acos z = y'' :+ (-x'') where - x'' ::+ y'' = log (z + ((-y') ::+ x')) - x' ::+ y' = sqrt (1 - z*z) + x'' :+ y'' = log (z + ((-y') :+ x')) + x' :+ y' = sqrt (1 - z*z) - atan z@(x ::+ y) = y' ::+ (-x') + atan z@(x :+ y) = y' :+ (-x') where - x' ::+ y' = log (((1-y) ::+ x) / sqrt (1+z*z)) + x' :+ y' = log (((1-y) :+ x) / sqrt (1+z*z)) asinh z = log (z + sqrt (1+z*z)) acosh z = log (z + (z+1) * sqrt ((z-1)/(z+1))) @@ -382,18 +396,18 @@ instance (RealFloat a, Exponent a ~ Int, BitOrMask (EltR a) ~ Bit) => P.Floating instance (FromIntegral a b, Num b, Elt (Complex b)) => FromIntegral a (Complex b) where - fromIntegral x = fromIntegral x ::+ 0 + fromIntegral x = fromIntegral x :+ 0 -- | @since 1.2.0.0 -- instance Functor Complex where - fmap f (r ::+ i) = f r ::+ f i + fmap f (r :+ i) = f r :+ f i -- | The non-negative magnitude of a complex number -- magnitude :: (RealFloat a, Exponent a ~ Int) => Exp (Complex a) -> Exp a -magnitude (r ::+ i) = scaleFloat k (sqrt (sqr (scaleFloat mk r) + sqr (scaleFloat mk i))) +magnitude (r :+ i) = scaleFloat k (sqrt (sqr (scaleFloat mk r) + sqr (scaleFloat mk i))) where k = max (exponent r) (exponent i) mk = -k @@ -405,13 +419,13 @@ magnitude (r ::+ i) = scaleFloat k (sqrt (sqr (scaleFloat mk r) + sqr (scaleFloa -- @since 1.3.0.0 -- magnitude' :: RealFloat a => Exp (Complex a) -> Exp a -magnitude' (r ::+ i) = sqrt (r*r + i*i) +magnitude' (r :+ i) = sqrt (r*r + i*i) -- | The phase of a complex number, in the range @(-'pi', 'pi']@. If the -- magnitude is zero, then so is the phase. -- phase :: RealFloat a => Exp (Complex a) -> Exp a -phase (r ::+ i) = +phase (r :+ i) = if r == 0 && i == 0 then 0 else atan2 i r @@ -437,17 +451,17 @@ cis = lift1 (C.cis :: Exp a -> Complex (Exp a)) -- | Return the real part of a complex number -- real :: Elt a => Exp (Complex a) -> Exp a -real (r ::+ _) = r +real (r :+ _) = r -- | Return the imaginary part of a complex number -- imag :: Elt a => Exp (Complex a) -> Exp a -imag (_ ::+ i) = i +imag (_ :+ i) = i -- | Return the complex conjugate of a complex number, defined as -- -- > conjugate(Z) = X - iY -- conjugate :: Num a => Exp (Complex a) -> Exp (Complex a) -conjugate z = real z ::+ (- imag z) +conjugate z = real z :+ (- imag z)