Skip to content

Commit

Permalink
Merge pull request #129 from ekmett/fizruk/th-ghc8
Browse files Browse the repository at this point in the history
Add GadtC and RecGadtC support for makeFree
  • Loading branch information
RyanGlScott committed Feb 8, 2016
2 parents baeedfe + 7a67863 commit ad1080d
Showing 1 changed file with 119 additions and 36 deletions.
155 changes: 119 additions & 36 deletions src/Control/Monad/Free/TH.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ module Control.Monad.Free.TH
import Control.Arrow
import Control.Monad
import Data.Char (toLower)
import Data.List ((\\), nub)
import Language.Haskell.TH
import Language.Haskell.TH.Syntax

#if !(MIN_VERSION_base(4,8,0))
import Control.Applicative
Expand Down Expand Up @@ -77,7 +79,7 @@ findValueOrFail s = lookupValueName s >>= maybe (fail $ s ++ "is not in scope")
mkOpName :: String -> Q String
mkOpName (':':name) = return name
mkOpName ( c :name) = return $ toLower c : name
mkOpName _ = fail "null constructor name"
mkOpName _ = fail "impossible happened: empty (null) constructor name"

-- | Check if parameter is used in type.
usesTV :: Name -> Type -> Bool
Expand All @@ -88,8 +90,8 @@ usesTV n (ForallT bs _ t) = usesTV n t && n `notElem` map tyVarBndrName bs
usesTV _ _ = False

-- | Analyze constructor argument.
mkArg :: Name -> Type -> Q Arg
mkArg n t
mkArg :: Type -> Type -> Q Arg
mkArg (VarT n) t
| usesTV n t =
case t of
-- if parameter is used as is, the return type should be ()
Expand All @@ -100,18 +102,37 @@ mkArg n t
-- expression is an N-tuple secion (,...,).
AppT (AppT ArrowT _) _ -> do
(ts, name) <- arrowsToTuple t
when (name /= n) $ fail "return type is not the parameter"
when (any (usesTV n) ts) $ fail $ unlines
[ "type variable " ++ pprint n ++ " is forbidden"
, "in a type like (a1 -> ... -> aN -> " ++ pprint n ++ ")"
, "in a constructor's argument type: " ++ pprint t ]
when (name /= n) $ fail $ unlines
[ "expected final return type `" ++ pprint n ++ "'"
, "but got `" ++ pprint name ++ "'"
, "in a constructor's argument type: `" ++ pprint t ++ "'" ]
let tup = foldl AppT (TupleT $ length ts) ts
xs <- mapM (const $ newName "x") ts
return $ Captured tup (LamE (map VarP xs) (TupE (map VarE xs)))
_ -> fail "don't know how to make Arg"
_ -> fail $ unlines
[ "expected a type variable `" ++ pprint n ++ "'"
, "or a type like (a1 -> ... -> aN -> " ++ pprint n ++ ")"
, "but got `" ++ pprint t ++ "'"
, "in a constructor's argument" ]
| otherwise = return $ Param t
where
arrowsToTuple (AppT (AppT ArrowT t1) (VarT name)) = return ([t1], name)
arrowsToTuple (AppT (AppT ArrowT t1) t2) = do
(ts, name) <- arrowsToTuple t2
return (t1:ts, name)
arrowsToTuple _ = fail "return type is not a variable"
arrowsToTuple (VarT name) = return ([], name)
arrowsToTuple rt = fail $ unlines
[ "expected final return type `" ++ pprint n ++ "'"
, "but got `" ++ pprint rt ++ "'"
, "in a constructor's argument type: `" ++ pprint t ++ "'" ]

mkArg n _ = fail $ unlines
[ "expected a type variable"
, "but got `" ++ pprint n ++ "'"
, "as the last parameter of the type constructor" ]

-- | Apply transformation to the return value independently of how many
-- parameters does @e@ have.
Expand Down Expand Up @@ -144,9 +165,32 @@ unifyCaptured :: Name -> [(Type, Exp)] -> Q (Type, [Exp])
unifyCaptured a [] = return (VarT a, [])
unifyCaptured _ [(t, e)] = return (t, [e])
unifyCaptured _ [x, y] = unifyT x y
unifyCaptured _ _ = fail "can't unify more than 2 arguments that use type parameter"
unifyCaptured _ xs = fail $ unlines
[ "can't unify more than 2 return types"
, "that use type parameter"
, "when unifying return types: "
, unlines (map (pprint . fst) xs) ]

extractVars :: Type -> [Name]
extractVars (ForallT bs _ t) = extractVars t \\ map bndrName bs
where
bndrName (PlainTV n) = n
bndrName (KindedTV n _) = n
extractVars (VarT n) = [n]
extractVars (AppT x y) = extractVars x ++ extractVars y
#if MIN_VERSION_template_haskell(2,8,0)
extractVars (SigT x k) = extractVars x ++ extractVars k
#else
extractVars (SigT x k) = extractVars x
#endif
#if MIN_VERSION_template_haskell(2,11,0)
extractVars (InfixT x _ y) = extractVars x ++ extractVars y
extractVars (UInfixT x _ y) = extractVars x ++ extractVars y
extractVars (ParensT x) = extractVars x
#endif
extractVars _ = []

liftCon' :: Bool -> [TyVarBndr] -> Cxt -> Type -> Name -> [Name] -> Name -> [Type] -> Q [Dec]
liftCon' :: Bool -> [TyVarBndr] -> Cxt -> Type -> Type -> [Type] -> Name -> [Type] -> Q [Dec]
liftCon' typeSig tvbs cx f n ns cn ts = do
-- prepare some names
opName <- mkName <$> mkOpName (nameBase cn)
Expand All @@ -168,9 +212,10 @@ liftCon' typeSig tvbs cx f n ns cn ts = do
let pat = map VarP xs -- this is LHS
exprs = zipExprs (map VarE xs) es args -- this is what ctor would be applied to
fval = foldl AppE (ConE cn) exprs -- this is RHS without liftF
q = tvbs ++ map PlainTV (qa ++ m : ns)
ns' = nub (concatMap extractVars ns)
q = filter nonNext tvbs ++ map PlainTV (qa ++ m : ns')
qa = case retType of VarT b | a == b -> [a]; _ -> []
f' = foldl AppT f (map VarT ns)
f' = foldl AppT f ns
return $ concat
[ if typeSig
#if MIN_VERSION_template_haskell(2,10,0)
Expand All @@ -180,16 +225,60 @@ liftCon' typeSig tvbs cx f n ns cn ts = do
#endif
else []
, [ FunD opName [ Clause pat (NormalB $ AppE (VarE liftF) fval) [] ] ] ]
where
nonNext (PlainTV pn) = VarT pn /= n
nonNext (KindedTV kn _) = VarT kn /= n

-- | Provide free monadic actions for a single value constructor.
liftCon :: Bool -> [TyVarBndr] -> Cxt -> Type -> Name -> [Name] -> Con -> Q [Dec]
liftCon typeSig ts cx f n ns con =
case con of
NormalC cName fields -> liftCon' typeSig ts cx f n ns cName $ map snd fields
RecC cName fields -> liftCon' typeSig ts cx f n ns cName $ map (\(_, _, ty) -> ty) fields
InfixC (_,t1) cName (_,t2) -> liftCon' typeSig ts cx f n ns cName [t1, t2]
ForallC ts' cx' con' -> liftCon typeSig (ts ++ ts') (cx ++ cx') f n ns con'
_ -> fail "Unsupported constructor type"
liftCon :: Bool -> [TyVarBndr] -> Cxt -> Type -> Type -> [Type] -> Maybe [Name] -> Con -> Q [Dec]
liftCon typeSig ts cx f n ns onlyCons con
| not (any (`melem` onlyCons) (constructorNames con)) = return []
| otherwise = case con of
NormalC cName fields -> liftCon' typeSig ts cx f n ns cName $ map snd fields
RecC cName fields -> liftCon' typeSig ts cx f n ns cName $ map (\(_, _, ty) -> ty) fields
InfixC (_,t1) cName (_,t2) -> liftCon' typeSig ts cx f n ns cName [t1, t2]
ForallC ts' cx' con' -> liftCon typeSig (ts ++ ts') (cx ++ cx') f n ns onlyCons con'
#if MIN_VERSION_template_haskell(2,11,0)
GadtC cNames fields resType -> do
decs <- forM (filter (`melem` onlyCons) cNames) $ \cName ->
liftGadtC cName fields resType typeSig ts cx f
return (concat decs)
RecGadtC cNames fields resType -> do
let fields' = map (\(_, x, y) -> (x, y)) fields
decs <- forM (filter (`melem` onlyCons) cNames) $ \cName ->
liftGadtC cName fields' resType typeSig ts cx f
return (concat decs)
#endif
_ -> fail $ "Unsupported constructor type: `" ++ pprint con ++ "'"

#if MIN_VERSION_template_haskell(2,11,0)
splitAppT :: Type -> [Type]
splitAppT (AppT x y) = splitAppT x ++ [y]
splitAppT t = [t]

liftGadtC :: Name -> [BangType] -> Type -> Bool -> [TyVarBndr] -> Cxt -> Type -> Q [Dec]
liftGadtC cName fields resType typeSig ts cx f =
liftCon typeSig ts cx f nextTy (init tys) Nothing (NormalC cName fields)
where
(_f : tys) = splitAppT resType
nextTy = last tys
#endif

melem :: Eq a => a -> Maybe [a] -> Bool
melem _ Nothing = True
melem x (Just xs) = x `elem` xs

-- | Get construstor name(s).
constructorNames :: Con -> [Name]
constructorNames (NormalC name _) = [name]
constructorNames (RecC name _) = [name]
constructorNames (InfixC _ name _) = [name]
constructorNames (ForallC _ _ c) = constructorNames c
#if MIN_VERSION_template_haskell(2,11,0)
constructorNames (GadtC names _ _) = names
constructorNames (RecGadtC names _ _) = names
#endif
constructorNames con' = fail $ "Unsupported constructor type: `" ++ pprint con' ++ "'"

-- | Provide free monadic actions for a type declaration.
liftDec :: Bool -- ^ Include type signature?
Expand All @@ -201,24 +290,16 @@ liftDec typeSig onlyCons (DataD _ tyName tyVarBndrs _ cons _)
#else
liftDec typeSig onlyCons (DataD _ tyName tyVarBndrs cons _)
#endif
| null tyVarBndrs = fail $ "Type " ++ show tyName ++ " needs at least one free variable"
| otherwise = concat <$> mapM (liftCon typeSig [] [] con nextTyName (init tyNames)) cons'
| null tyVarBndrs = fail $ "Type constructor " ++ pprint tyName ++ " needs at least one type parameter"
| otherwise = concat <$> mapM (liftCon typeSig [] [] con nextTy (init tys) onlyCons) cons
where
cons' = case onlyCons of
Nothing -> cons
Just ns -> filter (\c -> constructorName c `elem` ns) cons
tyNames = map tyVarBndrName tyVarBndrs
nextTyName = last tyNames
tys = map (VarT . tyVarBndrName) tyVarBndrs
nextTy = last tys
con = ConT tyName
liftDec _ _ dec = fail $ "liftDec: Don't know how to lift " ++ show dec

-- | Get construstor name.
constructorName :: Con -> Name
constructorName (NormalC name _) = name
constructorName (RecC name _) = name
constructorName (InfixC _ name _) = name
constructorName (ForallC _ _ c) = constructorName c
constructorName _ = error "Unsupported constructor type"
liftDec _ _ dec = fail $ unlines
[ "failed to derive makeFree operations:"
, "expected a data type constructor"
, "but got " ++ pprint dec ]

-- | Generate monadic actions for a data type.
genFree :: Bool -- ^ Include type signature?
Expand All @@ -243,7 +324,9 @@ genFreeCon typeSig cname = do
_
#endif
-> genFree typeSig (Just [cname]) tname
_ -> fail "makeFreeCon expects a data constructor"
_ -> fail $ unlines
[ "expected a data constructor"
, "but got " ++ pprint info ]

-- | @$('makeFree' ''T)@ provides free monadic actions for the
-- constructors of the given data type @T@.
Expand Down

0 comments on commit ad1080d

Please sign in to comment.