Skip to content

Commit ee915d6

Browse files
committed
Wasm: specialization of bigarray accesses
1 parent 362eabc commit ee915d6

File tree

7 files changed

+650
-16
lines changed

7 files changed

+650
-16
lines changed

compiler/lib-wasm/code_generation.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ module Arith = struct
368368
(match e, e' with
369369
| W.Const (I32 n), W.Const (I32 n') when Int32.(n' < 31l) ->
370370
W.Const (I32 (Int32.shift_left n (Int32.to_int n')))
371+
| _, W.Const (I32 0l) -> e
371372
| _ -> W.BinOp (I32 Shl, e, e'))
372373

373374
let ( lsr ) = binary (Shr U)

compiler/lib-wasm/gc_target.ml

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,38 @@ module Type = struct
419419
}
420420
])
421421
})
422+
423+
let int_array_type =
424+
register_type "int_array" (fun () ->
425+
return
426+
{ supertype = None
427+
; final = true
428+
; typ = W.Array { mut = true; typ = Value I32 }
429+
})
430+
431+
let bigarray_type =
432+
register_type "bigarray" (fun () ->
433+
let* custom_operations = custom_operations_type in
434+
let* int_array = int_array_type in
435+
let* custom = custom_type in
436+
return
437+
{ supertype = Some custom
438+
; final = true
439+
; typ =
440+
W.Struct
441+
[ { mut = false
442+
; typ = Value (Ref { nullable = false; typ = Type custom_operations })
443+
}
444+
; { mut = true; typ = Value (Ref { nullable = false; typ = Extern }) }
445+
; { mut = true; typ = Value (Ref { nullable = false; typ = Extern }) }
446+
; { mut = false
447+
; typ = Value (Ref { nullable = false; typ = Type int_array })
448+
}
449+
; { mut = false; typ = Packed I8 }
450+
; { mut = false; typ = Packed I8 }
451+
; { mut = false; typ = Packed I8 }
452+
]
453+
})
422454
end
423455

424456
module Value = struct
@@ -1360,6 +1392,237 @@ module Math = struct
13601392
let exp2 x = power (return (W.Const (F64 2.))) x
13611393
end
13621394

1395+
module Bigarray = struct
1396+
let dimension n a =
1397+
let* ty = Type.bigarray_type in
1398+
Memory.wasm_array_get
1399+
~ty:Type.int_array_type
1400+
(Memory.wasm_struct_get ty (Memory.wasm_cast ty a) 3)
1401+
(Arith.const (Int32.of_int n))
1402+
1403+
let get_at_offset ~(kind : Typing.Bigarray.kind) a i =
1404+
let name, (typ : Wasm_ast.value_type), size, box =
1405+
match kind with
1406+
| Float32 ->
1407+
( "dv_get_f32"
1408+
, F32
1409+
, 2
1410+
, fun x ->
1411+
let* x = x in
1412+
Memory.box_float (return (W.F64PromoteF32 x)) )
1413+
| Float64 -> "dv_get_f64", F64, 3, Memory.box_float
1414+
| Int8_signed -> "dv_get_i8", I32, 0, Fun.id
1415+
| Int8_unsigned | Char -> "dv_get_ui8", I32, 0, Fun.id
1416+
| Int16_signed -> "dv_get_i16", I32, 1, Fun.id
1417+
| Int16_unsigned -> "dv_get_ui16", I32, 1, Fun.id
1418+
| Int32 -> "dv_get_i32", I32, 2, Memory.box_int32
1419+
| Nativeint -> "dv_get_i32", I32, 2, Memory.box_nativeint
1420+
| Int64 -> "dv_get_i64", I64, 3, Memory.box_int64
1421+
| Int -> "dv_get_i32", I32, 2, Fun.id
1422+
| Float16 ->
1423+
( "dv_get_i16"
1424+
, I32
1425+
, 1
1426+
, fun x ->
1427+
let* conv =
1428+
register_import
1429+
~name:"caml_float16_to_double"
1430+
(Fun { W.params = [ I32 ]; result = [ F64 ] })
1431+
in
1432+
let* x = x in
1433+
Memory.box_float (return (W.Call (conv, [ x ]))) )
1434+
| Complex32 ->
1435+
( "dv_get_f32"
1436+
, F32
1437+
, 3
1438+
, fun x ->
1439+
let* x = x in
1440+
return (W.F64PromoteF32 x) )
1441+
| Complex64 -> "dv_get_f64", F64, 4, Fun.id
1442+
in
1443+
let* little_endian =
1444+
register_import
1445+
~import_module:"bindings"
1446+
~name:"littleEndian"
1447+
(Global { mut = false; typ = I32 })
1448+
in
1449+
let* f =
1450+
register_import
1451+
~import_module:"bindings"
1452+
~name
1453+
(Fun
1454+
{ W.params =
1455+
Ref { nullable = true; typ = Extern }
1456+
:: I32
1457+
:: (if size = 0 then [] else [ I32 ])
1458+
; result = [ typ ]
1459+
})
1460+
in
1461+
let* ty = Type.bigarray_type in
1462+
let* ta = Memory.wasm_struct_get ty (Memory.wasm_cast ty a) 2 in
1463+
let* ofs = Arith.(i lsl const (Int32.of_int size)) in
1464+
match kind with
1465+
| Float32
1466+
| Float64
1467+
| Int8_signed
1468+
| Int8_unsigned
1469+
| Int16_signed
1470+
| Int16_unsigned
1471+
| Int32
1472+
| Int64
1473+
| Int
1474+
| Nativeint
1475+
| Char
1476+
| Float16 ->
1477+
box
1478+
(return
1479+
(W.Call
1480+
(f, ta :: ofs :: (if size = 0 then [] else [ W.GlobalGet little_endian ]))))
1481+
| Complex32 | Complex64 ->
1482+
let delta = Int32.shift_left 1l (size - 1) in
1483+
let* ofs' = Arith.(return ofs + const delta) in
1484+
let* x = box (return (W.Call (f, [ ta; ofs; W.GlobalGet little_endian ]))) in
1485+
let* y = box (return (W.Call (f, [ ta; ofs'; W.GlobalGet little_endian ]))) in
1486+
let* ty = Type.float_array_type in
1487+
return (W.ArrayNewFixed (ty, [ x; y ]))
1488+
1489+
let set_at_offset ~kind a i v =
1490+
let name, (typ : Wasm_ast.value_type), size, unbox =
1491+
match (kind : Typing.Bigarray.kind) with
1492+
| Float32 ->
1493+
( "dv_set_f32"
1494+
, F32
1495+
, 2
1496+
, fun x ->
1497+
let* e = Memory.unbox_float x in
1498+
return (W.F32DemoteF64 e) )
1499+
| Float64 -> "dv_set_f64", F64, 3, Memory.unbox_float
1500+
| Int8_signed | Int8_unsigned | Char -> "dv_set_i8", I32, 0, Fun.id
1501+
| Int16_signed | Int16_unsigned -> "dv_set_i16", I32, 1, Fun.id
1502+
| Int32 -> "dv_set_i32", I32, 2, Memory.unbox_int32
1503+
| Nativeint -> "dv_set_i32", I32, 2, Memory.unbox_nativeint
1504+
| Int64 -> "dv_set_i64", I64, 3, Memory.unbox_int64
1505+
| Int -> "dv_set_i32", I32, 2, Fun.id
1506+
| Float16 ->
1507+
( "dv_set_i16"
1508+
, I32
1509+
, 1
1510+
, fun x ->
1511+
let* conv =
1512+
register_import
1513+
~name:"caml_double_to_float16"
1514+
(Fun { W.params = [ F64 ]; result = [ I32 ] })
1515+
in
1516+
let* x = Memory.unbox_float x in
1517+
return (W.Call (conv, [ x ])) )
1518+
| Complex32 ->
1519+
( "dv_set_f32"
1520+
, F32
1521+
, 3
1522+
, fun x ->
1523+
let* x = x in
1524+
return (W.F32DemoteF64 x) )
1525+
| Complex64 -> "dv_set_f64", F64, 4, Fun.id
1526+
in
1527+
let* ty = Type.bigarray_type in
1528+
let* ta = Memory.wasm_struct_get ty (Memory.wasm_cast ty a) 2 in
1529+
let* ofs = Arith.(i lsl const (Int32.of_int size)) in
1530+
let* little_endian =
1531+
register_import
1532+
~import_module:"bindings"
1533+
~name:"littleEndian"
1534+
(Global { mut = false; typ = I32 })
1535+
in
1536+
let* f =
1537+
register_import
1538+
~import_module:"bindings"
1539+
~name
1540+
(Fun
1541+
{ W.params =
1542+
Ref { nullable = true; typ = Extern }
1543+
:: I32
1544+
:: typ
1545+
:: (if size = 0 then [] else [ I32 ])
1546+
; result = []
1547+
})
1548+
in
1549+
match kind with
1550+
| Float32
1551+
| Float64
1552+
| Int8_signed
1553+
| Int8_unsigned
1554+
| Int16_signed
1555+
| Int16_unsigned
1556+
| Int32
1557+
| Int64
1558+
| Int
1559+
| Nativeint
1560+
| Char
1561+
| Float16 ->
1562+
let* v = unbox v in
1563+
instr
1564+
(W.CallInstr
1565+
( f
1566+
, ta :: ofs :: v :: (if size = 0 then [] else [ W.GlobalGet little_endian ])
1567+
))
1568+
| Complex32 | Complex64 ->
1569+
let delta = Int32.shift_left 1l (size - 1) in
1570+
let* ofs' = Arith.(return ofs + const delta) in
1571+
let ty = Type.float_array_type in
1572+
let* x = unbox (Memory.wasm_array_get ~ty v (Arith.const 0l)) in
1573+
let* () = instr (W.CallInstr (f, [ ta; ofs; x; W.GlobalGet little_endian ])) in
1574+
let* y = unbox (Memory.wasm_array_get ~ty v (Arith.const 1l)) in
1575+
instr (W.CallInstr (f, [ ta; ofs'; y; W.GlobalGet little_endian ]))
1576+
1577+
let offset ~bound_error_index ~(layout : Typing.Bigarray.layout) ta ~indices =
1578+
let l =
1579+
List.mapi
1580+
~f:(fun pos i ->
1581+
let i =
1582+
match layout with
1583+
| C -> i
1584+
| Fortran -> Arith.(i - const 1l)
1585+
in
1586+
let i' = Code.Var.fresh () in
1587+
let dim = Code.Var.fresh () in
1588+
( (let* () = store ~typ:I32 i' i in
1589+
let* () = store ~typ:I32 dim (dimension pos ta) in
1590+
let* cond = Arith.uge (load i') (load dim) in
1591+
instr (W.Br_if (bound_error_index, cond)))
1592+
, i'
1593+
, dim ))
1594+
indices
1595+
in
1596+
let l =
1597+
match layout with
1598+
| C -> l
1599+
| Fortran -> List.rev l
1600+
in
1601+
match l with
1602+
| (instrs, i', _) :: rem ->
1603+
List.fold_left
1604+
~f:(fun (instrs, ofs) (instrs', i', dim) ->
1605+
let ofs' = Code.Var.fresh () in
1606+
( (let* () = instrs in
1607+
let* () = instrs' in
1608+
store ~typ:I32 ofs' Arith.((ofs * load dim) + load i'))
1609+
, load ofs' ))
1610+
~init:(instrs, load i')
1611+
rem
1612+
| [] -> return (), Arith.const 0l
1613+
1614+
let get ~bound_error_index ~kind ~layout ta ~indices =
1615+
let instrs, ofs = offset ~bound_error_index ~layout ta ~indices in
1616+
seq instrs (get_at_offset ~kind ta ofs)
1617+
1618+
let set ~bound_error_index ~kind ~layout ta ~indices v =
1619+
let instrs, ofs = offset ~bound_error_index ~layout ta ~indices in
1620+
seq
1621+
(let* () = instrs in
1622+
set_at_offset ~kind ta ofs v)
1623+
Value.unit
1624+
end
1625+
13631626
module JavaScript = struct
13641627
let anyref = W.Ref { nullable = true; typ = Any }
13651628

0 commit comments

Comments
 (0)