Skip to content

Commit c9bc13d

Browse files
committed
feat: derive read schema from record for CSV.
1 parent 088c884 commit c9bc13d

5 files changed

Lines changed: 285 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
### New features
55
* New `DataFrame.Typed.TH.deriveSchemaFromType` Template Haskell splice generates a typed schema synonym and a `HasSchema` instance from a Haskell record ADT. Pair with `DataFrame.fromRecords` / `DataFrame.toRecords` (or `DataFrame.Typed.fromRecordsTyped` / `toRecordsTyped`) to convert between `[Order]` and `DataFrame`/`TypedDataFrame OrderSchema`. Field names are translated `camelCase → snake_case` by default; the transform is configurable via `SchemaOptions`.
66
* New `DataFrame.Typed.Generic` exposes `SchemaOf`/`SchemaOfRaw` plus `genericToColumns` / `genericFromColumns`, so users who prefer `GHC.Generics` over a TH splice can derive the same schema and row bridge.
7+
* New `DataFrame.Internal.Schema.deriveSchema` Template Haskell splice generates, from a record ADT, both a runtime `Schema` value (`orderSchema :: Schema`, suitable for `readCsvWithSchema` / `readCsvWithOpts`) and one `Expr` accessor per field (`orderCustomerId :: Expr Int`, etc.), so expression-DSL code can refer to columns by typed name without writing `col @T "snake_case_name"` at each call site. Re-exported from `DataFrame`.
78

89
### Refactor
910
* The untyped Template Haskell splices (`declareColumns`, `declareColumnsFromCsvFile`, `declareColumnsFromCsvWithOpts`, `declareColumnsFromParquetFile`, `declareColumnsWithPrefix`, `declareColumnsWithPrefix'`) have moved from `DataFrame.Functions` to a new `DataFrame.TH` module (re-exported from `DataFrame`). Update imports accordingly; the bundled `dataframe.ghci` already points to the new module.

README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,33 @@ Field names are translated `camelCase → snake_case` by default; override
305305
the translation with `deriveSchemaFromTypeWith
306306
defaultSchemaOptions{nameTransform = id}` (or any `String -> String`).
307307

308+
If all you need is a runtime `Schema` to drive `readCsvWithSchema` (no
309+
typed-dataframe machinery), there's a companion splice in
310+
`DataFrame.Internal.Schema` (re-exported from `DataFrame`):
311+
312+
```haskell
313+
$(D.deriveSchema ''Order)
314+
-- emits:
315+
-- orderSchema :: Schema
316+
-- orderSchema = makeSchema [("order_id", schemaType @Int64), ...]
317+
-- orderOrderId :: Expr Int64
318+
-- orderOrderId = col "order_id"
319+
-- orderRegion :: Expr Text
320+
-- orderRegion = col "region"
321+
-- orderAmount :: Expr Double
322+
-- orderAmount = col "amount"
323+
324+
orders :: IO D.DataFrame
325+
orders = do
326+
df <- D.readCsvWithSchema orderSchema "orders.csv"
327+
pure (D.filter orderAmount (> 100) df)
328+
```
329+
330+
Each record field gets a typed accessor named `<lower-first TyConName><UpperFirst FieldName>`,
331+
so `data Order { customerId :: Int }` yields `orderCustomerId :: Expr Int = col "customer_id"`.
332+
That's the same shape as `$(D.declareColumns df)` produces from a runtime
333+
`DataFrame`, but driven off the ADT instead of an existing frame.
334+
308335
If you'd rather not depend on Template Haskell, the same schema is
309336
available via `GHC.Generics`:
310337

src/DataFrame.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ import DataFrame.IO.CSV as CSV (
260260
fromCsvBytes,
261261
readCsv,
262262
readCsvWithOpts,
263+
readCsvWithSchema,
263264
readSeparated,
264265
readTsv,
265266
writeCsv,
@@ -309,6 +310,7 @@ import DataFrame.Internal.Row as Row (
309310
toRowVector,
310311
)
311312
import DataFrame.Internal.Schema as Schema (
313+
deriveSchema,
312314
makeSchema,
313315
schemaType,
314316
)

src/DataFrame/Internal/Schema.hs

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,23 @@
44
{-# LANGUAGE InstanceSigs #-}
55
{-# LANGUAGE RankNTypes #-}
66
{-# LANGUAGE ScopedTypeVariables #-}
7+
{-# LANGUAGE TemplateHaskellQuotes #-}
78
{-# LANGUAGE TypeApplications #-}
89
{-# LANGUAGE TypeFamilies #-}
910

1011
module DataFrame.Internal.Schema where
1112

13+
import Data.Char (isUpper, toLower, toUpper)
1214
import qualified Data.Map as M
1315
import qualified Data.Proxy as P
1416
import qualified Data.Text as T
1517

1618
import Data.Maybe (isJust)
1719
import Data.Type.Equality (TestEquality (..))
1820
import DataFrame.Internal.Column (Columnable)
21+
import DataFrame.Internal.Expression (Expr)
22+
import DataFrame.Operators (col)
23+
import Language.Haskell.TH
1924
import Type.Reflection (typeRep)
2025

2126
-- | A runtime tag for a column’s element type.
@@ -108,3 +113,130 @@ True
108113
-}
109114
makeSchema :: [(T.Text, SchemaType)] -> Schema
110115
makeSchema = Schema . M.fromList
116+
117+
{- | Auto-generate a runtime 'Schema' (and per-column @'Expr'@ accessors)
118+
from a record ADT.
119+
120+
The splice reifies the record, applies @camelCase -> snake_case@ to each
121+
record-selector name, and emits:
122+
123+
* a top-level @\<lower-first TyConName\>Schema :: 'Schema'@ binding suitable
124+
for passing to 'DataFrame.IO.CSV.readCsvWithSchema' /
125+
'DataFrame.IO.CSV.readCsvWithOpts'.
126+
* one @\<lower-first TyConName\>\<UpperFirst FieldName\> :: 'Expr' /ty/@ binding
127+
per field, so you can refer to columns in expression DSL code by name
128+
without writing @col \@/ty/ "snake_case_name"@ at every call site.
129+
130+
@
131+
data Order = Order { customerId :: Int, region :: Text, amount :: Double }
132+
133+
\$(deriveSchema ''Order)
134+
-- expands to:
135+
-- orderSchema :: Schema
136+
-- orderSchema = makeSchema
137+
-- [ ("customer_id", schemaType \@Int)
138+
-- , ("region", schemaType \@Text)
139+
-- , ("amount", schemaType \@Double)
140+
-- ]
141+
-- orderCustomerId :: Expr Int
142+
-- orderCustomerId = col "customer_id"
143+
-- orderRegion :: Expr Text
144+
-- orderRegion = col "region"
145+
-- orderAmount :: Expr Double
146+
-- orderAmount = col "amount"
147+
148+
main = do
149+
df <- D.readCsvWithSchema orderSchema "orders.csv"
150+
let bigOrders = D.filterWhere (orderAmount .>. 100) df
151+
...
152+
@
153+
154+
The data type must have exactly one record constructor; sum types or
155+
positional constructors fail the splice with a descriptive error. Field
156+
types must satisfy @('Columnable' a, 'Read' a)@ — the same constraints
157+
'schemaType' already requires.
158+
-}
159+
deriveSchema :: Name -> DecsQ
160+
deriveSchema tyName = do
161+
info <- reify tyName
162+
fields <- extractRecordFields tyName info
163+
let entries =
164+
[ (camelToSnake fieldBase, fieldBase, fTy)
165+
| (fName, _bang, fTy) <- fields
166+
, let fieldBase = nameBase fName
167+
]
168+
schemaName = mkName (lowerFirst (nameBase tyName) ++ "Schema")
169+
prefix = lowerFirst (nameBase tyName)
170+
tupleE (colName, _, fTy) =
171+
TupE
172+
[ Just (AppE (VarE 'T.pack) (LitE (StringL colName)))
173+
, Just (AppTypeE (VarE 'schemaType) fTy)
174+
]
175+
schemaBody =
176+
AppE (VarE 'makeSchema) (ListE (map tupleE entries))
177+
schemaDecls =
178+
[ SigD schemaName (ConT ''Schema)
179+
, ValD (VarP schemaName) (NormalB schemaBody) []
180+
]
181+
accessorDecls =
182+
concat
183+
[ [ SigD accName (AppT (ConT ''Expr) fTy)
184+
, ValD
185+
(VarP accName)
186+
( NormalB
187+
( AppE
188+
(VarE 'col)
189+
( AppE
190+
(VarE 'T.pack)
191+
(LitE (StringL colName))
192+
)
193+
)
194+
)
195+
[]
196+
]
197+
| (colName, fieldBase, fTy) <- entries
198+
, let accName = mkName (prefix ++ upperFirst fieldBase)
199+
]
200+
pure (schemaDecls ++ accessorDecls)
201+
202+
extractRecordFields :: Name -> Info -> Q [VarBangType]
203+
extractRecordFields _ (TyConI dec) = case dec of
204+
DataD _ _ _ _ [RecC _ fs] _ -> pure fs
205+
NewtypeD _ _ _ _ (RecC _ fs) _ -> pure fs
206+
DataD _ n _ _ _ _ ->
207+
fail $
208+
"deriveSchema: "
209+
++ show n
210+
++ " must have exactly one record constructor"
211+
NewtypeD _ n _ _ _ _ ->
212+
fail $
213+
"deriveSchema: " ++ show n ++ " newtype must use record syntax"
214+
other ->
215+
fail $
216+
"deriveSchema: unsupported declaration: " ++ show other
217+
extractRecordFields tyName _ =
218+
fail $
219+
"deriveSchema: "
220+
++ show tyName
221+
++ " is not a data/newtype declaration"
222+
223+
-- Local @camelCase -> snake_case@: lowercase the first char, then prefix
224+
-- @\'_\'@ before any uppercase character (lowercased). Duplicated from
225+
-- 'DataFrame.Typed.TH.camelToSnake' to keep this module free of any
226+
-- @DataFrame.Typed.*@ imports.
227+
camelToSnake :: String -> String
228+
camelToSnake [] = []
229+
camelToSnake (c : cs) = toLower c : go cs
230+
where
231+
go [] = []
232+
go (x : xs)
233+
| isUpper x = '_' : toLower x : go xs
234+
| otherwise = x : go xs
235+
236+
lowerFirst :: String -> String
237+
lowerFirst [] = []
238+
lowerFirst (c : cs) = toLower c : cs
239+
240+
upperFirst :: String -> String
241+
upperFirst [] = []
242+
upperFirst (c : cs) = toUpper c : cs

tests/Operations/Record.hs

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@ module Operations.Record where
1515

1616
import Data.Int (Int64)
1717
import qualified Data.Map.Strict as M
18-
import Data.Proxy (Proxy (..))
1918
import qualified Data.Text as T
19+
import qualified Data.Text.IO as TIO
2020
import GHC.Generics (Generic)
2121

2222
import qualified DataFrame as D
23+
import qualified DataFrame.Functions as F
2324
import qualified DataFrame.Internal.Column as DI
2425
import qualified DataFrame.Internal.Schema as IS
26+
import DataFrame.Operators
2527
import DataFrame.Typed (Schema)
2628
import qualified DataFrame.Typed as DT
2729

@@ -36,6 +38,7 @@ data Order = Order
3638
deriving (Show, Eq)
3739

3840
$(DT.deriveSchemaFromType ''Order)
41+
$(D.deriveSchema ''Order)
3942

4043
-- Nullable fields (Maybe Text -> RNullableBoxed; Maybe Int -> RNullableUnboxed).
4144
data User = User
@@ -46,6 +49,7 @@ data User = User
4649
deriving (Show, Eq)
4750

4851
$(DT.deriveSchemaFromType ''User)
52+
$(D.deriveSchema ''User)
4953

5054
-- Identity-cased: keep the record selector names verbatim.
5155
data Account = Account
@@ -85,6 +89,7 @@ data Wide = Wide
8589
deriving (Show, Eq)
8690

8791
$(DT.deriveSchemaFromType ''Wide)
92+
$(D.deriveSchema ''Wide)
8893

8994
-- Generics opt-in: derive the schema via Generic, not TH.
9095
data Foo = Foo
@@ -220,6 +225,117 @@ genericColumnNames = TestCase $ do
220225
["foo_id", "foo_name", "foo_value"]
221226
(D.columnNames df)
222227

228+
deriveSchemaSplice :: Test
229+
deriveSchemaSplice = TestCase $ do
230+
assertEqual
231+
"orderSchema column names"
232+
["amount", "order_id", "region"]
233+
(M.keys (IS.elements orderSchema))
234+
assertEqual
235+
"order_id is Int64"
236+
(Just (IS.schemaType @Int64))
237+
(M.lookup "order_id" (IS.elements orderSchema))
238+
assertEqual
239+
"region is Text"
240+
(Just (IS.schemaType @T.Text))
241+
(M.lookup "region" (IS.elements orderSchema))
242+
assertEqual
243+
"amount is Double"
244+
(Just (IS.schemaType @Double))
245+
(M.lookup "amount" (IS.elements orderSchema))
246+
247+
deriveSchemaNullable :: Test
248+
deriveSchemaNullable = TestCase $ do
249+
assertEqual
250+
"userSchema column names"
251+
["user_age", "user_id", "user_name"]
252+
(M.keys (IS.elements userSchema))
253+
assertEqual
254+
"user_id is Int64"
255+
(Just (IS.schemaType @Int64))
256+
(M.lookup "user_id" (IS.elements userSchema))
257+
assertEqual
258+
"user_name is Maybe Text"
259+
(Just (IS.schemaType @(Maybe T.Text)))
260+
(M.lookup "user_name" (IS.elements userSchema))
261+
assertEqual
262+
"user_age is Maybe Int"
263+
(Just (IS.schemaType @(Maybe Int)))
264+
(M.lookup "user_age" (IS.elements userSchema))
265+
266+
deriveSchemaWide :: Test
267+
deriveSchemaWide = TestCase $ do
268+
assertEqual
269+
"wideSchema has 8 keys"
270+
8
271+
(M.size (IS.elements wideSchema))
272+
assertEqual
273+
"f1 is Int"
274+
(Just (IS.schemaType @Int))
275+
(M.lookup "f1" (IS.elements wideSchema))
276+
assertEqual
277+
"f8 is Int"
278+
(Just (IS.schemaType @Int))
279+
(M.lookup "f8" (IS.elements wideSchema))
280+
281+
deriveSchemaReadsCsv :: Test
282+
deriveSchemaReadsCsv = TestCase $ do
283+
let csv =
284+
T.unlines
285+
[ "order_id,region,amount"
286+
, "1,us,10.0"
287+
, "2,eu,20.5"
288+
, "3,ap,30.0"
289+
]
290+
tmp = "/tmp/dataframe_test_deriveSchema.csv"
291+
TIO.writeFile tmp csv
292+
df <- D.readCsvWithSchema orderSchema tmp
293+
assertEqual
294+
"deriveSchema-driven readCsvWithSchema column names"
295+
["order_id", "region", "amount"]
296+
(D.columnNames df)
297+
case D.toRecords df :: Either T.Text [Order] of
298+
Left e -> assertFailure (T.unpack e)
299+
Right xs ->
300+
assertEqual
301+
"deriveSchema-driven CSV parses back to records"
302+
[Order 1 "us" 10.0, Order 2 "eu" 20.5, Order 3 "ap" 30.0]
303+
xs
304+
305+
deriveSchemaAccessorFilter :: Test
306+
deriveSchemaAccessorFilter = TestCase $ do
307+
let df =
308+
D.fromRecords
309+
[ Order 1 "us" 20.0
310+
, Order 2 "eu" 20.5
311+
, Order 3 "ap" 30.0
312+
, Order 4 "us" 25.0
313+
]
314+
big =
315+
D.filterWhere
316+
(orderAmount .>. F.lit @Double 15.0 .&&. orderRegion .==. F.lit @T.Text "us")
317+
df
318+
case D.toRecords big :: Either T.Text [Order] of
319+
Left e -> assertFailure (T.unpack e)
320+
Right xs ->
321+
assertEqual
322+
"accessor drives D.filter (amount > 15.0 && region == \"us\")"
323+
[Order 1 "us" 20.0, Order 4 "us" 25.0]
324+
xs
325+
326+
deriveSchemaAccessorDerive :: Test
327+
deriveSchemaAccessorDerive = TestCase $ do
328+
let df =
329+
D.fromRecords
330+
[ Order 1 "us" 10.0
331+
, Order 2 "eu" 20.0
332+
]
333+
df' = D.derive "double_amount" (orderAmount + orderAmount) df
334+
assertEqual
335+
"accessor composes in derive expression"
336+
[20.0, 40.0]
337+
(D.columnAsList (D.col @Double "double_amount") df')
338+
223339
tests :: [Test]
224340
tests =
225341
[ TestLabel "basicTypedRoundTrip" basicTypedRoundTrip
@@ -233,4 +349,10 @@ tests =
233349
, TestLabel "wideRoundTrip" wideRoundTrip
234350
, TestLabel "genericRoundTrip" genericRoundTrip
235351
, TestLabel "genericColumnNames" genericColumnNames
352+
, TestLabel "deriveSchemaSplice" deriveSchemaSplice
353+
, TestLabel "deriveSchemaNullable" deriveSchemaNullable
354+
, TestLabel "deriveSchemaWide" deriveSchemaWide
355+
, TestLabel "deriveSchemaReadsCsv" deriveSchemaReadsCsv
356+
, TestLabel "deriveSchemaAccessorFilter" deriveSchemaAccessorFilter
357+
, TestLabel "deriveSchemaAccessorDerive" deriveSchemaAccessorDerive
236358
]

0 commit comments

Comments
 (0)