Skip to content

Commit

Permalink
Reduce Mod n p <= q to p <= q + 1 and 1 <= p
Browse files Browse the repository at this point in the history
  • Loading branch information
rowanG077 committed Apr 19, 2024
1 parent 473f7d2 commit 085a8dd
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# Unreleased
* Fix faulty lookup for `Mod` and `Div` in GHC >= 9.2
* Reduce `Mod n p <= q` to `p <= q + 1` and `1 <= p`

# 0.4.7
* Fix Plugin silently fails when normalizing <= in GHC 9.4+ [#50](https://github.com/clash-lang/ghc-typelits-extra/issues/50)
Expand Down
12 changes: 9 additions & 3 deletions src-ghc-9.4/GHC/TypeLits/Extra/Solver.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ import GHC.TcPluginM.Extra
import GHC.Builtin.Names (eqPrimTyConKey, hasKey, getUnique)
import GHC.Builtin.Types (promotedTrueDataCon, promotedFalseDataCon)
import GHC.Builtin.Types (boolTy, naturalTy, cTupleDataCon, cTupleTyCon)
import GHC.Builtin.Types.Literals (typeNatDivTyCon, typeNatModTyCon, typeNatCmpTyCon)
import GHC.Builtin.Types.Literals (typeNatAddTyCon, typeNatDivTyCon, typeNatModTyCon, typeNatCmpTyCon)
import GHC.Core.Coercion (mkUnivCo)
import GHC.Core.DataCon (dataConWrapId)
import GHC.Core.Predicate (EqRel (NomEq), Pred (EqPred, IrredPred), classifyPredType)
import GHC.Core.Reduction (Reduction(..))
import GHC.Core.TyCon (TyCon)
import GHC.Core.TyCo.Rep (Type (..), TyLit (..), UnivCoProvenance (PluginProv))
import GHC.Core.Type (Kind, mkTyConApp, splitTyConApp_maybe, typeKind)
import GHC.Core.Type (Kind, mkTyConApp, mkNumLitTy, splitTyConApp_maybe, typeKind)
#if MIN_VERSION_ghc(9,6,0)
import GHC.Core.TyCo.Compare (eqType)
#else
Expand Down Expand Up @@ -181,7 +181,13 @@ simplifyExtra defs eqs = tcPluginTrace "simplifyExtra" (ppr eqs) >> simples [] [
| otherwise -> return (Impossible eq)
(p, Max x y)
| b && (p == x || p == y) -> simples (((,) <$> evMagic ct <*> pure ct):evs) news eqs'

-- transform: Mod n p <= q
-- to: p <= q + 1, 1 <= p
(Mod _ p, q) | isWantedCt ct -> do
let succQ = toCType $ TyConApp typeNatAddTyCon [reifyEOP defs q, mkNumLitTy 1]
modCt <- createWantedFromNormalised defs (NatInequality ct p succQ b norm)
gteOneCt <- createWantedFromNormalised defs (NatInequality ct (I 1) p b norm)
simples (((,) <$> evMagic ct <*> pure ct):evs) (modCt:gteOneCt:news) eqs'
-- transform: q ~ Max x y => (p <=? q ~ True)
-- to: (p <=? Max x y) ~ True
-- and try to solve that along with the rest of the eqs'
Expand Down
16 changes: 11 additions & 5 deletions src-pre-ghc-9.4/GHC/TypeLits/Extra/Solver.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ import GHC.Builtin.Types (boolTy, naturalTy)
#else
import GHC.Builtin.Types (typeNatKind)
#endif
import GHC.Builtin.Types.Literals (typeNatDivTyCon, typeNatModTyCon)
import GHC.Builtin.Types.Literals (typeNatAddTyCon, typeNatDivTyCon, typeNatModTyCon)
#if MIN_VERSION_ghc(9,2,0)
import GHC.Builtin.Types.Literals (typeNatCmpTyCon)
#else
import GHC.Builtin.Types.Literals (typeNatLeqTyCon)
#endif
import GHC.Core.Predicate (EqRel (NomEq), Pred (EqPred), classifyPredType)
import GHC.Core.TyCo.Rep (Type (..))
import GHC.Core.Type (Kind, eqType, mkTyConApp, splitTyConApp_maybe, typeKind)
import GHC.Core.Type (Kind, eqType, mkNumLitTy, mkTyConApp, splitTyConApp_maybe, typeKind)
import GHC.Data.FastString (fsLit)
import GHC.Driver.Plugins (Plugin (..), defaultPlugin, purePlugin)
import GHC.Tc.Plugin (TcPluginM, tcLookupTyCon, tcPluginTrace)
Expand All @@ -77,10 +77,10 @@ import PrelNames (eqPrimTyConKey, hasKey)
import TcEvidence (EvTerm)
import TcPluginM (TcPluginM, tcLookupTyCon, tcPluginTrace)
import TcRnTypes (TcPlugin(..), TcPluginResult (..))
import Type (Kind, eqType, mkTyConApp, splitTyConApp_maybe)
import Type (Kind, eqType, mkNumLitTy, mkTyConApp, splitTyConApp_maybe)
import TyCoRep (Type (..))
import TysWiredIn (typeNatKind, promotedTrueDataCon, promotedFalseDataCon)
import TcTypeNats (typeNatLeqTyCon)
import TcTypeNats (typeNatAddTyCon, typeNatLeqTyCon)
#if MIN_VERSION_ghc(8,4,0)
import TcTypeNats (typeNatDivTyCon, typeNatModTyCon)
#else
Expand Down Expand Up @@ -209,7 +209,13 @@ simplifyExtra defs eqs = tcPluginTrace "simplifyExtra" (ppr eqs) >> simples [] [
| otherwise -> return (Impossible eq)
(p, Max x y)
| b && (p == x || p == y) -> simples (((,) <$> evMagic ct <*> pure ct):evs) news eqs'

-- transform: Mod n p <= q
-- to: p <= q + 1, 1 <= p
(Mod _ p, q) | isWantedCt ct -> do
let succQ = toCType $ TyConApp typeNatAddTyCon [reifyEOP defs q, mkNumLitTy 1]
modCt <- createWantedFromNormalised defs (NatInequality ct p succQ b norm)
gteOneCt <- createWantedFromNormalised defs (NatInequality ct (I 1) p b norm)
simples (((,) <$> evMagic ct <*> pure ct):evs) (modCt:gteOneCt:news) eqs'
-- transform: q ~ Max x y => (p <=? q ~ True)
-- to: (p <=? Max x y) ~ True
-- and try to solve that along with the rest of the eqs'
Expand Down
7 changes: 5 additions & 2 deletions src/GHC/TypeLits/Extra/Solver/Unify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ module GHC.TypeLits.Extra.Solver.Unify
( ExtraDefs (..)
, UnifyResult (..)
, NormaliseResult
, toCType
, normaliseNat
, unifyExtra
)
Expand Down Expand Up @@ -60,6 +61,8 @@ mergeNormResWith f x y = do
(res, n3) <- f x' y'
pure (res, n1 `mergeNormalised` n2 `mergeNormalised` n3)

toCType :: Type -> ExtraOp
toCType ty = C $ CType ty

normaliseNat :: ExtraDefs -> Type -> MaybeT TcPluginM NormaliseResult
normaliseNat defs ty | Just ty1 <- coreView ty = normaliseNat defs ty1
Expand Down Expand Up @@ -105,9 +108,9 @@ normaliseNat defs (TyConApp tc tys) = do
normResults <- lift (sequence (runMaybeT . normaliseNat defs <$> tys))
let anyNormalised = foldr mergeNormalised Untouched (snd <$> catMaybes normResults)
let tys' = mergeExtraOp (zip normResults tys)
pure (C (CType (TyConApp tc tys')), anyNormalised)
pure (toCType $ TyConApp tc tys', anyNormalised)

normaliseNat _ t = return (C (CType t), Untouched)
normaliseNat _ t = return (toCType t, Untouched)

-- | Result of comparing two 'SOP' terms, returning a potential substitution
-- list under which the two terms are equal.
Expand Down
9 changes: 9 additions & 0 deletions tests-ghc-9.4/ErrorTests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ testFail26 = testFail26' (Proxy @4) (Proxy @6) (Proxy @6)
testFail27 :: Proxy n -> Proxy (n + 2 <=? Max (n + 1) 1) -> Proxy True
testFail27 _ = id

testFail28 :: Proxy n -> Proxy (Mod n p <=? p) -> Proxy True
testFail28 _ = id

testFail1Errors =
["Expected: Proxy (GCD 6 8) -> Proxy 4"
," Actual: Proxy 4 -> Proxy 4"
Expand Down Expand Up @@ -231,3 +234,9 @@ testFail26Errors =
,"from the context: (x <=? n) ~ 'True"
]
#endif

testFail28Errors =
["Couldn't match type ‘Data.Type.Ord.OrdCond"
,"(CmpNat 1 p) True True False"
, "with ‘True’"
]
13 changes: 13 additions & 0 deletions tests-pre-ghc-9.4/ErrorTests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ testFail26 = testFail26' (Proxy @4) (Proxy @6) (Proxy @6)
testFail27 :: Proxy n -> Proxy (n + 2 <=? Max (n + 1) 1) -> Proxy True
testFail27 _ = id

testFail28 :: Proxy n -> Proxy (Mod n p <=? p) -> Proxy True
testFail28 _ = id

#if __GLASGOW_HASKELL__ >= 900
testFail1Errors =
["Expected: Proxy (GCD 6 8) -> Proxy 4"
Expand Down Expand Up @@ -345,3 +348,13 @@ testFail26Errors =
["Could not deduce: Max x y ~ n"
,"from the context: (x <=? n) ~ 'True"
]

testFail28Errors =
#if __GLASGOW_HASKELL__ >= 902
["Couldn't match type ‘Data.Type.Ord.OrdCond"
,"(CmpNat 1 p) True True False"
, "with ‘True’"
]
#else
["Couldn't match type ‘1 <=? p’ with ‘'True’"]
#endif
21 changes: 21 additions & 0 deletions tests/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,20 @@ test58b
-> Proxy (Max (n+2) 1)
test58b = test58a

test59
:: Proxy n
-> Proxy p
-> Proxy (Mod n (p + 1) <=? p)
-> Proxy True
test59 _ _ x = x

test60
:: Proxy n
-> Proxy p
-> Proxy (Mod n (3 * p + 5) <=? (4 + p * 3))
-> Proxy True
test60 _ _ x = x

main :: IO ()
main = defaultMain tests

Expand Down Expand Up @@ -411,6 +425,12 @@ tests = testGroup "ghc-typelits-natnormalise"
, testCase "forall n p . n + 1 <= Max (n + p + 1) p" $
show (test57 Proxy Proxy Proxy) @?=
"Proxy"
, testCase "forall n p . Mod n (p + 1) <= p" $
show (test59 Proxy Proxy Proxy) @?=
"Proxy"
, testCase "forall n p . Mod n (3 * p + 5) <= (4 + p * 3)" $
show (test60 Proxy Proxy Proxy) @?=
"Proxy"
]
, testGroup "errors"
[ testCase "GCD 6 8 /~ 4" $ testFail1 `throws` testFail1Errors
Expand Down Expand Up @@ -440,6 +460,7 @@ tests = testGroup "ghc-typelits-natnormalise"
, testCase "(x+1 <=? Max x y) /~ True" $ testFail25 `throws` testFail25Errors
, testCase "(x <= n) /=> (Max x y) ~ n" $ testFail26 `throws` testFail26Errors
, testCase "n + 2 <=? Max (n + 1) 1 /~ True" $ testFail27 `throws` testFail27Errors
, testCase "Mod n p <=? p" $ testFail28 `throws` testFail28Errors
]
]

Expand Down

0 comments on commit 085a8dd

Please sign in to comment.