Skip to content

Commit

Permalink
Untested: infrastructure for ternary operations
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Jan 27, 2025
1 parent 842daaa commit 9682b4a
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 14 deletions.
162 changes: 161 additions & 1 deletion lib/ppx_cd.ml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type expr_type =

let is_unknown = function Unknown -> true | _ -> false

type projections_slot = LHS | RHS1 | RHS2 | Nonslot | Undet [@@deriving equal, sexp]
type projections_slot = LHS | RHS1 | RHS2 | RHS3 | Nonslot | Undet [@@deriving equal, sexp]

let assignment_op expr =
(* This should stay in sync with Arrayjit.Ops.assign_op_cd_syntax. *)
Expand Down Expand Up @@ -72,6 +72,7 @@ let assignment_op expr =

let binary_op expr =
(* This and is_binary_op should stay in sync with Arrayjit.Ops.binop_cd_syntax. *)
(* FIXME: get rid of this and use binary_ops table instead. *)
let loc = expr.pexp_loc in
match expr with
| [%expr ( + )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Add])
Expand Down Expand Up @@ -106,6 +107,18 @@ let binary_op expr =
"+ (Add), - (Sub), * (Mul), / (Div), ** (ToPowOf), -?/ (Relu_gate), -/> (Arg2), < \
(Cmplt), <> (Cmpne), || (Or), && (And), % (Mod), @^ (Max), ^^ (Min)" )

let ternary_op expr =
(* FIXME: get rid of this and use ternary_ops table instead. *)
let loc = expr.pexp_loc in
match expr with
| [%expr where] -> ([%expr Shape.Pointwise_tern], [%expr Arrayjit.Ops.Where])
| [%expr fma] -> ([%expr Shape.Compose_accumulate], [%expr Arrayjit.Ops.FMA])
| _ ->
( [%expr Shape.Pointwise_bin],
Ast_builder.Default.pexp_extension ~loc
@@ Location.error_extensionf ~loc "ppx_ocannl %%cd: expected a ternary operator, one of: %s"
"where, fma" )

type result = {
vbs : value_binding Map.M(String).t;
(** [vbs] are the bindings introduced by inline tensor declarations (aka. punning). These
Expand Down Expand Up @@ -206,6 +219,7 @@ let project_p_slot debug loc slot =
| LHS -> [%expr p.project_lhs]
| RHS1 -> [%expr p.project_rhs.(0)]
| RHS2 -> [%expr p.project_rhs.(1)]
| RHS3 -> [%expr p.project_rhs.(2)]
| Nonslot ->
Ast_builder.Default.pexp_extension ~loc
@@ Location.error_extensionf ~loc
Expand All @@ -221,6 +235,7 @@ let project_p_dims debug loc slot =
| LHS -> [%expr p.lhs_dims]
| RHS1 -> [%expr p.rhs_dims.(0)]
| RHS2 -> [%expr p.rhs_dims.(1)]
| RHS3 -> [%expr p.rhs_dims.(2)]
| Nonslot ->
Ast_builder.Default.pexp_extension ~loc
@@ Location.error_extensionf ~loc
Expand Down Expand Up @@ -344,6 +359,7 @@ let setup_array ~punned ~bad_pun_hints ~is_lhs
| LHS -> [%pat? nondiff__lhs]
| RHS1 -> [%pat? nondiff__rhs1]
| RHS2 -> [%pat? nondiff__rhs2]
| RHS3 -> [%pat? nondiff__rhs3]
| Nonslot | Undet -> [%pat? nondiff__tensor]
in
let t = pat2expr v in
Expand Down Expand Up @@ -444,6 +460,74 @@ let translate (expr : expression) : result =
{ vbs = no_vbs; typ = Tensor; slot = Undet; expr; array_opt_of_code = None }
in
let loop = transl ~bad_pun_hints in
(* FIXME: collapse these (code reuse) *)
let process_assign_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ?projections ~proj_in_scope
() =
let initialize_neutral, accu_op = assignment_op accu_op in
let setup_l =
setup_array ~punned ~bad_pun_hints ~is_lhs:true @@ loop ~proj_in_scope:true lhs
in
let _, tern_op = ternary_op tern_op in
let setup_r1 = setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs1 in
let setup_r2 = setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs2 in
let setup_r3 = setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs3 in
let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in
let projections =
match projections with
| Some prjs -> prjs
| None ->
let lhs_dims = project_p_dims "LHS" lhs.pexp_loc setup_l.slot in
let rhs1_dims = project_p_dims "RHS1" lhs.pexp_loc setup_r1.slot in
let rhs2_dims = project_p_dims "RHS2" lhs.pexp_loc setup_r2.slot in
let rhs3_dims = project_p_dims "RHS3" lhs.pexp_loc setup_r3.slot in
let project_lhs = project_p_slot "LHS" lhs.pexp_loc setup_l.slot in
let project_rhs1 = project_p_slot "RHS1" rhs1.pexp_loc setup_r1.slot in
let project_rhs2 = project_p_slot "RHS2" rhs2.pexp_loc setup_r2.slot in
let project_rhs3 = project_p_slot "RHS3" rhs3.pexp_loc setup_r3.slot in
[%expr
lazy
(let p = Lazy.force projections in
Arrayjit.Indexing.
{
product_space = p.product_space;
product_iterators = p.product_iterators;
lhs_dims = [%e lhs_dims];
rhs_dims = [| [%e rhs1_dims]; [%e rhs2_dims]; [%e rhs3_dims] |];
project_lhs = [%e project_lhs];
project_rhs = [| [%e project_rhs1]; [%e project_rhs2]; [%e project_rhs3] |];
debug_info =
{
p.debug_info with
trace =
( "ppx_cd " ^ [%e expr2string_or_empty accu_op] ^ " "
^ [%e expr2string_or_empty tern_op],
Arrayjit.Indexing.unique_debug_id () )
:: p.debug_info.trace;
};
})]
in
(* TODO: might be better to treat missing [rhs1, rhs2, rhs3] as zeros or errors rather than
eliding the code. *)
let body =
[%expr
Option.value ~default:Arrayjit.Assignments.Noop
@@ Option.map [%e setup_l.array_opt] ~f:(fun lhs ->
Option.map3 [%e setup_r1.array_opt] [%e setup_r2.array_opt] [%e setup_r2.array_opt]
~f:(fun rhs1 rhs2 rhs3 ->
Arrayjit.Assignments.Accum_ternop
{
initialize_neutral = [%e initialize_neutral];
accum = [%e accu_op];
lhs;
op = [%e tern_op];
rhs1;
rhs2;
rhs3;
projections = [%e projections];
}))]
in
assignment ~punned ~lhs:setup_l ~rhses:[ setup_r1; setup_r2; setup_r3 ] body
in
let process_assign_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ?projections ~proj_in_scope () =
let initialize_neutral, accu_op = assignment_op accu_op in
let setup_l =
Expand Down Expand Up @@ -561,6 +645,27 @@ let translate (expr : expression) : result =
in
assignment ~punned ~lhs:setup_l ~rhses:[ setup_r ] body
in
let process_raw_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~logic =
let initialize_neutral, accu_op = assignment_op accu_op in
let setup_l = setup_array ~punned ~bad_pun_hints ~is_lhs:true @@ loop ~proj_in_scope lhs in
let setup_r1 = setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs1 in
let setup_r2 = setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs2 in
let setup_r3 = setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs3 in
let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in
let t_expr, lhs_is_grad, _ = args_for ~loc setup_l in
let t1_expr, rhs1_is_grad, rhs1_is_merge = args_for ~loc setup_r1 in
let t2_expr, rhs2_is_grad, rhs2_is_merge = args_for ~loc setup_r2 in
let t3_expr, rhs3_is_grad, rhs3_is_merge = args_for ~loc setup_r3 in
let body =
[%expr
Tensor.raw_ternop ~initialize_neutral:[%e initialize_neutral] ~accum:[%e accu_op]
~t:[%e t_expr] ~lhs_is_grad:[%e lhs_is_grad] ~op:[%e tern_op] ~t1:[%e t1_expr]
~rhs1_is_grad:[%e rhs1_is_grad] ~rhs1_is_merge:[%e rhs1_is_merge] ~t2:[%e t2_expr]
~rhs2_is_grad:[%e rhs2_is_grad] ~rhs2_is_merge:[%e rhs2_is_merge] ~t3:[%e t3_expr]
~rhs3_is_grad:[%e rhs3_is_grad] ~rhs3_is_merge:[%e rhs3_is_merge] ~logic:[%e logic]]
in
assignment ~punned ~lhs:setup_l ~rhses:[ setup_r1; setup_r2; setup_r3 ] body
in
let process_raw_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~logic =
let initialize_neutral, accu_op = assignment_op accu_op in
let setup_l = setup_array ~punned ~bad_pun_hints ~is_lhs:true @@ loop ~proj_in_scope lhs in
Expand Down Expand Up @@ -655,6 +760,19 @@ let translate (expr : expression) : result =
slot = RHS2;
expr = [%expr Option.map t2.Tensor.diff ~f:(fun d -> d.Tensor.grad)];
}
| { pexp_desc = Pexp_ident { txt = Lident "rhs3"; _ }; _ } ->
{ default_result with typ = Array; slot = RHS3 }
| { pexp_desc = Pexp_ident { txt = Lident "t3"; _ }; _ } ->
{ default_result with typ = Tensor; slot = RHS3 }
| { pexp_desc = Pexp_ident { txt = Lident "v3"; _ }; _ } ->
{ default_result with typ = Array; slot = RHS3; expr = [%expr t3.Tensor.value] }
| { pexp_desc = Pexp_ident { txt = Lident "g3"; _ }; _ } ->
{
default_result with
typ = Grad_of_tensor [%expr t3];
slot = RHS3;
expr = [%expr Option.map t3.Tensor.diff ~f:(fun d -> d.Tensor.grad)];
}
| { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ } when is_primitive_op op_ident ->
default_result
| [%expr [%e? expr1] **. [%e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->
Expand Down Expand Up @@ -811,7 +929,15 @@ let translate (expr : expression) : result =
[%e? accu_op]
[%e? lhs]
([%e? bin_op] [%e? rhs1] ([%e? rhs2] ~projections:[%e? projections]))] ->
(* Note: when clause not needed here and below, it's an error if bin_op is not a primitive
binary op. *)
process_assign_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~projections ~proj_in_scope:true ()
| [%expr
[%e? accu_op]
[%e? lhs]
([%e? tern_op] ([%e? rhs1], [%e? rhs2], [%e? rhs3]) ~projections:[%e? projections])] ->
process_assign_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~projections
~proj_in_scope:true ()
| [%expr
[%e? accu_op]
[%e? lhs]
Expand Down Expand Up @@ -852,6 +978,25 @@ let translate (expr : expression) : result =
in
let _, bin_op = binary_op bin_op in
process_raw_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~logic
| [%expr
[%e? accu_op]
[%e? lhs]
([%e? tern_op]
([%e? rhs1], [%e? rhs2], [%e? rhs3])
~logic:[%e? { pexp_desc = Pexp_constant (Pconst_string (spec, s_loc, _)); _ }])] ->
let logic =
let loc = s_loc in
if String.equal spec "." then [%expr Shape.Pointwise_bin]
else if String.equal spec "@" then [%expr Shape.Compose]
else
Ast_builder.Default.pexp_extension ~loc
@@ Location.error_extensionf ~loc
"ppx_ocannl %%cd: expected <.> or <@>, found <%s> -- einsum notation for ternary \
operators not supported yet, see issue #305"
spec
in
let _, tern_op = binary_op tern_op in
process_raw_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~logic
| [%expr
[%e? accu_op]
[%e? lhs]
Expand Down Expand Up @@ -882,6 +1027,13 @@ let translate (expr : expression) : result =
[%e? rhs2])]
when is_assignment accu_ident && Hashtbl.mem binary_ops binop_ident && proj_in_scope ->
process_assign_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~proj_in_scope ()
| [%expr
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
[%e? lhs]
([%e? { pexp_desc = Pexp_ident { txt = Lident ternop_ident; _ }; _ } as tern_op]
([%e? rhs1], [%e? rhs2], [%e? rhs3]))]
when is_assignment accu_ident && Hashtbl.mem ternary_ops ternop_ident && proj_in_scope ->
process_assign_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~proj_in_scope ()
| [%expr
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
[%e? lhs]
Expand All @@ -905,6 +1057,14 @@ let translate (expr : expression) : result =
when is_assignment accu_ident && Hashtbl.mem binary_ops binop_ident ->
let logic, bin_op = binary_op bin_op in
process_raw_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~logic
| [%expr
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
[%e? lhs]
([%e? { pexp_desc = Pexp_ident { txt = Lident ternop_ident; _ }; _ } as tern_op]
([%e? rhs1], [%e? rhs2], [%e? rhs3]))]
when is_assignment accu_ident && Hashtbl.mem ternary_ops ternop_ident ->
let logic, tern_op = ternary_op tern_op in
process_raw_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~logic
| [%expr
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
[%e? lhs]
Expand Down
56 changes: 43 additions & 13 deletions lib/shape.ml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ type transpose_type =
| Batch_slice of Arrayjit.Indexing.static_symbol
[@@deriving equal, sexp]

type ternary_type = Pointwise_tern | Compose_accumulate [@@deriving sexp, equal]

let identifier_multichar = Angstrom.take_while1 Char.is_alphanum

let opt_separators : _ Angstrom.t =
Expand Down Expand Up @@ -203,26 +205,19 @@ let einsum_of_spec spec =
| Error msg ->
raise @@ Utils.User_error ("Shape.einsum_of_spec: while parsing: " ^ spec ^ " error: " ^ msg)

(** How to propagate shape updates and do the last update of [Tensor.t.shape] when finalizing the
tensor. Axes are broadcast-expanded on a bottom-up update to fit the incoming shape. *)
type logic =
| Broadcast of compose_type * t * t
(** Matches the shapes for a binary operation.
For [Broadcast (Einsum (ls1, ls2, ls3), s1, s2)], the labels of [s1] and [s2] must match
according to the [ls1], [ls2] lineup, and the resulting shape inherits the labels
according to the [ls3] lineup. *)
| Transpose of transpose_type * t
(** Permutes the axes of a shape. One case of [Transpose] is to swap inputs with outputs of
[s1], hence the name. *)
| Broadcast_tern of ternary_type * t * t * t
| Terminal of Arrayjit.Ops.init_op
(** Extracts any available shape information from the initialization. E.g. for
[File_mapped fn], opens the file [fn] to check its length. *)
[@@deriving equal, sexp]

let logic_to_spec = function
| Broadcast (Pointwise_bin, _, _) | Transpose (Pointwise_un, _) -> "."
| Broadcast (Compose, _, _) -> "@"
| Broadcast (Pointwise_bin, _, _)
| Transpose (Pointwise_un, _)
| Broadcast_tern (Pointwise_tern, _, _, _) ->
"."
| Broadcast (Compose, _, _) | Broadcast_tern (Compose_accumulate, _, _, _) -> "@"
| Broadcast (Einsum spec, _, _) | Transpose (Permute spec, _) -> spec
| Transpose (Transpose, _) -> "T"
| Transpose (Batch_slice _, _) -> "@|"
Expand Down Expand Up @@ -430,6 +425,31 @@ let get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : update_step) :
Row_ineq { cur = cur_sh.output; subr = sh1.output };
Row_ineq { cur = cur_sh.output; subr = sh2.output };
] )
| Broadcast_tern (Compose_accumulate, sh1, sh2, sh3) ->
( Row.dim_map_empty,
[
Row_ineq { cur = sh1.input; subr = sh2.output };
Row_ineq { cur = cur_sh.batch; subr = sh1.batch };
Row_ineq { cur = cur_sh.batch; subr = sh2.batch };
Row_ineq { cur = cur_sh.input; subr = sh2.input };
Row_ineq { cur = cur_sh.output; subr = sh1.output };
Row_ineq { cur = cur_sh.batch; subr = sh3.batch };
Row_ineq { cur = cur_sh.input; subr = sh3.input };
Row_ineq { cur = cur_sh.output; subr = sh3.output };
] )
| Broadcast_tern (Pointwise_tern, sh1, sh2, sh3) ->
( Row.dim_map_empty,
[
Row_ineq { cur = cur_sh.batch; subr = sh1.batch };
Row_ineq { cur = cur_sh.batch; subr = sh2.batch };
Row_ineq { cur = cur_sh.batch; subr = sh3.batch };
Row_ineq { cur = cur_sh.input; subr = sh1.input };
Row_ineq { cur = cur_sh.input; subr = sh2.input };
Row_ineq { cur = cur_sh.input; subr = sh3.input };
Row_ineq { cur = cur_sh.output; subr = sh1.output };
Row_ineq { cur = cur_sh.output; subr = sh2.output };
Row_ineq { cur = cur_sh.output; subr = sh3.output };
] )
| Transpose (Batch_slice { static_range; static_symbol }, sh) ->
let slice_v = get_var () in
let slice_var = Var slice_v in
Expand Down Expand Up @@ -553,6 +573,10 @@ let iter_shapes update_step ~f =
| Broadcast (_, sh1, sh2) ->
f sh1;
f sh2
| Broadcast_tern (_, sh1, sh2, sh3) ->
f sh1;
f sh2;
f sh3

let all_rows update_step =
let rows_sh sh = [ sh.batch; sh.input; sh.output ] in
Expand All @@ -562,6 +586,7 @@ let all_rows update_step =
| Terminal _ -> []
| Transpose (_, sh1) -> rows_sh sh1
| Broadcast (_, sh1, sh2) -> rows_sh sh1 @ rows_sh sh2
| Broadcast_tern (_, sh1, sh2, sh3) -> rows_sh sh1 @ rows_sh sh2 @ rows_sh sh3

let apply_env_t env sh =
sh.batch <- Row.subst_row env sh.batch;
Expand Down Expand Up @@ -661,6 +686,10 @@ let fresh_proj_ids update =
| Broadcast (_, sh1, sh2) ->
fresh_shape sh1;
fresh_shape sh2
| Broadcast_tern (_, sh1, sh2, sh3) ->
fresh_shape sh1;
fresh_shape sh2;
fresh_shape sh3

(** Computes the indexing into subtensors given the shape information of a tensor.
[derive_projections] should only be invoked when the shapes are fully inferred already! *)
Expand Down Expand Up @@ -692,6 +721,7 @@ let derive_projections (update_step : update_step) : Idx.projections =
| Terminal _ -> []
| Transpose (_, sh) -> [ sh ]
| Broadcast (_, sh1, sh2) -> [ sh1; sh2 ]
| Broadcast_tern (_, sh1, sh2, sh3) -> [ sh1; sh2; sh3 ]
in
let lhs_dims = to_dims lhs in
let rhs_dims = Array.of_list_map ~f:to_dims rhs in
Expand Down
7 changes: 7 additions & 0 deletions lib/shape.mli
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ type transpose_type =
| Batch_slice of Arrayjit.Indexing.static_symbol (** Removes the leftmost batch axis. *)
[@@deriving equal, sexp]

(** If you miss expressivity here, leave a note on {!{https://github.com/ahrefs/ocannl/issues/305}issue 305}. *)
type ternary_type =
| Pointwise_tern (** As in the operation [Where]. *)
| Compose_accumulate (** As in the operation [FMA]. *)
[@@deriving equal, sexp]

val make :
?batch_dims:int list ->
?input_dims:int list ->
Expand Down Expand Up @@ -123,6 +129,7 @@ type logic =
| Transpose of transpose_type * t
(** Permutes the axes of a shape. One case of [Transpose] is to swap inputs with outputs of
[s1], hence the name. *)
| Broadcast_tern of ternary_type * t * t * t (** Matches the shapes for a ternary operation. *)
| Terminal of Arrayjit.Ops.init_op
(** Extracts any available shape information from the initialization. E.g. for
[File_mapped fn], opens the file [fn] to check its length. *)
Expand Down
5 changes: 5 additions & 0 deletions lib/tensor.ml
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,11 @@ let binop ~label ?compose_op ~op_asn ~grad_asn ?grad_spec t1 t2 =
let grad_asn ~v ~g ~projections = grad_asn ~v ~g ~t1 ~t2 ~projections in
op ~label ?compose_op ?transpose_op:None ~op_asn ~grad_asn ?grad_spec (Shape.make ()) [ t1; t2 ]

let ternop ~label ?compose_op ~op_asn ~grad_asn ?grad_spec t1 t2 t3 =
let op_asn ~v ~projections = op_asn ~v ~t1 ~t2 ~t3 ~projections in
let grad_asn ~v ~g ~projections = grad_asn ~v ~g ~t1 ~t2 ~t3 ~projections in
op ~label ?compose_op ?transpose_op:None ~op_asn ~grad_asn ?grad_spec (Shape.make ()) [ t1; t2; t3 ]

let unop ~label ?transpose_op ~op_asn ~grad_asn ?grad_spec t1 =
let op_asn ~v ~projections = op_asn ~v ~t1 ~projections in
let grad_asn ~v ~g ~projections = grad_asn ~v ~g ~t1 ~projections in
Expand Down
Loading

0 comments on commit 9682b4a

Please sign in to comment.