From 311d13de23bd21a5a214b2f9a9e694dafb21ab49 Mon Sep 17 00:00:00 2001 From: David Feuer Date: Sat, 14 Dec 2024 23:02:29 -0500 Subject: [PATCH] Use a fake GADT for sequence folds and traversals --- containers-tests/containers-tests.cabal | 1 + containers/containers.cabal | 1 + containers/src/Data/Sequence/Internal.hs | 153 ++++++++++++++++-- .../src/Data/Sequence/Internal/Depth.hs | 120 ++++++++++++++ 4 files changed, 260 insertions(+), 15 deletions(-) create mode 100644 containers/src/Data/Sequence/Internal/Depth.hs diff --git a/containers-tests/containers-tests.cabal b/containers-tests/containers-tests.cabal index f88c489c1..00f138860 100644 --- a/containers-tests/containers-tests.cabal +++ b/containers-tests/containers-tests.cabal @@ -106,6 +106,7 @@ library Data.Map.Strict.Internal Data.Sequence Data.Sequence.Internal + Data.Sequence.Internal.Depth Data.Sequence.Internal.Sorting Data.Set Data.Set.Internal diff --git a/containers/containers.cabal b/containers/containers.cabal index 3f7f43b45..de4c1eb52 100644 --- a/containers/containers.cabal +++ b/containers/containers.cabal @@ -70,6 +70,7 @@ Library Data.Graph Data.Sequence Data.Sequence.Internal + Data.Sequence.Internal.Depth Data.Sequence.Internal.Sorting Data.Tree Utils.Containers.Internal.BitUtil diff --git a/containers/src/Data/Sequence/Internal.hs b/containers/src/Data/Sequence/Internal.hs index 89dae7fa1..eab11a952 100644 --- a/containers/src/Data/Sequence/Internal.hs +++ b/containers/src/Data/Sequence/Internal.hs @@ -1,3 +1,4 @@ +{- OPTIONS_GHC -ddump-simpl #-} {-# LANGUAGE CPP #-} #include "containers.h" {-# LANGUAGE BangPatterns #-} @@ -7,6 +8,7 @@ {-# LANGUAGE DeriveLift #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskellQuotes #-} @@ -193,6 +195,7 @@ module Data.Sequence.Internal ( node2, node3, #endif + bongo ) where import Utils.Containers.Internal.Prelude hiding ( @@ -210,7 +213,7 @@ import Control.Applicative ((<$>), (<**>), Alternative, import qualified Control.Applicative as Applicative import Control.DeepSeq (NFData(rnf)) import Control.Monad (MonadPlus(..)) -import Data.Monoid (Monoid(..)) +import Data.Monoid (Monoid(..), Endo(..), Dual(..)) import Data.Functor (Functor(..)) import Utils.Containers.Internal.State (State(..), execState) import Data.Foldable (foldr', toList) @@ -250,6 +253,7 @@ import Data.Functor.Identity (Identity(..)) import Utils.Containers.Internal.StrictPair (StrictPair (..), toPair) import Control.Monad.Zip (MonadZip (..)) import Control.Monad.Fix (MonadFix (..), fix) +import Data.Sequence.Internal.Depth (Depth_ (..), Depth2_ (..)) default () @@ -394,16 +398,38 @@ fmapSeq f (Seq xs) = Seq (fmap (fmap f) xs) #-} #endif +--type Depth = Depth_ Elem Node +type Depth = Depth_ Node +type Depth2 = Depth2_ Node + instance Foldable Seq where #ifdef __GLASGOW_HASKELL__ foldMap :: forall m a. Monoid m => (a -> m) -> Seq a -> m - foldMap = coerce (foldMap :: (Elem a -> m) -> FingerTree (Elem a) -> m) + foldMap f (Seq t0) = foldMapFT Bottom t0 + where + foldMapBlob :: Depth (Elem a) t -> t -> m + foldMapBlob Bottom (Elem a) = f a + foldMapBlob (Deeper w) (Node2 _ x y) = foldMapBlob w x <> foldMapBlob w y + foldMapBlob (Deeper w) (Node3 _ x y z) = foldMapBlob w x <> foldMapBlob w y <> foldMapBlob w z + + foldMapFT :: Depth (Elem a) t -> FingerTree t -> m + foldMapFT !_ EmptyT = mempty + foldMapFT w (Single t) = foldMapBlob w t + foldMapFT w (Deep _ pr m sf) = + foldMap (foldMapBlob w) pr + <> foldMapFT (Deeper w) m + <> foldMap (foldMapBlob w) sf foldr :: forall a b. (a -> b -> b) -> b -> Seq a -> b - foldr = coerce (foldr :: (Elem a -> b -> b) -> b -> FingerTree (Elem a) -> b) + -- We define this explicitly so we can inline the foldMap. And we don't + -- define it as a coercion of the FingerTree version because we want users + -- to have the option of (effectively) inlining it explicitly. + foldr f z t = appEndo (GHC.Exts.inline foldMap (coerce f) t) z foldl :: forall b a. (b -> a -> b) -> b -> Seq a -> b - foldl = coerce (foldl :: (b -> Elem a -> b) -> b -> FingerTree (Elem a) -> b) + -- Should we define this by hand to associate optimally? Or is GHC + -- clever enough to do that for us? + foldl f z t = appEndo (getDual (GHC.Exts.inline foldMap (Dual . Endo . flip f) t)) z foldr' :: forall a b. (a -> b -> b) -> b -> Seq a -> b foldr' = coerce (foldr' :: (Elem a -> b -> b) -> b -> FingerTree (Elem a) -> b) @@ -442,7 +468,37 @@ instance Foldable Seq where instance Traversable Seq where #if __GLASGOW_HASKELL__ {-# INLINABLE traverse #-} -#endif + traverse :: forall f a b. Applicative f => (a -> f b) -> Seq a -> f (Seq b) + traverse f (Seq t0) = Seq <$> traverseFT Bottom2 t0 + where + traverseFT :: Depth2 (Elem a) t (Elem b) u -> FingerTree t -> f (FingerTree u) + traverseFT !_ EmptyT = pure EmptyT + traverseFT w (Single t) = Single <$> traverseBlob w t + traverseFT w (Deep s pr m sf) = liftA3 (Deep s) + (traverse (traverseBlob w) pr) + (traverseFT (Deeper2 w) m) + (traverse (traverseBlob w) sf) + + -- Traverse a 2-3 tree, given its height. + traverseBlob :: Depth2 (Elem a) t (Elem b) u -> t -> f u + traverseBlob Bottom2 (Elem a) = Elem <$> f a + + -- We have a special case here to avoid needing to `fmap Elem` over + -- each of the leaves, in case that's not free in the relevant functor. + -- We still end up using extra fmaps for the very first level of the + -- FingerTree and the Seq constructor. While we *could* avoid that, + -- doing so requires a good bit of extra code to save *at most* nine + -- fmap applications for the sequence. It would also save on Depth + -- comparisons, but I doubt that matters very much. + traverseBlob (Deeper2 Bottom2) (Node2 s (Elem x) (Elem y)) + = liftA2 (\x' y' -> Node2 s (Elem x') (Elem y')) (f x) (f y) + traverseBlob (Deeper2 Bottom2) (Node3 s (Elem x) (Elem y) (Elem z)) + = liftA3 (\x' y' z' -> Node3 s (Elem x') (Elem y') (Elem z')) + (f x) (f y) (f z) + + traverseBlob (Deeper2 w) (Node2 s x y) = liftA2 (Node2 s) (traverseBlob w x) (traverseBlob w y) + traverseBlob (Deeper2 w) (Node3 s x y z) = liftA3 (Node3 s) (traverseBlob w x) (traverseBlob w y) (traverseBlob w z) +#else traverse _ (Seq EmptyT) = pure (Seq EmptyT) traverse f' (Seq (Single (Elem x'))) = (\x'' -> Seq (Single (Elem x''))) <$> f' x' @@ -514,6 +570,7 @@ instance Traversable Seq where :: Applicative f => (Node a -> f (Node b)) -> Node (Node a) -> f (Node (Node b)) traverseNodeN f t = traverse f t +#endif instance NFData a => NFData (Seq a) where rnf (Seq xs) = rnf xs @@ -1078,7 +1135,33 @@ instance Sized a => Sized (FingerTree a) where size (Single x) = size x size (Deep v _ _ _) = v +-- We don't fold FingerTrees directly, but instead coerce them to +-- Seqs and fold those. This seems backwards! Why do it? We certainly +-- *could* fold FingerTrees directly, but we'd need a slightly different +-- version of the Depth GADT to do so. While that's not a big deal, +-- it is a bit annoying. Note: we need the current version of Depth +-- to deal with the Sized issues for indexed folds. instance Foldable FingerTree where +#ifdef __GLASGOW_HASKELL__ + foldMap :: forall m a. Monoid m => (a -> m) -> FingerTree a -> m + foldMap f = foldMapFT Bottom + where + foldMapBlob :: Depth a t -> t -> m + foldMapBlob Bottom a = f a + foldMapBlob (Deeper w) (Node2 _ x y) = foldMapBlob w x <> foldMapBlob w y + foldMapBlob (Deeper w) (Node3 _ x y z) = foldMapBlob w x <> foldMapBlob w y <> foldMapBlob w z + + foldMapFT :: Depth a t -> FingerTree t -> m + foldMapFT !_ EmptyT = mempty + foldMapFT w (Single t) = foldMapBlob w t + foldMapFT w (Deep _ pr m sf) = + foldMap (foldMapBlob w) pr + <> foldMapFT (Deeper w) m + <> foldMap (foldMapBlob w) sf + +-- foldMap = coerce (foldMap :: (a -> m) -> Seq a -> m) + {-# INLINABLE foldMap #-} +#else foldMap _ EmptyT = mempty foldMap f' (Single x') = f' x' foldMap f' (Deep _ pr' m' sf') = @@ -1105,8 +1188,6 @@ instance Foldable FingerTree where foldMapNodeN :: Monoid m => (Node a -> m) -> Node (Node a) -> m foldMapNodeN f t = foldNode (<>) f t -#if __GLASGOW_HASKELL__ - {-# INLINABLE foldMap #-} #endif foldr _ z' EmptyT = z' @@ -1270,7 +1351,7 @@ foldDigit _ f (One a) = f a foldDigit (<+>) f (Two a b) = f a <+> f b foldDigit (<+>) f (Three a b c) = f a <+> f b <+> f c foldDigit (<+>) f (Four a b c d) = f a <+> f b <+> f c <+> f d -{-# INLINE foldDigit #-} +{-# INLINABLE foldDigit #-} instance Foldable Digit where foldMap = foldDigit mappend @@ -3203,15 +3284,56 @@ foldWithIndexNode (<+>) f s (Node3 _ a b c) = f s a <+> f sPsa b <+> f sPsab c -- element in the sequence. -- -- @since 0.5.8 -foldMapWithIndex :: Monoid m => (Int -> a -> m) -> Seq a -> m +foldMapWithIndex :: forall m a. Monoid m => (Int -> a -> m) -> Seq a -> m +#ifdef __GLASGOW_HASKELL__ +foldMapWithIndex f (Seq t) = foldMapWithIndexFT Bottom 0 t + where + foldMapWithIndexFT :: Depth (Elem a) t -> Int -> FingerTree t -> m + foldMapWithIndexFT !_ !_ EmptyT = mempty + foldMapWithIndexFT d s (Single xs) = foldMapWithIndexBlob d s xs + foldMapWithIndexFT d s (Deep _ pr m sf) = case depthSized d of { Sizzy -> + foldWithIndexDigit (<>) (foldMapWithIndexBlob d) s pr <> + foldMapWithIndexFT (Deeper d) sPspr m <> + foldWithIndexDigit (<>) (foldMapWithIndexBlob d) sPsprm sf + where + !sPspr = s + size pr + !sPsprm = sPspr + size m + } + + foldMapWithIndexBlob :: Depth (Elem a) t -> Int -> t -> m + foldMapWithIndexBlob Bottom k (Elem a) = f k a + foldMapWithIndexBlob (Deeper yop) k (Node2 _s t1 t2) = + foldMapWithIndexBlob yop k t1 <> + foldMapWithIndexBlob yop (k + sizeBlob yop t1) t2 + foldMapWithIndexBlob (Deeper yop) k (Node3 _s t1 t2 t3) = + foldMapWithIndexBlob yop k t1 <> + foldMapWithIndexBlob yop (k + st1) t2 <> + foldMapWithIndexBlob yop (k + st1t2) t3 + where + st1 = sizeBlob yop t1 + st1t2 = st1 + sizeBlob yop t2 +{-# INLINABLE foldMapWithIndex #-} + +data Sizzy a where + Sizzy :: Sized a => Sizzy a + +depthSized :: Depth (Elem a) t -> Sizzy t +depthSized Bottom = Sizzy +depthSized (Deeper _) = Sizzy + +sizeBlob :: Depth (Elem a) t -> t -> Int +sizeBlob Bottom = size +sizeBlob (Deeper _) = size + +#else foldMapWithIndex f' (Seq xs') = foldMapWithIndexTreeE (lift_elem f') 0 xs' where lift_elem :: (Int -> a -> m) -> (Int -> Elem a -> m) -#ifdef __GLASGOW_HASKELL__ +# ifdef __GLASGOW_HASKELL__ lift_elem g = coerce g -#else +# else lift_elem g = \s (Elem a) -> g s a -#endif +# endif {-# INLINE lift_elem #-} -- We have to specialize these functions by hand, unfortunately, because -- GHC does not specialize until *all* instances are determined. @@ -3250,9 +3372,6 @@ foldMapWithIndex f' (Seq xs') = foldMapWithIndexTreeE (lift_elem f') 0 xs' foldMapWithIndexNodeN :: Monoid m => (Int -> Node a -> m) -> Int -> Node (Node a) -> m foldMapWithIndexNodeN f i t = foldWithIndexNode (<>) f i t - -#if __GLASGOW_HASKELL__ -{-# INLINABLE foldMapWithIndex #-} #endif -- | 'traverseWithIndex' is a version of 'traverse' that also offers @@ -4997,3 +5116,7 @@ fromList2 n = execState (replicateA n (State ht)) where ht (x:xs) = (xs, x) ht [] = error "fromList2: short list" + +{-# NOINLINE bongo #-} +bongo :: Seq [a] -> [a] +bongo xs = GHC.Exts.inline foldMap id xs diff --git a/containers/src/Data/Sequence/Internal/Depth.hs b/containers/src/Data/Sequence/Internal/Depth.hs new file mode 100644 index 000000000..723a10e7c --- /dev/null +++ b/containers/src/Data/Sequence/Internal/Depth.hs @@ -0,0 +1,120 @@ +{-# OPTIONS_GHC -ddump-prep #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE Trustworthy #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} + +-- | This module defines efficient representations of GADTs that are shaped +-- like (strict) unary natural numbers. That is, each type looks, from the +-- outside, something like this: +-- +-- @ +-- data NatLike ... where +-- ZeroLike :: NatLike ... +-- SuccLike :: !(NatLike ...) -> NatLike ... +-- @ +-- +-- but in fact it is represented by a single machine word. We put these in a +-- separate module to confine the highly unsafe magic used in the +-- implementation. +-- +-- Caution: Unlike the GADTs they represent, the types in this module are +-- bounded by @maxBound \@Word@, and attempting to take a successor of the +-- maximum bound will throw an overflow error. That's okay for our purposes +-- of implementing certain functions in "Data.Sequence.Internal"—the spine +-- of a well-formed sequence can only reach a length of around the word +-- size, not even close to @maxBound \@Word@. + +module Data.Sequence.Internal.Depth + ( Depth_ (Bottom, Deeper) + , Depth2_ (Bottom2, Deeper2) + ) where + +import Data.Kind (Type) +import Unsafe.Coerce (unsafeCoerce) + +-- @Depth_@ is an optimized representation of the following GADT: +-- +-- @ +-- data Depth_ node a t where +-- Bottom :: Depth_ node a a +-- Deeper :: !(Depth_ node a t) -> Depth_ node a (node t) +-- @ +-- +-- "Data.Sequence.Internal" fills in the @node@ parameter with its @Node@ +-- constructor; we have to be more general in this module because we don't +-- have access to that. +-- +-- @Depth_@ is represented internally as a 'Word' for performance, and the +-- 'Bottom' and 'Deeper' pattern synonyms implement the above GADT interface. +-- The implementation is "safe"—in the very unlikely event of arithmetic +-- overflow, an error will be thrown. This decision is subject to change; +-- arithmetic overflow on 64-bit systems requires somewhat absurdly long +-- computations on sequences constructed with extensive amounts of internal +-- sharing (e.g., using the '*>' operator repeatedly). +newtype Depth_ (node :: Type -> Type) (a :: Type) (t :: Type) + = Depth_ Word +type role Depth_ nominal nominal nominal + +-- | The depth is 0. +pattern Bottom :: () => t ~ a => Depth_ node a t +pattern Bottom <- (checkBottom -> AtBottom) + where + Bottom = Depth_ 0 + +-- | The depth is non-zero. +pattern Deeper :: () => t ~ node t' => Depth_ node a t' -> Depth_ node a t +pattern Deeper d <- (checkBottom -> NotBottom d) + where + Deeper (Depth_ d) + | d == maxBound = error "Depth overflow" + | otherwise = Depth_ (d + 1) + +{-# COMPLETE Bottom, Deeper #-} + +data CheckedBottom node a t where + AtBottom :: CheckedBottom node a a + NotBottom :: !(Depth_ node a t) -> CheckedBottom node a (node t) + +checkBottom :: Depth_ node a t -> CheckedBottom node a t +checkBottom (Depth_ 0) = unsafeCoerce AtBottom +checkBottom (Depth_ d) = unsafeCoerce (NotBottom (Depth_ (d - 1))) + + +-- | A version of 'Depth_' for implementing traversals. Conceptually, +-- +-- @ +-- data Depth2_ node a t b u where +-- Bottom2 :: Depth_ node a a b b +-- Deeper2 :: !(Depth_ node a t b u) -> Depth_ node a (node t) b (node u) +-- @ +newtype Depth2_ (node :: Type -> Type) (a :: Type) (t :: Type) (b :: Type) (u :: Type) + = Depth2_ Word +type role Depth2_ nominal nominal nominal nominal nominal + +-- | The depth is 0. +pattern Bottom2 :: () => (t ~ a, u ~ b) => Depth2_ node a t b u +pattern Bottom2 <- (checkBottom2 -> AtBottom2) + where + Bottom2 = Depth2_ 0 + +-- | The depth is non-zero. +pattern Deeper2 :: () => (t ~ node t', u ~ node u') => Depth2_ node a t' b u' -> Depth2_ node a t b u +pattern Deeper2 d <- (checkBottom2 -> NotBottom2 d) + where + Deeper2 (Depth2_ d) + | d == maxBound = error "Depth2 overflow" + | otherwise = Depth2_ (d + 1) + +{-# COMPLETE Bottom2, Deeper2 #-} + +data CheckedBottom2 node a t b u where + AtBottom2 :: CheckedBottom2 node a a b b + NotBottom2 :: !(Depth2_ node a t b u) -> CheckedBottom2 node a (node t) b (node u) + +checkBottom2 :: Depth2_ node a t b u -> CheckedBottom2 node a t b u +checkBottom2 (Depth2_ 0) = unsafeCoerce AtBottom2 +checkBottom2 (Depth2_ d) = unsafeCoerce (NotBottom2 (Depth2_ (d - 1)))