Skip to content

Commit d5ae284

Browse files
committed
Wasm: specialization of number comparisons
1 parent 0ea2d6f commit d5ae284

File tree

7 files changed

+226
-53
lines changed

7 files changed

+226
-53
lines changed

compiler/lib-wasm/generate.ml

Lines changed: 127 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ module Generate (Target : Target_sig.S) = struct
6767
type repr =
6868
| Value
6969
| Float
70+
| Int
7071
| Int32
7172
| Nativeint
7273
| Int64
@@ -75,24 +76,23 @@ module Generate (Target : Target_sig.S) = struct
7576
match r with
7677
| Value -> Type.value
7778
| Float -> F64
78-
| Int32 -> I32
79-
| Nativeint -> I32
79+
| Int | Int32 | Nativeint -> I32
8080
| Int64 -> I64
8181

8282
let specialized_primitive_type (_, params, result) =
8383
{ W.params = List.map ~f:repr_type params; result = [ repr_type result ] }
8484

8585
let box_value r e =
8686
match r with
87-
| Value -> e
87+
| Value | Int -> e
8888
| Float -> Memory.box_float e
8989
| Int32 -> Memory.box_int32 e
9090
| Nativeint -> Memory.box_nativeint e
9191
| Int64 -> Memory.box_int64 e
9292

9393
let unbox_value r e =
9494
match r with
95-
| Value -> e
95+
| Value | Int -> e
9696
| Float -> Memory.unbox_float e
9797
| Int32 -> Memory.unbox_int32 e
9898
| Nativeint -> Memory.unbox_nativeint e
@@ -105,9 +105,9 @@ module Generate (Target : Target_sig.S) = struct
105105
[ "caml_int32_bswap", (`Pure, [ Int32 ], Int32)
106106
; "caml_nativeint_bswap", (`Pure, [ Nativeint ], Nativeint)
107107
; "caml_int64_bswap", (`Pure, [ Int64 ], Int64)
108-
; "caml_int32_compare", (`Pure, [ Int32; Int32 ], Value)
109-
; "caml_nativeint_compare", (`Pure, [ Nativeint; Nativeint ], Value)
110-
; "caml_int64_compare", (`Pure, [ Int64; Int64 ], Value)
108+
; "caml_int32_compare", (`Pure, [ Int32; Int32 ], Int)
109+
; "caml_nativeint_compare", (`Pure, [ Nativeint; Nativeint ], Int)
110+
; "caml_int64_compare", (`Pure, [ Int64; Int64 ], Int)
111111
; "caml_string_get32", (`Mutator, [ Value; Value ], Int32)
112112
; "caml_string_get64", (`Mutator, [ Value; Value ], Int64)
113113
; "caml_bytes_get32", (`Mutator, [ Value; Value ], Int32)
@@ -124,7 +124,7 @@ module Generate (Target : Target_sig.S) = struct
124124
; "caml_ldexp_float", (`Pure, [ Float; Value ], Float)
125125
; "caml_erf_float", (`Pure, [ Float ], Float)
126126
; "caml_erfc_float", (`Pure, [ Float ], Float)
127-
; "caml_float_compare", (`Pure, [ Float; Float ], Value)
127+
; "caml_float_compare", (`Pure, [ Float; Float ], Int)
128128
];
129129
h
130130

@@ -299,6 +299,38 @@ module Generate (Target : Target_sig.S) = struct
299299
(transl_prim_arg ctx ?typ:tz z)
300300
| _ -> invalid_arity name l ~expected:3)
301301

302+
let register_comparison name cmp_int cmp_boxed_int cmp_float =
303+
register_prim name `Mutable (fun ctx _ l ->
304+
match l with
305+
| [ x; y ] -> (
306+
let x' = transl_prim_arg ctx x in
307+
let y' = transl_prim_arg ctx y in
308+
match get_type ctx x, get_type ctx y with
309+
| Int _, Int _ -> cmp_int ctx x y
310+
| Number Int32, Number Int32 ->
311+
let* x' = Memory.unbox_int32 x' in
312+
let* y' = Memory.unbox_int32 y' in
313+
return (W.BinOp (I32 cmp_boxed_int, x', y'))
314+
| Number Nativeint, Number Nativeint ->
315+
let* x' = Memory.unbox_nativeint x' in
316+
let* y' = Memory.unbox_nativeint y' in
317+
return (W.BinOp (I32 cmp_boxed_int, x', y'))
318+
| Number Int64, Number Int64 ->
319+
let* x' = Memory.unbox_int64 x' in
320+
let* y' = Memory.unbox_int64 y' in
321+
return (W.BinOp (I64 cmp_boxed_int, x', y'))
322+
| Number Float, Number Float -> float_comparison cmp_float x' y'
323+
| _ ->
324+
let* f =
325+
register_import
326+
~name
327+
(Fun { W.params = [ Type.value; Type.value ]; result = [ I32 ] })
328+
in
329+
let* x' = x' in
330+
let* y' = y' in
331+
return (W.Call (f, [ x'; y' ])))
332+
| _ -> invalid_arity name l ~expected:2)
333+
302334
let () =
303335
register_bin_prim
304336
"caml_array_unsafe_get"
@@ -780,7 +812,93 @@ module Generate (Target : Target_sig.S) = struct
780812
l
781813
~init:(return [])
782814
in
783-
Memory.allocate ~tag:0 ~deadcode_sentinal:ctx.deadcode_sentinal ~load l)
815+
Memory.allocate ~tag:0 ~deadcode_sentinal:ctx.deadcode_sentinal ~load l);
816+
register_comparison
817+
"caml_greaterthan"
818+
(fun ctx x y -> translate_int_comparison ctx (fun y x -> Arith.(x < y)) x y)
819+
(Gt S)
820+
Gt;
821+
register_comparison
822+
"caml_greaterequal"
823+
(fun ctx x y -> translate_int_comparison ctx (fun y x -> Arith.(x <= y)) x y)
824+
(Ge S)
825+
Ge;
826+
register_comparison
827+
"caml_lessthan"
828+
(fun ctx x y -> translate_int_comparison ctx Arith.( < ) x y)
829+
(Lt S)
830+
Lt;
831+
register_comparison
832+
"caml_lessequal"
833+
(fun ctx x y -> translate_int_comparison ctx Arith.( <= ) x y)
834+
(Le S)
835+
Le;
836+
register_comparison
837+
"caml_equal"
838+
(fun ctx x y -> translate_int_equality ctx Arith.( = ) Value.eq x y)
839+
Eq
840+
Eq;
841+
register_comparison
842+
"caml_notequal"
843+
(fun ctx x y -> translate_int_equality ctx Arith.( <> ) Value.neq x y)
844+
Ne
845+
Ne;
846+
register_prim "caml_compare" `Mutable (fun ctx _ l ->
847+
match l with
848+
| [ x; y ] -> (
849+
let x' = transl_prim_arg ctx x in
850+
let y' = transl_prim_arg ctx y in
851+
match get_type ctx x, get_type ctx y with
852+
| Int _, Int _ ->
853+
Arith.(
854+
(Value.int_val y' < Value.int_val x')
855+
- (Value.int_val x' < Value.int_val y'))
856+
| Number Int32, Number Int32 ->
857+
let* f =
858+
register_import
859+
~name:"caml_int32_compare"
860+
(Fun { W.params = [ Type.value; Type.value ]; result = [ I32 ] })
861+
in
862+
let* x' = Memory.unbox_int32 x' in
863+
let* y' = Memory.unbox_int32 y' in
864+
return (W.Call (f, [ x'; y' ]))
865+
| Number Nativeint, Number Nativeint ->
866+
let* f =
867+
register_import
868+
~name:"caml_nativeint_compare"
869+
(Fun (Type.primitive_type 2))
870+
in
871+
let* x' = Memory.unbox_nativeint x' in
872+
let* y' = Memory.unbox_nativeint y' in
873+
return (W.Call (f, [ x'; y' ]))
874+
| Number Int64, Number Int64 ->
875+
let* f =
876+
register_import
877+
~name:"caml_int64_compare"
878+
(Fun { W.params = [ Type.value; Type.value ]; result = [ I32 ] })
879+
in
880+
let* x' = Memory.unbox_int64 x' in
881+
let* y' = Memory.unbox_int64 y' in
882+
return (W.Call (f, [ x'; y' ]))
883+
| Number Float, Number Float ->
884+
let* f =
885+
register_import
886+
~name:"caml_float_compare"
887+
(Fun { W.params = [ Type.value; Type.value ]; result = [ I32 ] })
888+
in
889+
let* x' = Memory.unbox_int64 x' in
890+
let* y' = Memory.unbox_int64 y' in
891+
return (W.Call (f, [ x'; y' ]))
892+
| _ ->
893+
let* f =
894+
register_import
895+
~name:"caml_compare"
896+
(Fun { W.params = [ Type.value; Type.value ]; result = [ I32 ] })
897+
in
898+
let* x' = x' in
899+
let* y' = y' in
900+
return (W.Call (f, [ x'; y' ])))
901+
| _ -> invalid_arity "caml_compare" l ~expected:2)
784902

785903
let rec translate_expr ctx context x e =
786904
match e with

compiler/lib-wasm/typing.ml

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ module Integer = struct
1515
| Unnormalized, _ | _, Unnormalized -> Unnormalized
1616
| Ref, Ref -> Ref
1717
| _ -> Normalized
18+
19+
let sub r r' =
20+
match r, r' with
21+
| _, Unnormalized -> true
22+
| Ref, _ -> true
23+
| Normalized, Normalized -> true
24+
| Unnormalized, (Ref | Normalized) -> false
25+
| Normalized, Ref -> false
1826
end
1927

2028
type boxed_number =
@@ -62,6 +70,21 @@ module Domain = struct
6270
Array.length t = Array.length t' && Array.for_all2 ~f:equal t t'
6371
| (Top | Tuple _ | Int _ | Number _ | Bot), _ -> false
6472

73+
let rec sub t t' =
74+
match t, t' with
75+
| _, Top | Bot, _ -> true
76+
| Top, _ | _, Bot -> false
77+
| Int t, Int t' -> Integer.sub t t'
78+
| Number t, Number t' -> Poly.equal t t'
79+
| Tuple t, Tuple t' ->
80+
Array.length t <= Array.length t'
81+
&&
82+
let rec compare t t' i l =
83+
i = l || (sub t.(i) t'.(i) && compare t t' (i + 1) l)
84+
in
85+
compare t t' 0 (Array.length t)
86+
| (Int _ | Number _ | Tuple _), _ -> false
87+
6588
let bot = Bot
6689

6790
let depth_treshold = 4
@@ -186,11 +209,13 @@ let prim_type ~approx prim args =
186209
| "caml_lessthan"
187210
| "caml_lessequal"
188211
| "caml_equal"
189-
| "caml_compare" -> Int Ref
212+
| "caml_notequal"
213+
| "caml_compare" -> Int Normalized
190214
| "caml_int32_bswap" -> Number Int32
191215
| "caml_nativeint_bswap" -> Number Nativeint
192216
| "caml_int64_bswap" -> Number Int64
193-
| "caml_int32_compare" | "caml_nativeint_compare" | "caml_int64_compare" -> Int Ref
217+
| "caml_int32_compare" | "caml_nativeint_compare" | "caml_int64_compare" ->
218+
Int Normalized
194219
| "caml_string_get32" -> Number Int32
195220
| "caml_string_get64" -> Number Int64
196221
| "caml_bytes_get32" -> Number Int32
@@ -201,7 +226,7 @@ let prim_type ~approx prim args =
201226
| "caml_nextafter_float" -> Number Float
202227
| "caml_classify_float" -> Int Ref
203228
| "caml_ldexp_float" | "caml_erf_float" | "caml_erfc_float" -> Number Float
204-
| "caml_float_compare" -> Int Ref
229+
| "caml_float_compare" -> Int Normalized
205230
| "caml_floatarray_unsafe_get" -> Number Float
206231
| "caml_bytes_unsafe_get"
207232
| "caml_string_unsafe_get"
@@ -414,6 +439,40 @@ let solver st =
414439
in
415440
Solver.f () g (propagate st)
416441

442+
let print_opt typ f e =
443+
match e with
444+
| Prim
445+
( Extern
446+
( "caml_greaterthan"
447+
| "caml_greaterequal"
448+
| "caml_lessthan"
449+
| "caml_lessequal"
450+
| "caml_equal"
451+
| "caml_compare" )
452+
, l ) ->
453+
if
454+
List.exists
455+
~f:(fun t' ->
456+
List.for_all
457+
~f:(fun p ->
458+
let t =
459+
match p with
460+
| Pc c -> constant_type c
461+
| Pv x -> Var.Tbl.get typ x
462+
in
463+
Domain.sub t t')
464+
l)
465+
[ Int Ref
466+
; Int Normalized
467+
; Int Unnormalized
468+
; Number Int32
469+
; Number Int64
470+
; Number Nativeint
471+
; Number Float
472+
]
473+
then Format.fprintf f " OPT"
474+
| _ -> ()
475+
417476
let f ~state ~info ~deadcode_sentinal p =
418477
update_deps state p;
419478
let function_parameters = mark_function_parameters p in
@@ -434,7 +493,8 @@ let f ~state ~info ~deadcode_sentinal p =
434493
Format.err_formatter
435494
(fun _ i ->
436495
match i with
437-
| Instr (Let (x, _)) -> Format.asprintf "{%a}" Domain.print (Var.Tbl.get typ x)
496+
| Instr (Let (x, e)) ->
497+
Format.asprintf "{%a}%a" Domain.print (Var.Tbl.get typ x) (print_opt typ) e
438498
| _ -> "")
439499
p);
440500
typ

runtime/js/compare.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ function caml_compare_val(a, b, total) {
251251
b = b[i];
252252
}
253253
}
254-
//Provides: caml_compare (const, const)
254+
//Provides: caml_compare mutable (const, const)
255255
//Requires: caml_compare_val
256256
function caml_compare(a, b) {
257257
return caml_compare_val(a, b, true);

runtime/wasm/compare.wat

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -556,53 +556,49 @@
556556
(i32.const 0))
557557

558558
(func (export "caml_compare")
559-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
559+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
560560
(local $res i32)
561561
(local.set $res
562562
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 1)))
563563
(if (i32.lt_s (local.get $res) (i32.const 0))
564-
(then (return (ref.i31 (i32.const -1)))))
564+
(then (return (i32.const -1))))
565565
(if (i32.gt_s (local.get $res) (i32.const 0))
566-
(then (return (ref.i31 (i32.const 1)))))
567-
(ref.i31 (i32.const 0)))
566+
(then (return (i32.const 1))))
567+
(i32.const 0))
568568

569569
(func (export "caml_equal")
570-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
571-
(ref.i31
572-
(i32.eqz
573-
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))))
570+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
571+
(i32.eqz
572+
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0))))
574573

575574
(func (export "caml_notequal")
576-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
577-
(ref.i31
578-
(i32.ne (i32.const 0)
579-
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))))
575+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
576+
(i32.ne (i32.const 0)
577+
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0))))
580578

581579
(func (export "caml_lessthan")
582-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
580+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
583581
(local $res i32)
584582
(local.set $res
585583
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))
586-
(ref.i31
587-
(i32.and (i32.lt_s (local.get $res) (i32.const 0))
588-
(i32.ne (local.get $res) (global.get $unordered)))))
584+
(i32.and (i32.lt_s (local.get $res) (i32.const 0))
585+
(i32.ne (local.get $res) (global.get $unordered))))
589586

590587
(func (export "caml_lessequal")
591-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
588+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
592589
(local $res i32)
593590
(local.set $res
594591
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))
595-
(ref.i31
596-
(i32.and (i32.le_s (local.get $res) (i32.const 0))
597-
(i32.ne (local.get $res) (global.get $unordered)))))
592+
(i32.and (i32.le_s (local.get $res) (i32.const 0))
593+
(i32.ne (local.get $res) (global.get $unordered))))
598594

599595
(func (export "caml_greaterthan")
600-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
601-
(ref.i31 (i32.lt_s (i32.const 0)
602-
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))))
596+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
597+
(i32.lt_s (i32.const 0)
598+
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0))))
603599

604600
(func (export "caml_greaterequal")
605-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
606-
(ref.i31 (i32.le_s (i32.const 0)
607-
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))))
601+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
602+
(i32.le_s (i32.const 0)
603+
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0))))
608604
)

0 commit comments

Comments
 (0)