Skip to content

Commit

Permalink
Tiny update to syntax_extensions.md
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Jan 29, 2025
1 parent f279856 commit af79521
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
3 changes: 2 additions & 1 deletion arrayjit/lib/ops.ml
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ let interpret_binop op v1 v2 =
| Sub -> v1 - v2
| Mul -> v1 * v2
| Div -> v1 / v2
| ToPowOf -> if is_integer v2 then int_pow v1 @@ to_int v2 else v1 ** v2
| ToPowOf when is_integer v2 -> int_pow v1 @@ to_int v2
| ToPowOf -> v1 ** v2
| Relu_gate -> if v1 > 0.0 then v2 else 0.0
| Max -> max v1 v2
| Min -> min v1 v2
Expand Down
51 changes: 51 additions & 0 deletions lib/syntax_extensions.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ The unary primitive operations:
| `recip_sqrt` | pointwise | `Recip_sqrt` |
| `neg` | pointwise | `Neg` |
| `tanh` | pointwise | `Tanh_approx` |
| `not` | pointwise | `Not` |

The binary primitive operations:

Expand Down Expand Up @@ -88,6 +89,56 @@ The ternary primitive operations:
| `where` | pointwise | `Where` |
| `fma` | compose-accumulate | `FMA` |

The interpretation functions also state the semantics:

```ocaml
let interpret_unop op v =
let open Float in
match op with
| Identity -> v
| Relu when v >= 0. -> v
| Relu -> 0.
| Satur01 when v <= 0. -> 0.
| Satur01 when v >= 1. -> 1.
| Satur01 -> v
| Exp -> exp v
| Log -> log v
| Exp2 -> 2. ** v
| Log2 -> log v / log 2.
| Sin -> sin v
| Cos -> cos v
| Sqrt -> sqrt v
| Recip -> 1. / v
| Recip_sqrt -> 1. / sqrt v
| Neg -> ~-.v
| Tanh_approx -> tanh v
| Not -> if v = 0. then 1. else 0.
let interpret_binop op v1 v2 =
let open Float in
match op with
| Arg1 -> v1
| Arg2 -> v2
| Add -> v1 + v2
| Sub -> v1 - v2
| Mul -> v1 * v2
| Div -> v1 / v2
| ToPowOf when is_integer v2 -> int_pow v1 @@ to_int v2
| ToPowOf -> v1 ** v2
| Relu_gate -> if v1 > 0.0 then v2 else 0.0
| Max -> max v1 v2
| Min -> min v1 v2
| Mod -> v1 % v2
| Cmplt -> if v1 < v2 then 1. else 0.
| Cmpeq -> if v1 = v2 then 1. else 0.
| Or -> if v1 <> 0. || v2 <> 0. then 1. else 0.
| And -> if v1 <> 0. && v2 <> 0. then 1. else 0.
let interpret_ternop op v1 v2 v3 =
let open Float in
match op with Where -> if v1 <> 0. then v2 else v3 | FMA -> (v1 * v2) + v3
```

## The syntax for %op

The `%op` syntax is simpler than the `%cd` syntax since it relies more on regular OCaml expressions. For example, we can write without syntax extensions:
Expand Down

0 comments on commit af79521

Please sign in to comment.