diff --git a/containers-tests/benchmarks/Map.hs b/containers-tests/benchmarks/Map.hs index b53a4914d..0e324e556 100644 --- a/containers-tests/benchmarks/Map.hs +++ b/containers-tests/benchmarks/Map.hs @@ -95,6 +95,8 @@ main = do , bench "fromDistinctDescList" $ whnf M.fromDistinctDescList elems_rev , bench "fromDistinctDescList:fusion" $ whnf (\n -> M.fromDistinctDescList [(i,i) | i <- [n,n-1..1]]) bound , bench "minView" $ whnf (\m' -> case M.minViewWithKey m' of {Nothing -> 0; Just ((k,v),m'') -> k+v+M.size m''}) (M.fromAscList $ zip [1..10::Int] [100..110::Int]) + , bench "eq" $ whnf (\m' -> m' == m') m -- worst case, compares everything + , bench "compare" $ whnf (\m' -> compare m' m') m -- worst case, compares everything ] where bound = 2^12 diff --git a/containers-tests/benchmarks/Set.hs b/containers-tests/benchmarks/Set.hs index f65e2a620..265117e9c 100644 --- a/containers-tests/benchmarks/Set.hs +++ b/containers-tests/benchmarks/Set.hs @@ -55,6 +55,8 @@ main = do , bench "member.powerSet (16)" $ whnf (\ s -> all (flip S.member s) s) (S.powerSet (S.fromList [1..16])) , bench "member.powerSet (17)" $ whnf (\ s -> all (flip S.member s) s) (S.powerSet (S.fromList [1..17])) , bench "member.powerSet (18)" $ whnf (\ s -> all (flip S.member s) s) (S.powerSet (S.fromList [1..18])) + , bench "eq" $ whnf (\s' -> s' == s') s -- worst case, compares everything + , bench "compare" $ whnf (\s' -> compare s' s') s -- worst case, compares everything ] where bound = 2^12 diff --git a/containers-tests/containers-tests.cabal b/containers-tests/containers-tests.cabal index d6857040d..9a098820d 100644 --- a/containers-tests/containers-tests.cabal +++ b/containers-tests/containers-tests.cabal @@ -124,6 +124,7 @@ library Utils.Containers.Internal.PtrEquality Utils.Containers.Internal.State Utils.Containers.Internal.StrictMaybe + Utils.Containers.Internal.EqOrdUtil if impl(ghc) other-modules: diff --git a/containers-tests/tests/map-properties.hs b/containers-tests/tests/map-properties.hs index b2f01f1b7..b6d873c44 100644 --- a/containers-tests/tests/map-properties.hs +++ b/containers-tests/tests/map-properties.hs @@ -33,7 +33,7 @@ import Test.Tasty import Test.Tasty.HUnit import Test.Tasty.QuickCheck import Test.QuickCheck.Function (apply) -import Test.QuickCheck.Poly (A, B) +import Test.QuickCheck.Poly (A, B, OrdA) import Control.Arrow (first) default (Int) @@ -250,6 +250,8 @@ main = defaultMain $ testGroup "map-properties" , testProperty "splitAt" prop_splitAt , testProperty "lookupMin" prop_lookupMin , testProperty "lookupMax" prop_lookupMax + , testProperty "eq" prop_eq + , testProperty "compare" prop_compare ] {-------------------------------------------------------------------- @@ -1636,3 +1638,9 @@ prop_fromArgSet :: [(Int, Int)] -> Bool prop_fromArgSet ys = let xs = List.nubBy ((==) `on` fst) ys in fromArgSet (Set.fromList $ List.map (uncurry Arg) xs) == fromList xs + +prop_eq :: Map Int A -> Map Int A -> Property +prop_eq m1 m2 = (m1 == m2) === (toList m1 == toList m2) + +prop_compare :: Map Int OrdA -> Map Int OrdA -> Property +prop_compare m1 m2 = compare m1 m2 === compare (toList m1) (toList m2) diff --git a/containers-tests/tests/set-properties.hs b/containers-tests/tests/set-properties.hs index aeee9e584..82d30f0c6 100644 --- a/containers-tests/tests/set-properties.hs +++ b/containers-tests/tests/set-properties.hs @@ -110,6 +110,8 @@ main = defaultMain $ testGroup "set-properties" , testProperty "strict foldr" prop_strictFoldr' , testProperty "strict foldl" prop_strictFoldl' #endif + , testProperty "eq" prop_eq + , testProperty "compare" prop_compare ] -- A type with a peculiar Eq instance designed to make sure keys @@ -730,3 +732,9 @@ prop_strictFoldr' m = whnfHasNoThunks (foldr' (:) [] m) prop_strictFoldl' :: Set Int -> Property prop_strictFoldl' m = whnfHasNoThunks (foldl' (flip (:)) [] m) #endif + +prop_eq :: Set Int -> Set Int -> Property +prop_eq s1 s2 = (s1 == s2) === (toList s1 == toList s2) + +prop_compare :: Set Int -> Set Int -> Property +prop_compare s1 s2 = compare s1 s2 === compare (toList s1) (toList s2) diff --git a/containers/containers.cabal b/containers/containers.cabal index 1a89484b1..3ab69d3ac 100644 --- a/containers/containers.cabal +++ b/containers/containers.cabal @@ -80,6 +80,7 @@ Library Utils.Containers.Internal.StrictMaybe Utils.Containers.Internal.PtrEquality Utils.Containers.Internal.Coercions + Utils.Containers.Internal.EqOrdUtil if impl(ghc) other-modules: Utils.Containers.Internal.TypeError diff --git a/containers/src/Data/Map/Internal.hs b/containers/src/Data/Map/Internal.hs index d60af97b0..b0f29021a 100644 --- a/containers/src/Data/Map/Internal.hs +++ b/containers/src/Data/Map/Internal.hs @@ -401,6 +401,7 @@ import Utils.Containers.Internal.PtrEquality (ptrEq) import Utils.Containers.Internal.StrictPair import Utils.Containers.Internal.StrictMaybe import Utils.Containers.Internal.BitQueue +import Utils.Containers.Internal.EqOrdUtil (EqM(..), OrdM(..)) #ifdef DEFINE_ALTERF_FALLBACK import Utils.Containers.Internal.BitUtil (wordSize) #endif @@ -4118,6 +4119,31 @@ deleteFindMax t = case maxViewWithKey t of Nothing -> (error "Map.deleteFindMax: can not return the maximal element of an empty map", Tip) Just res -> res +{-------------------------------------------------------------------- + Iterator +--------------------------------------------------------------------} + +-- See Note [Iterator] in Data.Set.Internal + +iterDown :: Map k a -> Stack k a -> Stack k a +iterDown (Bin _ kx x l r) stk = iterDown l (Push kx x r stk) +iterDown Tip stk = stk + +-- Create an iterator from a Map, starting at the smallest key. +iterator :: Map k a -> Stack k a +iterator m = iterDown m Nada + +-- Get the next key-value and the remaining iterator. +iterNext :: Stack k a -> Maybe (StrictPair (KeyValue k a) (Stack k a)) +iterNext (Push kx x r stk) = Just $! KeyValue kx x :*: iterDown r stk +iterNext Nada = Nothing +{-# INLINE iterNext #-} + +-- Whether there are no more key-values in the iterator. +iterNull :: Stack k a -> Bool +iterNull (Push _ _ _ _) = False +iterNull Nada = True + {-------------------------------------------------------------------- [balance l x r] balances two trees with value x. The sizes of the trees should balance after decreasing the @@ -4284,41 +4310,69 @@ bin k x l r {-------------------------------------------------------------------- - Eq converts the tree to a list. In a lazy setting, this - actually seems one of the faster methods to compare two trees - and it is certainly the simplest :-) + Eq --------------------------------------------------------------------} + instance (Eq k,Eq a) => Eq (Map k a) where - t1 == t2 = (size t1 == size t2) && (toAscList t1 == toAscList t2) + m1 == m2 = liftEq2 (==) (==) m1 m2 + {-# INLINABLE (==) #-} -{-------------------------------------------------------------------- - Ord ---------------------------------------------------------------------} +-- | @since 0.5.9 +instance Eq k => Eq1 (Map k) where + liftEq = liftEq2 (==) + {-# INLINE liftEq #-} -instance (Ord k, Ord v) => Ord (Map k v) where - compare m1 m2 = compare (toAscList m1) (toAscList m2) +-- | @since 0.5.9 +instance Eq2 Map where + liftEq2 keq eq m1 m2 = size m1 == size m2 && sameSizeLiftEq2 keq eq m1 m2 + {-# INLINE liftEq2 #-} + +-- Assumes the maps are of equal size to skip the final check +sameSizeLiftEq2 + :: (ka -> kb -> Bool) -> (a -> b -> Bool) -> Map ka a -> Map kb b -> Bool +sameSizeLiftEq2 keq eq m1 m2 = + case runEqM (foldMapWithKey f m1) (iterator m2) of e :*: _ -> e + where + f kx x = EqM $ \it -> case iterNext it of + Nothing -> False :*: it + Just (KeyValue ky y :*: it') -> (keq kx ky && eq x y) :*: it' +{-# INLINE sameSizeLiftEq2 #-} {-------------------------------------------------------------------- - Lifted instances + Ord --------------------------------------------------------------------} --- | @since 0.5.9 -instance Eq2 Map where - liftEq2 eqk eqv m n = - size m == size n && liftEq (liftEq2 eqk eqv) (toList m) (toList n) +instance (Ord k, Ord v) => Ord (Map k v) where + compare m1 m2 = liftCmp2 compare compare m1 m2 + {-# INLINABLE compare #-} -- | @since 0.5.9 -instance Eq k => Eq1 (Map k) where - liftEq = liftEq2 (==) +instance Ord k => Ord1 (Map k) where + liftCompare = liftCmp2 compare + {-# INLINE liftCompare #-} -- | @since 0.5.9 instance Ord2 Map where - liftCompare2 cmpk cmpv m n = - liftCompare (liftCompare2 cmpk cmpv) (toList m) (toList n) + liftCompare2 = liftCmp2 + {-# INLINE liftCompare2 #-} + +liftCmp2 + :: (ka -> kb -> Ordering) + -> (a -> b -> Ordering) + -> Map ka a + -> Map kb b + -> Ordering +liftCmp2 kcmp cmp m1 m2 = case runOrdM (foldMapWithKey f m1) (iterator m2) of + o :*: it -> o <> if iterNull it then EQ else LT + where + f kx x = OrdM $ \it -> case iterNext it of + Nothing -> GT :*: it + Just (KeyValue ky y :*: it') -> (kcmp kx ky <> cmp x y) :*: it' +{-# INLINE liftCmp2 #-} --- | @since 0.5.9 -instance Ord k => Ord1 (Map k) where - liftCompare = liftCompare2 compare +{-------------------------------------------------------------------- + Lifted instances +--------------------------------------------------------------------} -- | @since 0.5.9 instance Show2 Map where diff --git a/containers/src/Data/Set/Internal.hs b/containers/src/Data/Set/Internal.hs index f1ec29c3a..0d79ac9ce 100644 --- a/containers/src/Data/Set/Internal.hs +++ b/containers/src/Data/Set/Internal.hs @@ -252,6 +252,7 @@ import Control.DeepSeq (NFData(rnf)) import Utils.Containers.Internal.StrictPair import Utils.Containers.Internal.PtrEquality +import Utils.Containers.Internal.EqOrdUtil (EqM(..), OrdM(..)) #if __GLASGOW_HASKELL__ import GHC.Exts ( build, lazy ) @@ -1272,19 +1273,90 @@ foldl'Stack f = go {-# INLINE foldl'Stack #-} {-------------------------------------------------------------------- - Eq converts the set to a list. In a lazy setting, this - actually seems one of the faster methods to compare two trees - and it is certainly the simplest :-) + Iterator --------------------------------------------------------------------} + +-- Note [Iterator] +-- ~~~~~~~~~~~~~~~ +-- Iteration, using a Stack as an iterator, is an efficient way to consume a Set +-- one element at a time. Alternately, this may be done by toAscList. toAscList +-- when consumed via List.foldr will rewrite to Set.foldr (thanks to rewrite +-- rules), which is quite efficient. However, sometimes that is not possible, +-- such as in the second arg of '==' or 'compare', where manifesting the list +-- cons cells is unavoidable and makes things slower. +-- +-- Concretely, compare on Set Int using toAscList takes ~21% more time compared +-- to using Iterator, on GHC 9.6.3. +-- +-- The heart of this implementation is the `iterDown` function. It walks down +-- the left spine of the tree, pushing the value and right child on the stack, +-- until a Tip is reached. The next value is now at the top of the stack. To get +-- to the value after that, `iterDown` is called again with the right child and +-- the remaining stack. + +iterDown :: Set a -> Stack a -> Stack a +iterDown (Bin _ x l r) stk = iterDown l (Push x r stk) +iterDown Tip stk = stk + +-- Create an iterator from a Set, starting at the smallest element. +iterator :: Set a -> Stack a +iterator s = iterDown s Nada + +-- Get the next element and the remaining iterator. +iterNext :: Stack a -> Maybe (StrictPair a (Stack a)) +iterNext (Push x r stk) = Just $! x :*: iterDown r stk +iterNext Nada = Nothing +{-# INLINE iterNext #-} + +-- Whether there are no more elements in the iterator. +iterNull :: Stack a -> Bool +iterNull (Push _ _ _) = False +iterNull Nada = True + +{-------------------------------------------------------------------- + Eq +--------------------------------------------------------------------} + instance Eq a => Eq (Set a) where - t1 == t2 = (size t1 == size t2) && (toAscList t1 == toAscList t2) + s1 == s2 = liftEq (==) s1 s2 + {-# INLINABLE (==) #-} + +-- | @since 0.5.9 +instance Eq1 Set where + liftEq eq s1 s2 = size s1 == size s2 && sameSizeLiftEq eq s1 s2 + {-# INLINE liftEq #-} + +-- Assumes the sets are of equal size to skip the final check. +sameSizeLiftEq :: (a -> b -> Bool) -> Set a -> Set b -> Bool +sameSizeLiftEq eq s1 s2 = + case runEqM (foldMap f s1) (iterator s2) of e :*: _ -> e + where + f x = EqM $ \it -> case iterNext it of + Nothing -> False :*: it + Just (y :*: it') -> eq x y :*: it' +{-# INLINE sameSizeLiftEq #-} {-------------------------------------------------------------------- Ord --------------------------------------------------------------------} instance Ord a => Ord (Set a) where - compare s1 s2 = compare (toAscList s1) (toAscList s2) + compare s1 s2 = liftCmp compare s1 s2 + {-# INLINABLE compare #-} + +-- | @since 0.5.9 +instance Ord1 Set where + liftCompare = liftCmp + {-# INLINE liftCompare #-} + +liftCmp :: (a -> b -> Ordering) -> Set a -> Set b -> Ordering +liftCmp cmp s1 s2 = case runOrdM (foldMap f s1) (iterator s2) of + o :*: it -> o <> if iterNull it then EQ else LT + where + f x = OrdM $ \it -> case iterNext it of + Nothing -> GT :*: it + Just (y :*: it') -> cmp x y :*: it' +{-# INLINE liftCmp #-} {-------------------------------------------------------------------- Show @@ -1293,16 +1365,6 @@ instance Show a => Show (Set a) where showsPrec p xs = showParen (p > 10) $ showString "fromList " . shows (toList xs) --- | @since 0.5.9 -instance Eq1 Set where - liftEq eq m n = - size m == size n && liftEq eq (toList m) (toList n) - --- | @since 0.5.9 -instance Ord1 Set where - liftCompare cmp m n = - liftCompare cmp (toList m) (toList n) - -- | @since 0.5.9 instance Show1 Set where liftShowsPrec sp sl d m = diff --git a/containers/src/Utils/Containers/Internal/EqOrdUtil.hs b/containers/src/Utils/Containers/Internal/EqOrdUtil.hs new file mode 100644 index 000000000..58d39d8bd --- /dev/null +++ b/containers/src/Utils/Containers/Internal/EqOrdUtil.hs @@ -0,0 +1,38 @@ +{-# LANGUAGE CPP #-} +module Utils.Containers.Internal.EqOrdUtil + ( EqM(..) + , OrdM(..) + ) where + +#if !MIN_VERSION_base(4,11,0) +import Data.Semigroup (Semigroup(..)) +#endif +import Utils.Containers.Internal.StrictPair + +newtype EqM a = EqM { runEqM :: a -> StrictPair Bool a } + +-- | Composes left-to-right, short-circuits on False +instance Semigroup (EqM a) where + f <> g = EqM $ \x -> case runEqM f x of + r@(e :*: x') -> if e then runEqM g x' else r + +instance Monoid (EqM a) where + mempty = EqM (True :*:) +#if !MIN_VERSION_base(4,11,0) + mappend = (<>) +#endif + +newtype OrdM a = OrdM { runOrdM :: a -> StrictPair Ordering a } + +-- | Composes left-to-right, short-circuits on non-EQ +instance Semigroup (OrdM a) where + f <> g = OrdM $ \x -> case runOrdM f x of + r@(o :*: x') -> case o of + EQ -> runOrdM g x' + _ -> r + +instance Monoid (OrdM a) where + mempty = OrdM (EQ :*:) +#if !MIN_VERSION_base(4,11,0) + mappend = (<>) +#endif