Skip to content

Commit ef3adb5

Browse files
committed
feat: Add linear solver to tao tree implementation and make candidate enumeration more performant.
1 parent e5ab413 commit ef3adb5

11 files changed

Lines changed: 2339 additions & 156 deletions

File tree

dataframe-core/src/DataFrame/Internal/Hash.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ module DataFrame.Internal.Hash (
2121
mixShow,
2222
) where
2323

24-
import Data.Array.Byte (ByteArray (ByteArray))
2524
import Data.Bits (rotateL, unsafeShiftL, unsafeShiftR, xor)
2625
import Data.Char (ord)
2726
import qualified Data.Text as T
27+
import Data.Text.Array (Array (ByteArray))
2828
import Data.Text.Internal (Text (Text))
2929
import GHC.Exts (Int (I#), indexWord8Array#, indexWord8ArrayAsWord64#)
3030
import GHC.Word (Word64 (W64#), Word8 (W8#))

dataframe-core/src/DataFrame/Internal/Interpreter.hs

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,37 @@ import Data.Int (Int16, Int32, Int64, Int8)
262262
Column ->
263263
Either DataFrameException Column
264264
#-}
265+
-- Bool-returning binary comparators (hot path for Expr Bool used in
266+
-- DecisionTree splits)
267+
{-# SPECIALIZE zipWithColumns ::
268+
(Double -> Double -> Bool) ->
269+
Column ->
270+
Column ->
271+
Either DataFrameException Column
272+
#-}
273+
{-# SPECIALIZE zipWithColumns ::
274+
(Float -> Float -> Bool) ->
275+
Column ->
276+
Column ->
277+
Either DataFrameException Column
278+
#-}
279+
{-# SPECIALIZE zipWithColumns ::
280+
(Int -> Int -> Bool) ->
281+
Column ->
282+
Column ->
283+
Either DataFrameException Column
284+
#-}
285+
{-# SPECIALIZE zipWithColumns ::
286+
(Bool -> Bool -> Bool) ->
287+
Column ->
288+
Column ->
289+
Either DataFrameException Column
290+
#-}
291+
292+
-- Bool-mapping unary ops (e.g. 'not')
293+
{-# SPECIALIZE mapColumn ::
294+
(Bool -> Bool) -> Column -> Either DataFrameException Column
295+
#-}
265296

266297
-------------------------------------------------------------------------------
267298
-- Value: the unified result type
@@ -273,15 +304,15 @@ broadcast allocations.
273304
-}
274305
data Value a where
275306
-- | A single value, not yet broadcast to any length.
276-
Scalar :: (Columnable a) => a -> Value a
307+
Scalar :: (Columnable a) => !a -> Value a
277308
{- | A flat column (one element per row in the flat case, or one
278309
element per group after aggregation).
279310
-}
280-
Flat :: (Columnable a) => Column -> Value a
311+
Flat :: (Columnable a) => !Column -> Value a
281312
{- | A grouped column: one 'Column' slice per group. Only produced
282313
when interpreting inside a 'GroupCtx'.
283314
-}
284-
Group :: (Columnable a) => V.Vector Column -> Value a
315+
Group :: (Columnable a) => !(V.Vector Column) -> Value a
285316

286317
instance (Show a) => Show (Value a) where
287318
show (Scalar v) = show v
@@ -325,6 +356,7 @@ liftValue ::
325356
liftValue f (Scalar v) = Right (Scalar (f v))
326357
liftValue f (Flat col) = Flat <$> mapColumn f col
327358
liftValue f (Group gs) = Group <$> V.mapM (mapColumn f) gs
359+
{-# INLINEABLE liftValue #-}
328360

329361
{- | Apply a binary function to two 'Value's. When one side is a
330362
'Scalar' the operation degenerates to a 'liftValue' — this is how the
@@ -351,6 +383,7 @@ liftValue2 _ (Group _) (Flat _) =
351383
Left $ AggregatedAndNonAggregatedException "non-aggregated" "aggregated"
352384
liftValue2 _ (Group _) (Group _) =
353385
Left $ InternalException "Group count mismatch in binary operation"
386+
{-# INLINEABLE liftValue2 #-}
354387

355388
-- | Branch on a boolean 'Value', selecting from two same-typed 'Value's.
356389
branchValue ::
@@ -384,6 +417,7 @@ branchValue _ _ _ =
384417
AggregatedAndNonAggregatedException
385418
"if-then-else branches"
386419
"mismatched shapes"
420+
{-# INLINEABLE branchValue #-}
387421

388422
{- | Low-level column branch: given a boolean column and two same-typed
389423
columns, produce the element-wise selection.

dataframe-learn/dataframe-learn.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ library
2929
import: warnings
3030
exposed-modules:
3131
DataFrame.DecisionTree
32+
DataFrame.LinearSolver
3233
DataFrame.Synthesis
3334
build-depends: base >= 4 && < 5,
3435
containers >= 0.6.7 && < 0.9,

0 commit comments

Comments
 (0)