Skip to content

Commit fe49f37

Browse files
committed
Wasm: specialization of number comparisons
1 parent 9c334eb commit fe49f37

File tree

6 files changed

+566
-11
lines changed

6 files changed

+566
-11
lines changed

compiler/lib-wasm/generate.ml

Lines changed: 110 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ module Generate (Target : Target_sig.S) = struct
3636
{ live : int array
3737
; in_cps : Effects.in_cps
3838
; deadcode_sentinal : Var.t
39+
; types : Typing.typ Var.Tbl.t
3940
; blocks : block Addr.Map.t
4041
; closures : Closure_conversion.closure Var.Map.t
4142
; global_context : Code_generation.context
@@ -230,6 +231,39 @@ module Generate (Target : Target_sig.S) = struct
230231
f context (transl_prim_arg x) (transl_prim_arg y) (transl_prim_arg z)
231232
| _ -> invalid_arity name l ~expected:3)
232233

234+
let get_type ctx p =
235+
match p with
236+
| Pv x -> Var.Tbl.get ctx.types x
237+
| Pc c -> Typing.constant_type c
238+
239+
let register_comparison name cmp_int cmp_boxed_int cmp_float =
240+
register_prim name `Mutable (fun ctx _ transl_prim_arg l ->
241+
match l with
242+
| [ x; y ] -> (
243+
let x' = transl_prim_arg x in
244+
let y' = transl_prim_arg y in
245+
match get_type ctx x, get_type ctx y with
246+
| Number Int, Number Int -> cmp_int x' y'
247+
| Number Int32, Number Int32 ->
248+
let* x' = Memory.unbox_int32 x' in
249+
let* y' = Memory.unbox_int32 y' in
250+
Value.val_int (return (W.BinOp (I32 cmp_boxed_int, x', y')))
251+
| Number Nativeint, Number Nativeint ->
252+
let* x' = Memory.unbox_nativeint x' in
253+
let* y' = Memory.unbox_nativeint y' in
254+
Value.val_int (return (W.BinOp (I32 cmp_boxed_int, x', y')))
255+
| Number Int64, Number Int64 ->
256+
let* x' = Memory.unbox_int64 x' in
257+
let* y' = Memory.unbox_int64 y' in
258+
Value.val_int (return (W.BinOp (I64 cmp_boxed_int, x', y')))
259+
| Number Float, Number Float -> float_comparison cmp_float x' y'
260+
| _ ->
261+
let* f = register_import ~name (Fun (Type.primitive_type 2)) in
262+
let* x' = x' in
263+
let* y' = y' in
264+
return (W.Call (f, [ x'; y' ])))
265+
| _ -> invalid_arity name l ~expected:2)
266+
233267
let () =
234268
register_bin_prim "caml_array_unsafe_get" `Mutable Memory.gen_array_get;
235269
register_bin_prim "caml_floatarray_unsafe_get" `Mutable Memory.float_array_get;
@@ -602,7 +636,76 @@ module Generate (Target : Target_sig.S) = struct
602636
l
603637
~init:(return [])
604638
in
605-
Memory.allocate ~tag:0 ~deadcode_sentinal:ctx.deadcode_sentinal l)
639+
Memory.allocate ~tag:0 ~deadcode_sentinal:ctx.deadcode_sentinal l);
640+
register_comparison "caml_greaterthan" (fun y x -> Value.lt x y) (Gt S) Gt;
641+
register_comparison "caml_greaterequal" (fun y x -> Value.le x y) (Ge S) Ge;
642+
register_comparison "caml_lessthan" Value.lt (Lt S) Lt;
643+
register_comparison "caml_lessequal" Value.le (Le S) Le;
644+
register_comparison
645+
"caml_equal"
646+
(fun x y ->
647+
let* x = x in
648+
let* y = y in
649+
Value.val_int (return (W.RefEq (x, y))))
650+
Eq
651+
Eq;
652+
register_comparison
653+
"caml_notequal"
654+
(fun x y ->
655+
let* x = x in
656+
let* y = y in
657+
Value.val_int (return (W.UnOp (I32 Eqz, RefEq (x, y)))))
658+
Ne
659+
Ne;
660+
register_prim "caml_compare" `Mutable (fun ctx _ transl_prim_arg l ->
661+
match l with
662+
| [ x; y ] -> (
663+
let x' = transl_prim_arg x in
664+
let y' = transl_prim_arg y in
665+
match get_type ctx x, get_type ctx y with
666+
| Number Int, Number Int ->
667+
Value.val_int
668+
Arith.(
669+
(Value.int_val y' < Value.int_val x')
670+
- (Value.int_val x' < Value.int_val y'))
671+
| Number Int32, Number Int32 ->
672+
let* f =
673+
register_import ~name:"caml_int32_compare" (Fun (Type.primitive_type 2))
674+
in
675+
let* x' = Memory.unbox_int32 x' in
676+
let* y' = Memory.unbox_int32 y' in
677+
return (W.Call (f, [ x'; y' ]))
678+
| Number Nativeint, Number Nativeint ->
679+
let* f =
680+
register_import
681+
~name:"caml_nativeint_compare"
682+
(Fun (Type.primitive_type 2))
683+
in
684+
let* x' = Memory.unbox_nativeint x' in
685+
let* y' = Memory.unbox_nativeint y' in
686+
return (W.Call (f, [ x'; y' ]))
687+
| Number Int64, Number Int64 ->
688+
let* f =
689+
register_import ~name:"caml_int64_compare" (Fun (Type.primitive_type 2))
690+
in
691+
let* x' = Memory.unbox_int64 x' in
692+
let* y' = Memory.unbox_int64 y' in
693+
return (W.Call (f, [ x'; y' ]))
694+
| Number Float, Number Float ->
695+
let* f =
696+
register_import ~name:"caml_float_compare" (Fun (Type.primitive_type 2))
697+
in
698+
let* x' = Memory.unbox_int64 x' in
699+
let* y' = Memory.unbox_int64 y' in
700+
return (W.Call (f, [ x'; y' ]))
701+
| _ ->
702+
let* f =
703+
register_import ~name:"caml_compare" (Fun (Type.primitive_type 2))
704+
in
705+
let* x' = x' in
706+
let* y' = y' in
707+
return (W.Call (f, [ x'; y' ])))
708+
| _ -> invalid_arity "caml_compare" l ~expected:2)
606709

607710
let rec translate_expr ctx context x e =
608711
match e with
@@ -1183,7 +1286,8 @@ module Generate (Target : Target_sig.S) = struct
11831286
~should_export
11841287
~warn_on_unhandled_effect
11851288
*)
1186-
~deadcode_sentinal =
1289+
~deadcode_sentinal
1290+
~types =
11871291
global_context.unit_name <- unit_name;
11881292
let p, closures = Closure_conversion.f p in
11891293
(*
@@ -1193,6 +1297,7 @@ module Generate (Target : Target_sig.S) = struct
11931297
{ live = live_vars
11941298
; in_cps
11951299
; deadcode_sentinal
1300+
; types
11961301
; blocks = p.blocks
11971302
; closures
11981303
; global_context
@@ -1306,8 +1411,10 @@ let start () = make_context ~value_type:Gc_target.Type.value
13061411

13071412
let f ~context ~unit_name p ~live_vars ~in_cps ~deadcode_sentinal =
13081413
let t = Timer.make () in
1414+
let state, info = Global_flow.f' ~fast:false p in
1415+
let types = Typing.f ~state ~info p in
13091416
let p = fix_switch_branches p in
1310-
let res = G.f ~context ~unit_name ~live_vars ~in_cps ~deadcode_sentinal p in
1417+
let res = G.f ~context ~unit_name ~live_vars ~in_cps ~deadcode_sentinal ~types p in
13111418
if times () then Format.eprintf " code gen.: %a@." Timer.print t;
13121419
res
13131420

0 commit comments

Comments
 (0)