Skip to content

Commit

Permalink
Added missing curried or uncurried syntax variants; syntax_extensions…
Browse files Browse the repository at this point in the history
….md update
  • Loading branch information
lukstafi committed Jan 28, 2025
1 parent b1e31fd commit a7c2053
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 40 deletions.
66 changes: 59 additions & 7 deletions lib/ppx_cd.ml
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ let translate (expr : expression) : result =
( [%expr Shape.Pointwise_tern],
Ast_builder.Default.pexp_extension ~loc
@@ Location.error_extensionf ~loc
"ppx_ocannl %%cd: expected a ternary operator, one of: %s" "where, fma" ))
"ppx_ocannl %%cd: expected a ternary operator, one of: where, fma" ))
in
(* FIXME: collapse these (code reuse) *)
let process_assign_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ?projections ~proj_in_scope
Expand Down Expand Up @@ -882,6 +882,19 @@ let translate (expr : expression) : result =
embedded_nodes = __comment_block.Arrayjit.Assignments.embedded_nodes;
}];
}
| [%expr
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
[%e? lhs]
([%e? { pexp_desc = Pexp_ident { txt = Lident bin_op; _ }; _ }]
([%e? rhs1], [%e? rhs2])
~projections:[%e? projections])]
| [%expr
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
[%e? lhs]
([%e? { pexp_desc = Pexp_ident { txt = Lident bin_op; _ }; _ }]
[%e? rhs1]
[%e? rhs2]
~projections:[%e? projections])]
| [%expr
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
[%e? lhs]
Expand All @@ -896,6 +909,14 @@ let translate (expr : expression) : result =
[%e? lhs]
([%e? { pexp_desc = Pexp_ident { txt = Lident tern_op; _ }; _ }]
([%e? rhs1], [%e? rhs2], [%e? rhs3])
~projections:[%e? projections])]
| [%expr
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
[%e? lhs]
([%e? { pexp_desc = Pexp_ident { txt = Lident tern_op; _ }; _ }]
[%e? rhs1]
[%e? rhs2]
[%e? rhs3]
~projections:[%e? projections])] ->
process_assign_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~projections
~proj_in_scope:true ()
Expand All @@ -908,15 +929,10 @@ let translate (expr : expression) : result =
| [%expr
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
[%e? lhs]
(([%e? { pexp_desc = Pexp_ident { txt = Lident un_op; _ }; _ }] [%e? rhs])
~projections:[%e? projections])]
| [%expr
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
[%e? lhs]
(* FIXME: this was never needed as prefix operators bind tighter? *)
([%e? { pexp_desc = Pexp_ident { txt = Lident un_op; _ }; _ }]
([%e? rhs] ~projections:[%e? projections]))]
when Hashtbl.mem unary_ops un_op ->
(* Handle both un_op priority levels -- where application binds tighter and less tight. *)
process_assign_unop ~accu_op ~lhs ~un_op ~rhs ~projections ~proj_in_scope:true ()
| [%expr
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
Expand All @@ -929,6 +945,13 @@ let translate (expr : expression) : result =
([%e? { pexp_desc = Pexp_ident { txt = Lident bin_op; _ }; _ }]
([%e? rhs1], [%e? rhs2])
~logic:[%e? { pexp_desc = Pexp_constant (Pconst_string (spec, s_loc, _)); _ } as logic])]
| [%expr
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
[%e? lhs]
([%e? { pexp_desc = Pexp_ident { txt = Lident bin_op; _ }; _ }]
[%e? rhs1]
[%e? rhs2]
~logic:[%e? { pexp_desc = Pexp_constant (Pconst_string (spec, s_loc, _)); _ } as logic])]
| [%expr
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
[%e? lhs]
Expand All @@ -951,6 +974,14 @@ let translate (expr : expression) : result =
[%e? lhs]
([%e? { pexp_desc = Pexp_ident { txt = Lident tern_op; _ }; _ }]
([%e? rhs1], [%e? rhs2], [%e? rhs3])
~logic:[%e? { pexp_desc = Pexp_constant (Pconst_string (spec, s_loc, _)); _ }])]
| [%expr
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
[%e? lhs]
([%e? { pexp_desc = Pexp_ident { txt = Lident tern_op; _ }; _ }]
[%e? rhs1]
[%e? rhs2]
[%e? rhs3]
~logic:[%e? { pexp_desc = Pexp_constant (Pconst_string (spec, s_loc, _)); _ }])] ->
let logic =
let loc = s_loc in
Expand All @@ -973,6 +1004,13 @@ let translate (expr : expression) : result =
| [%expr
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
[%e? lhs]
([%e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }]
[%e? rhs]
~logic:[%e? { pexp_desc = Pexp_constant (Pconst_string (spec, s_loc, _)); _ } as logic])]
| [%expr
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
[%e? lhs]
(* FIXME: this was never needed as prefix operators bind tighter? *)
([%e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }]
([%e? rhs]
~logic:
Expand Down Expand Up @@ -1002,6 +1040,13 @@ let translate (expr : expression) : result =
[%e? lhs]
([%e? { pexp_desc = Pexp_ident { txt = Lident tern_op; _ }; _ }]
([%e? rhs1], [%e? rhs2], [%e? rhs3]))]
| [%expr
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
[%e? lhs]
([%e? { pexp_desc = Pexp_ident { txt = Lident tern_op; _ }; _ }]
[%e? rhs1]
[%e? rhs2]
[%e? rhs3])]
when is_assignment accu_op && Hashtbl.mem ternary_ops tern_op && proj_in_scope ->
process_assign_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~proj_in_scope ()
| [%expr
Expand Down Expand Up @@ -1029,6 +1074,13 @@ let translate (expr : expression) : result =
[%e? lhs]
([%e? { pexp_desc = Pexp_ident { txt = Lident tern_op; _ }; _ }]
([%e? rhs1], [%e? rhs2], [%e? rhs3]))]
| [%expr
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
[%e? lhs]
([%e? { pexp_desc = Pexp_ident { txt = Lident tern_op; _ }; _ }]
[%e? rhs1]
[%e? rhs2]
[%e? rhs3])]
when is_assignment accu_op && Hashtbl.mem ternary_ops tern_op ->
let logic, tern_op = ternary_op tern_op in
process_raw_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~logic
Expand Down
9 changes: 9 additions & 0 deletions lib/ppx_op.ml
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,15 @@ let rec translate ~num_configs ~is_toplevel ~has_config ?label expr =
let vbs2, e2 = loop expr2 in
let vbs3, e3 = loop expr3 in
(reduce_vbss [ vbs2; vbs3 ], [%expr [%e e1] [%e e2] [%e e3]])
| [%expr
[%e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ }]
([%e? expr2], [%e? expr3], [%e? expr4])]
when Hashtbl.mem ternary_ops op_ident ->
let e1 = [%expr [%e expr] ?label:[%e opt_expr ~loc label]] in
let vbs2, e2 = loop expr2 in
let vbs3, e3 = loop expr3 in
let vbs4, e4 = loop expr4 in
(reduce_vbss [ vbs2; vbs3; vbs4 ], [%expr [%e e1] [%e e2] [%e e3] [%e e4]])
| [%expr [%e? expr1] [%e? expr2] [%e? expr3]] ->
let vbs1, e1 = loop ?label expr1 in
let vbs2, e2 = loop expr2 in
Expand Down
30 changes: 0 additions & 30 deletions lib/ppx_shared.ml
Original file line number Diff line number Diff line change
Expand Up @@ -114,36 +114,6 @@ let is_assignment ident =
&& Char.equal ident.[0] '='
&& (not @@ List.mem [ "=="; "==="; "=>"; "==>"; "=>>" ] ident ~equal:String.equal)

(* let binary_op expr = (* This and is_binary_op should stay in sync with
Arrayjit.Ops.binop_cd_syntax. *) (* FIXME: get rid of this and use binary_ops table instead. *)
let loc = expr.pexp_loc in match expr with | [%expr ( + )] -> ([%expr Shape.Pointwise_bin],
[%expr Arrayjit.Ops.Add]) | [%expr ( - )] -> ([%expr Shape.Pointwise_bin], [%expr
Arrayjit.Ops.Sub]) | [%expr ( * )] -> ( Ast_builder.Default.pexp_extension ~loc @@
Location.error_extensionf ~loc "No default compose type for binary `*`, try e.g. ~logic:\".\" for
pointwise, %s" "~logic:\"@\" for matrix multiplication", [%expr Arrayjit.Ops.Mul] ) | [%expr ( /
)] -> ( Ast_builder.Default.pexp_extension ~loc @@ Location.error_extensionf ~loc "For clarity,
no default compose type for binary `/`, use ~logic:\".\" for pointwise \ division", [%expr
Arrayjit.Ops.Div] ) | [%expr ( ** )] -> ([%expr Shape.Pointwise_bin], [%expr
Arrayjit.Ops.ToPowOf]) | [%expr ( -?/ )] -> ([%expr Shape.Pointwise_bin], [%expr
Arrayjit.Ops.Relu_gate]) | [%expr ( -/> )] -> ([%expr Shape.Pointwise_bin], [%expr
Arrayjit.Ops.Arg2]) | [%expr ( -@> )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Arg1])
| [%expr ( < )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmplt]) | [%expr ( <> )] ->
([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmpne]) | [%expr ( || )] -> ([%expr
Shape.Pointwise_bin], [%expr Arrayjit.Ops.Or]) | [%expr ( && )] -> ([%expr Shape.Pointwise_bin],
[%expr Arrayjit.Ops.And]) | [%expr ( % )] -> ([%expr Shape.Pointwise_bin], [%expr
Arrayjit.Ops.Mod]) | [%expr ( @^ )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Max]) |
[%expr ( ^^ )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Min]) | _ -> ( [%expr
Shape.Pointwise_bin], Ast_builder.Default.pexp_extension ~loc @@ Location.error_extensionf ~loc
"ppx_ocannl %%cd: expected a binary operator, one of: %s" "+ (Add), - (Sub), * (Mul), / (Div), **
(ToPowOf), -?/ (Relu_gate), -/> (Arg2), < \ (Cmplt), <> (Cmpne), || (Or), && (And), % (Mod), @^
(Max), ^^ (Min)" ) *)
(* let ternary_op expr = let loc =
expr.pexp_loc in match expr with | [%expr where] -> ([%expr Shape.Pointwise_tern], [%expr
Arrayjit.Ops.Where]) | [%expr fma] -> ([%expr Shape.Compose_accumulate], [%expr
Arrayjit.Ops.FMA]) | _ -> ( [%expr Shape.Pointwise_bin], Ast_builder.Default.pexp_extension ~loc
@@ Location.error_extensionf ~loc "ppx_ocannl %%cd: expected a ternary operator, one of: %s"
"where, fma" ) *)

(** Binary primitive ops, both infix operator and function name variants. *)
let binary_ops =
Hashtbl.of_alist_exn
Expand Down
57 changes: 54 additions & 3 deletions lib/syntax_extensions.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

- Table of contents
- [Preliminaries](#preliminaries)
- [Primitive operations](#primitive-operations)
- [The syntax for %op](#the-syntax-for-op)
- [The syntax for %cd](#the-syntax-for-cd)
- [Numeric and N-dimensional array literals](#numeric-and-n-dimensional-array-literals)
Expand All @@ -26,7 +27,7 @@

## Preliminaries

OCANNL, and arrayjit specifically, is built around a fixed number of numeric operations, declared in `arrayjit/ops.ml`. We assign lexical operators to many of the operations, inventing novel operators if needed. For example, Rectified Linear Unit `Relu` operation, which computes `f(x) = max(0,x)`, gets the operator `relu`, and the ReLU-Gate `Relu_gate` operation, which computes `f(x,y) = if x > 0.0 then y else 0.0`, gets the operator `-?/`. These built-in numeric operations are used to construct assignments (`Assignments.t` packaged as `Assignments.comp`). The syntax `%cd` is needed to build assignments concisely. On the other hand, while the syntax `%op` helps build tensors (`Tensor.t`), they can be expressed concisely in pure OCaml. Unlike for assignments, the building blocks for tensor expressions are easy to extend. The meaningful basic ones are provided in `lib/operation.ml`.
OCANNL, and arrayjit specifically, is built around a fixed number of numeric operations, declared in `arrayjit/ops.ml`. We assign lexical operators to the binary operations, inventing novel operators if needed. For example, Rectified Linear Unit `Relu` operation, which computes `f(x) = max(0,x)`, is called `relu`, while the ReLU-Gate `Relu_gate` operation, which computes `f(x,y) = if x > 0.0 then y else 0.0`, gets the operator `-?/` in addition to name `relu_gate`. These built-in numeric operations are used to construct assignments (`Assignments.t` packaged as `Assignments.comp`). The syntax `%cd` is needed to build assignments concisely, and the assignment operators always start with `=` (unlike in C where they end with `=`). On the other hand, while the syntax `%op` helps build tensors (`Tensor.t`), they can be expressed concisely in pure OCaml. Unlike for assignments, the building blocks for tensor expressions are easy to extend. The meaningful basic ones are provided in `lib/operation.ml`.

In OCANNL, we call a tensor that is prohibited from propagating gradients, does not have a gradient node nor backprop code, a _non-differentiable tensor_. Accordingly we can call the "plain" tensors with a gradient node _differentiable tensors_. Expressions in the `%cd` syntax will sometimes build new non-differentiable tensors as components of assignments (they will never build new differentiable tensors). The syntax extensions make the following assumption:

Expand All @@ -37,6 +38,56 @@ Functions inside `Operation.NTDSL` use `~grad_spec:Prohibit_grad` when calling i

The extension points open `NTDSL.O`, resp. `TDSL.O`, for the scope of the extension point, to expose the corresponding operators.

## Primitive operations

To accomodate stylistic preferences, OCANNL supports both curried and uncurried syntaxes for primitive operation application. Binary operators are associated with infix operators, in addition to having alphabetic identifiers. This stems from the following restriction: in the `%cd` syntax, the assignment is always an infix operator, and it needs to pick the accumulation operation.

The unary primitive operations:

| Identifier | Default projection | Constructor in `Arrayjit.Ops` |
|------------|--------------------|-------------|
| `id` | pointwise | `Identity` |
| `relu` | pointwise | `Relu` |
| `sat01` | pointwise | `Satur01` |
| `exp` | pointwise | `Exp` |
| `log` | pointwise | `Log` |
| `exp2` | pointwise | `Exp2` |
| `log2` | pointwise | `Log2` |
| `sin` | pointwise | `Sin` |
| `cos` | pointwise | `Cos` |
| `sqrt` | pointwise | `Sqrt` |
| `recip` | pointwise | `Recip` |
| `recip_sqrt` | pointwise | `Recip_sqrt` |
| `neg` | pointwise | `Neg` |
| `tanh` | pointwise | `Tanh_approx` |

The binary primitive operations:

| Identifier | Infix operator | Default projection | Constructor in `Arrayjit.Ops` | Assignments |
|------------|----------------|--------------------|-------------|-------------|
| `fst` | `-@>` | pointwise | `Arg1` | none |
| `snd` | `-/>` | pointwise | `Arg2` | `=:` |
| `add` | `+` | pointwise | `Add` | `=+`, `=:+` |
| `sub` | `-` | pointwise | `Sub` | `=-`, `=:-` |
| `mul` | `*` | none | `Mul` | `=*`, `=:*` |
| `div` | `/` | none | `Div` | `=/`, `=:/` |
| `pow` | `**` | pointwise | `ToPowOf` | `=**`, `=:**` |
| `relu_gate` | `-?/` | pointwise | `Relu_gate` | `=?/`, `=:?/` |
| `lt` | `<` | pointwise | `Cmplt` | none |
| `ne` | `<>` | pointwise | `Cmpne` | none |
| `or_` | `\|\|` | pointwise | `Or` | `=\|\|`, `=:\|\|` |
| `and_` | `&&` | pointwise | `And` | `=&&`, `=:&&` |
| `mod_` | `%` | pointwise | `Mod` | `=%`, `=:%` |
| `max` | `@^` | pointwise | `Max` | `=@^`, `=:@^` |
| `min` | `^^` | pointwise | `Min` | `=^^`, `=:^^` |

The ternary primitive operations:

| Identifier | Default projection | Constructor in `Arrayjit.Ops` |
|------------|--------------------|-------------|
| `where` | pointwise | `Where` |
| `fma` | compose-accumulate | `FMA` |

## The syntax for %op

The `%op` syntax is simpler than the `%cd` syntax since it relies more on regular OCaml expressions. For example, we can write without syntax extensions:
Expand Down Expand Up @@ -99,9 +150,9 @@ type Assignments.t =

For example the binary case in pseudocode: `if initialize_neutral then lhs = 0; lhs = lhs accum (rhs1 op rhs2)` (assuming the neutral element of `accum` is 0). The representation also has a field `projections` which determines which loops should be run and how the tensor nodes should be indexed to perform the computation.

The basic `%cd` syntax for binary operator assignments has the form: `<lhs> <asgn-op> <rhs1> <op> <rhs2>` (or `<lhs> <asgn-op> <op> <rhs1> <rhs2>` when `<op>` is not an operator). The binary operators in the `<rhs1> <op> <rhs2>` part have a straightfowrad syntax: `<op>` is one of `+`, `-`, `*`, `/`, `**` (to-power-of), `-?/` (ReLU-Gate). `<asgn-op>` starts with `=`, followed by `:` only if `initialize_neutral` is true, then followed by one of `+`, `-`, `*`, `/`, `**`, `relu`. The fields `<lhs>`, `<rhs1>`, `<rhs2>` will often be either special-purpose identifiers (e.g. `t`, `t1`, `t2`, `g`, `g1`, `g2`) or identifiers bound to tensors. `<rhs1>`, `<rsh2>` will also often be (non-differentiable) tensor expressions. The notation `<tensor>.grad` stands for the gradient node of the given tensor. For more about "slot fillers", and to learn about the operators `*+` and `++`, see the section [further features of the syntax extension %cd](#further-features-of-the-syntax-extension-cd).
The basic `%cd` syntax for assignments has the form: `<lhs> <asgn-op> <primitive-op-application[rhs1, rhs2?, rhs3?]>`. See [Primitive operations](#primitive-operations) for the syntax of primitive operation application, where `<rhs1>`, `<rhs2>` (for binary and ternary ops), `<rhs3>` (for ternary ops) are subexpressions. `<asgn-op>` starts with `=`, followed by `:` only if `initialize_neutral` is true, then followed by the operator syntax variant of a binary primitive operation. The fields `<lhs>`, `<rhs1>`, `<rhs2>`, `<rhs3>` will often be either special-purpose identifiers (e.g. `t`, `t1`, `t2`, `t3`, `g`, `g1`, `g2`, `g3`) or identifiers bound to tensors. `<rhs1>`, `<rsh2>`, `<rsh3>` will also often be (non-differentiable) tensor expressions. The notation `<tensor>.grad` stands for the gradient node of the given tensor. For more about "slot fillers", and to learn about the operators `*+` and `++`, see the section [further features of the syntax extension %cd](#further-features-of-the-syntax-extension-cd).

How is the `projections` field determined? `projections` can be given explicitly as a labeled argument `~projections`. If they aren't but `%cd` realizes there is a `~projections` parameter in scope, it uses it -- see `lib/operation.ml` where this option is used to define tensor operations. If instead of `~projections` a `~logic` labeled argument is given, the string passed is used to determine projections. `~logic:"."` means a pointwise operation. `~logic:"@"` means an "output axes of rhs2 match input axes of rhs1" operation (matrix multiplication is a special case). `~logic:"T"` means transpose of input and output axes. The string passed to `~logic` can also use OCANNL's generalization of the einsum notation, allowing arbitrary permutations and reductions of axes. If no information is given, the default is a pointwise operation.
How is the `projections` field determined? `projections` can be given explicitly as a labeled argument `~projections`. If they aren't but `%cd` realizes there is a `~projections` parameter in scope, it uses it -- see `lib/operation.ml` where this option is used to define tensor operations. If instead of `~projections` a `~logic` labeled argument is given, the string passed is used to determine projections. `~logic:"."` means a pointwise operation. `~logic:"@"` means an "output axes of rhs2 match input axes of rhs1" operation (matrix multiplication is a special case). `~logic:"T"` means transpose of input and output axes. The string passed to `~logic` can also use OCANNL's generalization of the einsum notation, allowing arbitrary permutations and reductions of axes. If no information is given, the default depends on the primitive operation, but it is almost always a pointwise operation.

Here we see an example of tensor multiplication -- extending matrix multiplication to arbitrary number of axes -- multiplying `a` by `b` to get `c`. In `=:+`, `=` is required to separate the assigned-to part from the computation, `:` clears-out `c` before the computation, `+` selects addition to accumulate the results.

Expand Down

0 comments on commit a7c2053

Please sign in to comment.