Scala has a function groupBy
on lists that accepts a function for extracting keys from list items, and returns another list where the items are tuples consisting of
Specifically, the following should work:
scalaGroupBy f = groupBy ((==) `on` f) . sortBy (comparing f)
modulo that this doesn't get you the result of f
in each group, but if you really need it you can always post-process with
map (\xs -> (f (head xs), xs)) . scalaGroupBy f
You can write the function yourself rather easily, but you need to place an Ord
or Hashable
constraint on the result of the classifier function if you want an efficient solution. Example:
import Control.Arrow ((&&&))
import Data.List
import Data.Function
myGroupBy :: (Ord b) => (a -> b) -> [a] -> [(b, [a])]
myGroupBy f = map (f . head &&& id)
. groupBy ((==) `on` f)
. sortBy (compare `on` f)
> myGroupBy (`mod` 2) [1..9]
[(0,[2,4,6,8]),(1,[1,3,5,7,9])]
You can also use a hash map like Data.HashMap.Strict
instead of sorting for expected linear time.
This isn't a function in the List library.
You can write it as the composition of sortBy and groupBy.
Since Scala groupBy
returns an immutable HashMap
, which does not require ordering, the corresponding Haskell implementation should return a HashMap
as well.
import qualified Data.HashMap.Strict as M
scalaGroupBy :: (Eq k, Hashable k) => (v -> k) -> [v] -> M.HashMap k [v]
scalaGroupBy f l = M.fromListWith (++) [ (f a, [a]) | a <- l]
We can also use the SQL-like then group by
syntax in list comprehension, which requires TransformListComp
language extension.
Since Scala groupBy
returns a Map
, we can call fromDistinctAscList
to convert the list comprehension to a Map
.
$ stack repl --package containers
Prelude> :set -XTransformListComp
Prelude> import Data.Map.Strict ( fromDistinctAscList, Map )
Prelude Data.Map.Strict> import GHC.Exts ( groupWith, the )
Prelude Data.Map.Strict GHC.Exts> :{
Prelude Data.Map.Strict GHC.Exts| scalaGroupBy f l =
Prelude Data.Map.Strict GHC.Exts| fromDistinctAscList
Prelude Data.Map.Strict GHC.Exts| [ (the key, value)
Prelude Data.Map.Strict GHC.Exts| | value <- l
Prelude Data.Map.Strict GHC.Exts| , let key = f value
Prelude Data.Map.Strict GHC.Exts| , then group by key using groupWith
Prelude Data.Map.Strict GHC.Exts| ]
Prelude Data.Map.Strict GHC.Exts| :}
Prelude Data.Map.Strict GHC.Exts> :type scalaGroupBy
scalaGroupBy :: Ord b => (t -> b) -> [t] -> Map b [t]
Prelude Data.Map.Strict GHC.Exts> scalaGroupBy (`mod` 2) [1, 2, 3, 4, 5, 6, 7, 8, 9]
fromList [(0,[2,4,6,8]),(1,[1,3,5,7,9])]
The only difference from Scala groupBy
is that the above implementation returns a sorted map instead of a hash map. For implementation that returns a hash map, see my other answer at https://stackoverflow.com/a/64204797/955091.
Putting a trace
in f
reveals that, with @Niklas solution, f
is evaluated 3 times for each element on any list of length 2 or more. I took the liberty of modifying it so that f
is applied to each element only once. It's not clear however whether the cost of creating and destroying tuples is less than the cost of evaluating f
multiple times (since f
can be arbitrary).
import Control.Arrow ((&&&))
import Data.List
import Data.Function
myGroupBy' :: (Ord b) => (a -> b) -> [a] -> [(b, [a])]
myGroupBy' f = map (fst . head &&& map snd)
. groupBy ((==) `on` fst)
. sortBy (compare `on` fst)
. map (f &&& id)