Skip to content

Commit 45634f6

Browse files
committed
feat: sortBy now works with compound columns
1 parent c4b73e0 commit 45634f6

2 files changed

Lines changed: 110 additions & 10 deletions

File tree

src/DataFrame/Operations/Permutation.hs

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ import Data.Vector.Internal.Check (HasCallStack)
2222
import DataFrame.Errors (DataFrameException (..))
2323
import DataFrame.Internal.Column (Column (..), Columnable, atIndicesStable)
2424
import DataFrame.Internal.DataFrame (DataFrame (..), unsafeGetColumn)
25-
import DataFrame.Internal.Expression (Expr (Col))
25+
import DataFrame.Internal.Expression (Expr (Col), getColumns)
2626
import DataFrame.Operations.Core (columnNames, dimensions)
27+
import DataFrame.Operations.Transformations (derive)
2728
import System.Random (Random (randomR), RandomGen)
2829
import Type.Reflection (typeRep)
2930

@@ -38,15 +39,40 @@ instance Eq SortOrder where
3839
(==) (Desc _) (Desc _) = True
3940
(==) _ _ = False
4041

41-
getSortColumnName :: SortOrder -> T.Text
42-
getSortColumnName (Asc (Col n)) = n
43-
getSortColumnName (Desc (Col n)) = n
44-
getSortColumnName _ = error "Sorting on compound column"
42+
sortOrderColumns :: SortOrder -> [T.Text]
43+
sortOrderColumns (Asc e) = getColumns e
44+
sortOrderColumns (Desc e) = getColumns e
4545

4646
mustFlipCompare :: SortOrder -> Bool
4747
mustFlipCompare (Asc _) = True
4848
mustFlipCompare (Desc _) = False
4949

50+
{- | Materialize any compound sort expressions into synthetic columns on
51+
a working dataframe, returning rewritten 'SortOrder's that reference
52+
those columns by name.
53+
-}
54+
prepareSortColumns :: [SortOrder] -> DataFrame -> ([SortOrder], DataFrame)
55+
prepareSortColumns = go 0
56+
where
57+
go _ [] acc = ([], acc)
58+
go i (ord : rest) acc =
59+
let (ord', acc') = materializeSortOrder i ord acc
60+
(rest', acc'') = go (i + 1) rest acc'
61+
in (ord' : rest', acc'')
62+
63+
materializeSortOrder :: Int -> SortOrder -> DataFrame -> (SortOrder, DataFrame)
64+
materializeSortOrder _ ord@(Asc (Col _)) df = (ord, df)
65+
materializeSortOrder _ ord@(Desc (Col _)) df = (ord, df)
66+
materializeSortOrder i (Asc (e :: Expr a)) df =
67+
let name = syntheticName i
68+
in (Asc (Col name :: Expr a), derive name e df)
69+
materializeSortOrder i (Desc (e :: Expr a)) df =
70+
let name = syntheticName i
71+
in (Desc (Col name :: Expr a), derive name e df)
72+
73+
syntheticName :: Int -> T.Text
74+
syntheticName i = "__sortBy_synthetic_" <> T.pack (show i) <> "__"
75+
5076
{- | O(k log n) Sorts the dataframe by a given row.
5177
5278
> sortBy Ascending ["Age"] df
@@ -56,22 +82,24 @@ sortBy ::
5682
DataFrame ->
5783
DataFrame
5884
sortBy sortOrds df
59-
| any (`notElem` columnNames df) names =
85+
| not (null missing) =
6086
throw $
6187
ColumnsNotFoundException
62-
(names L.\\ columnNames df)
88+
missing
6389
"sortBy"
6490
(columnNames df)
6591
| otherwise =
6692
let
67-
comparators = map (`sortOrderComparator` df) sortOrds
93+
(sortOrds', df') = prepareSortColumns sortOrds df
94+
comparators = map (`sortOrderComparator` df') sortOrds'
6895
compositeCompare i j = mconcat [c i j | c <- comparators]
69-
nRows = fst (dataframeDimensions df)
96+
nRows = fst (dataframeDimensions df')
7097
indexes = sortIndices compositeCompare nRows
7198
in
7299
df{columns = V.map (atIndicesStable indexes) (columns df)}
73100
where
74-
names = map getSortColumnName sortOrds
101+
referenced = L.nub (concatMap sortOrderColumns sortOrds)
102+
missing = referenced L.\\ columnNames df
75103

76104
{- | Build a row-index comparator from a SortOrder and a DataFrame.
77105
The Ord dictionary is recovered from the SortOrder GADT.

tests/Operations/Sort.hs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,83 @@ sortByColumnDoesNotExist =
8989
(print $ D.sortBy [D.Asc (F.col @Int "test0")] testData)
9090
)
9191

92+
compoundTestData :: D.DataFrame
93+
compoundTestData =
94+
D.fromNamedColumns
95+
[ ("a", DI.fromList ([3, 1, 4, 1, 5] :: [Int]))
96+
, ("b", DI.fromList ([10, 40, 20, 30, 50] :: [Int]))
97+
]
98+
99+
sortByCompoundExpression :: Test
100+
sortByCompoundExpression =
101+
TestCase
102+
( assertEqual
103+
"Sorting by Asc (a + b) orders rows by the sum without leaking a synthetic column"
104+
( D.fromNamedColumns
105+
[ ("a", DI.fromList ([3, 4, 1, 1, 5] :: [Int]))
106+
, ("b", DI.fromList ([10, 20, 30, 40, 50] :: [Int]))
107+
]
108+
)
109+
(D.sortBy [D.Asc (F.col @Int "a" + F.col @Int "b")] compoundTestData)
110+
)
111+
112+
sortByCompoundExpressionDescending :: Test
113+
sortByCompoundExpressionDescending =
114+
TestCase
115+
( assertEqual
116+
"Sorting by Desc (b - a) orders rows by descending difference"
117+
( D.fromNamedColumns
118+
[ ("a", DI.fromList ([5, 1, 1, 4, 3] :: [Int]))
119+
, ("b", DI.fromList ([50, 40, 30, 20, 10] :: [Int]))
120+
]
121+
)
122+
(D.sortBy [D.Desc (F.col @Int "b" - F.col @Int "a")] compoundTestData)
123+
)
124+
125+
sortByCompoundMixedWithBareColumn :: Test
126+
sortByCompoundMixedWithBareColumn =
127+
TestCase
128+
( assertEqual
129+
"Mixing a compound Asc key with a bare Desc tie-breaker works"
130+
( D.fromNamedColumns
131+
[ ("a", DI.fromList ([1, 1, 3, 4, 5] :: [Int]))
132+
, ("b", DI.fromList ([40, 30, 10, 20, 50] :: [Int]))
133+
]
134+
)
135+
( D.sortBy
136+
[D.Asc (F.col @Int "a" * 2), D.Desc (F.col @Int "b")]
137+
compoundTestData
138+
)
139+
)
140+
141+
sortByCompoundMissingColumn :: Test
142+
sortByCompoundMissingColumn =
143+
TestCase
144+
( assertExpectException
145+
"[Error Case]"
146+
( D.columnsNotFound
147+
["nope"]
148+
"sortBy"
149+
(D.columnNames compoundTestData)
150+
)
151+
( print $
152+
D.sortBy
153+
[D.Asc (F.col @Int "nope" + F.col @Int "a")]
154+
compoundTestData
155+
)
156+
)
157+
92158
tests :: [Test]
93159
tests =
94160
[ TestLabel "sortByAscendingWAI" sortByAscendingWAI
95161
, TestLabel "sortByDescendingWAI" sortByDescendingWAI
96162
, TestLabel "sortByColumnDoesNotExist" sortByColumnDoesNotExist
97163
, TestLabel "sortByTwoColumns" sortByTwoColumns
98164
, TestLabel "sortByOneColumnAscOneColumnDesc" sortByOneColumnAscOneColumnDesc
165+
, TestLabel "sortByCompoundExpression" sortByCompoundExpression
166+
, TestLabel
167+
"sortByCompoundExpressionDescending"
168+
sortByCompoundExpressionDescending
169+
, TestLabel "sortByCompoundMixedWithBareColumn" sortByCompoundMixedWithBareColumn
170+
, TestLabel "sortByCompoundMissingColumn" sortByCompoundMissingColumn
99171
]

0 commit comments

Comments
 (0)