Skip to content

Commit 23fc32b

Browse files
committed
chore: Update California housing example to work with hasktorch.
1 parent 23b181f commit 23fc32b

4 files changed

Lines changed: 129 additions & 5 deletions

File tree

dataframe.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ executable california_housing
141141
bytestring >= 0.11 && <= 0.12.2.0,
142142
containers >= 0.6.7 && < 0.8,
143143
directory >= 1.3.0.0 && <= 1.3.9.0,
144+
hasktorch,
144145
hashable >= 1.2 && <= 1.5.0.0,
145146
statistics >= 0.16.2.1 && <= 0.16.3.0,
146147
template-haskell >= 2.0 && <= 2.30,

examples/CaliforniaHousing.hs

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,67 @@
1-
{-# LANGUAGE OverloadedStrings #-}
1+
{-# LANGUAGE OverloadedStrings #-}
2+
{-# LANGUAGE TypeApplications #-}
3+
{-# LANGUAGE ScopedTypeVariables #-}
4+
{-# LANGUAGE NumericUnderscores #-}
25
module Main where
36

7+
import Control.Monad (when)
8+
import Data.Maybe
9+
import qualified Data.Map as M
10+
import qualified Data.Text as T
411
import qualified DataFrame as D
12+
import qualified DataFrame.Functions as F
13+
import qualified Data.Vector.Unboxed as VU
14+
import qualified Data.Vector as V
15+
import Torch
16+
17+
import DataFrame ((|>))
518

619
main :: IO ()
720
main = do
8-
parsed <- D.readCsv "./data/housing.csv"
21+
{- Feature ingestion and engineering -}
22+
df <- fmap (D.apply (\(op :: T.Text) -> oceanProximity M.! op) "ocean_proximity") (D.readCsv "./data/housing.csv")
23+
-- This column has nulls so we:
24+
-- * Remove all nulls with filterJust
25+
-- * Calculate the mean of total_bedrooms
26+
-- * impute the mean.
27+
-- This could probably be a utility function.
28+
let meanTotalBedrooms = fromMaybe 0 $ df |> D.filterJust "total_bedrooms"
29+
|> D.mean "total_bedrooms"
30+
imputed = df |> D.impute "total_bedrooms" meanTotalBedrooms
31+
|> D.exclude ["median_house_value"]
32+
|> normalizeFeatures
33+
(r, c) = D.dimensions imputed
34+
features = reshape [r,c] $ asTensor (flattenFeatures imputed)
35+
labels = asTensor ((VU.map realToFrac . VU.convert) (D.columnAsVector @Double "median_house_value" df) :: VU.Vector Float)
36+
37+
{- Train the model -}
38+
putStrLn "Training linear regression model..."
39+
init <- sample $ LinearSpec{in_features = (snd (D.dimensions df) - 1), out_features = 1}
40+
trained <- foldLoop init 100_000 $ \state i -> do
41+
let labels' = model state features
42+
loss = mseLoss labels labels'
43+
when (i `mod` 10_000 == 0) $ do
44+
putStrLn $ "Iteration: " ++ show i ++ " | Loss: " ++ show loss
45+
(state', _) <- runStep state GD loss 0.1
46+
pure state'
47+
48+
{- Show predictions -}
49+
let predictions = D.insertUnboxedVector "predicted_house_value" (asValue @(VU.Vector Float) (model trained features)) df
50+
print $ D.select ["median_house_value", "predicted_house_value"] predictions |> D.take 10
51+
52+
normalizeFeatures :: D.DataFrame -> D.DataFrame
53+
normalizeFeatures df = df |> D.fold (\name d -> let
54+
m = fromMaybe 0 (D.mean name d)
55+
stdDev = fromMaybe 0.01 (D.standardDeviation name d)
56+
col = F.col @Double name
57+
in D.derive name ((col - (F.minimum col)) / (F.maximum col - F.minimum col)) d) (D.columnNames df)
58+
959

10-
print $ D.describeColumns parsed
60+
model :: Linear -> Tensor -> Tensor
61+
model state input = squeezeAll $ linear state input
1162

12-
print $ D.take 5 parsed
63+
oceanProximity :: M.Map T.Text Double
64+
oceanProximity = M.fromList [("ISLAND", 0), ("NEAR OCEAN", 1), ("NEAR BAY", 2), ("<1H OCEAN", 3), ("INLAND", 4)]
1365

14-
D.plotHistograms D.PlotAll D.VerticalHistogram parsed
66+
flattenFeatures :: D.DataFrame -> VU.Vector Float
67+
flattenFeatures df = V.foldl' (\acc v -> acc VU.++ v) VU.empty (D.toMatrix df)

examples/README.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Running the examples
2+
3+
## California housing
4+
5+
Preparation:
6+
This has a hasktorch integration that requires some setup. Copy the [`get-deps.sh`](https://github.com/hasktorch/hasktorch/blob/master/deps/get-deps.sh) file into the dataframe home directory. This will download and link some pytorch files required to run the examples.
7+
8+
After this is done you'll need to run `./set_hasktorch_env` to put the hasktorch libraries in your `LD_LIBRARY_PATH`.
9+
10+
Running:
11+
`cabal run california_housing`.
12+
13+
Expected output:
14+
15+
```
16+
Training linear regression model...
17+
Iteration: 10000 | Loss: Tensor Float [] 5.0225e9
18+
Iteration: 20000 | Loss: Tensor Float [] 4.9093e9
19+
Iteration: 30000 | Loss: Tensor Float [] 4.8576e9
20+
Iteration: 40000 | Loss: Tensor Float [] 4.8333e9
21+
Iteration: 50000 | Loss: Tensor Float [] 4.8217e9
22+
Iteration: 60000 | Loss: Tensor Float [] 4.8160e9
23+
Iteration: 70000 | Loss: Tensor Float [] 4.8130e9
24+
Iteration: 80000 | Loss: Tensor Float [] 4.8114e9
25+
Iteration: 90000 | Loss: Tensor Float [] 4.8105e9
26+
Iteration: 100000 | Loss: Tensor Float [] 4.8099e9
27+
--------------------------------------------------
28+
index | median_house_value | predicted_house_value
29+
------|--------------------|----------------------
30+
Int | Double | Float
31+
------|--------------------|----------------------
32+
0 | 452600.0 | 414079.94
33+
1 | 358500.0 | 423011.94
34+
2 | 352100.0 | 383239.06
35+
3 | 341300.0 | 324928.94
36+
4 | 342200.0 | 256934.23
37+
5 | 269700.0 | 264944.84
38+
6 | 299200.0 | 259094.13
39+
7 | 241400.0 | 257224.55
40+
8 | 226700.0 | 201753.69
41+
9 | 261100.0 | 268698.7
42+
```

set_hasktorch_env

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#!/usr/bin/env bash
2+
3+
if [ -z "${XDG_CACHE_HOME:-}" ]; then
4+
export XDG_CACHE_HOME="$HOME/.cache"
5+
fi
6+
7+
# execute this command if needed when building on OSX if there are linker errors.
8+
# dylib files in extra-lib-dirs don't get forwarded to ghc
9+
# in some versions of OSX. See https://github.com/commercialhaskell/stack/issues/1826
10+
HASKTORCH_LIB_PATH="$XDG_CACHE_HOME/libtorch/lib:$XDG_CACHE_HOME/mklml/lib/:$XDG_CACHE_HOME/libtokenizers/lib"
11+
12+
function add_vendor_lib_path {
13+
case "$(uname)" in
14+
"Darwin")
15+
DYLD_LIBRARY_PATH=/opt/homebrew/lib:/opt/homebrew/opt/libomp/lib:$HASKTORCH_LIB_PATH:$DYLD_LIBRARY_PATH
16+
export DYLD_LIBRARY_PATH
17+
;;
18+
"Linux"|"FreeBSD")
19+
LD_LIBRARY_PATH=$HASKTORCH_LIB_PATH:$LD_LIBRARY_PATH
20+
export LD_LIBRARY_PATH
21+
;;
22+
*)
23+
echo "OS doesn't have known environment variable hacks to set"
24+
;;
25+
esac
26+
}
27+
28+
add_vendor_lib_path

0 commit comments

Comments
 (0)