@@ -262,6 +262,37 @@ import Data.Int (Int16, Int32, Int64, Int8)
262262 Column ->
263263 Either DataFrameException Column
264264 #-}
265+ -- Bool-returning binary comparators (hot path for Expr Bool used in
266+ -- DecisionTree splits)
267+ {-# SPECIALIZE zipWithColumns ::
268+ (Double -> Double -> Bool) ->
269+ Column ->
270+ Column ->
271+ Either DataFrameException Column
272+ #-}
273+ {-# SPECIALIZE zipWithColumns ::
274+ (Float -> Float -> Bool) ->
275+ Column ->
276+ Column ->
277+ Either DataFrameException Column
278+ #-}
279+ {-# SPECIALIZE zipWithColumns ::
280+ (Int -> Int -> Bool) ->
281+ Column ->
282+ Column ->
283+ Either DataFrameException Column
284+ #-}
285+ {-# SPECIALIZE zipWithColumns ::
286+ (Bool -> Bool -> Bool) ->
287+ Column ->
288+ Column ->
289+ Either DataFrameException Column
290+ #-}
291+
292+ -- Bool-mapping unary ops (e.g. 'not')
293+ {-# SPECIALIZE mapColumn ::
294+ (Bool -> Bool) -> Column -> Either DataFrameException Column
295+ #-}
265296
266297-------------------------------------------------------------------------------
267298-- Value: the unified result type
@@ -273,15 +304,15 @@ broadcast allocations.
273304-}
274305data Value a where
275306 -- | A single value, not yet broadcast to any length.
276- Scalar :: (Columnable a ) => a -> Value a
307+ Scalar :: (Columnable a ) => ! a -> Value a
277308 {- | A flat column (one element per row in the flat case, or one
278309 element per group after aggregation).
279310 -}
280- Flat :: (Columnable a ) => Column -> Value a
311+ Flat :: (Columnable a ) => ! Column -> Value a
281312 {- | A grouped column: one 'Column' slice per group. Only produced
282313 when interpreting inside a 'GroupCtx'.
283314 -}
284- Group :: (Columnable a ) => V. Vector Column -> Value a
315+ Group :: (Columnable a ) => ! ( V. Vector Column ) -> Value a
285316
286317instance (Show a ) => Show (Value a ) where
287318 show (Scalar v) = show v
@@ -325,6 +356,7 @@ liftValue ::
325356liftValue f (Scalar v) = Right (Scalar (f v))
326357liftValue f (Flat col) = Flat <$> mapColumn f col
327358liftValue f (Group gs) = Group <$> V. mapM (mapColumn f) gs
359+ {-# INLINEABLE liftValue #-}
328360
329361{- | Apply a binary function to two 'Value's. When one side is a
330362'Scalar' the operation degenerates to a 'liftValue' — this is how the
@@ -351,6 +383,7 @@ liftValue2 _ (Group _) (Flat _) =
351383 Left $ AggregatedAndNonAggregatedException " non-aggregated" " aggregated"
352384liftValue2 _ (Group _) (Group _) =
353385 Left $ InternalException " Group count mismatch in binary operation"
386+ {-# INLINEABLE liftValue2 #-}
354387
355388-- | Branch on a boolean 'Value', selecting from two same-typed 'Value's.
356389branchValue ::
@@ -384,6 +417,7 @@ branchValue _ _ _ =
384417 AggregatedAndNonAggregatedException
385418 " if-then-else branches"
386419 " mismatched shapes"
420+ {-# INLINEABLE branchValue #-}
387421
388422{- | Low-level column branch: given a boolean column and two same-typed
389423columns, produce the element-wise selection.
0 commit comments