Replies: 2 comments 2 replies
-
There's no precise set of rules for when to make a new Primitive and when to use existing ops. Generally the default is to use existing ops unless there is a really good need for a primitive. Why? Primitive's have a contract: they must be transformable under vmap, vjp, jvp, etc. Every new primitive is hence a lot of work if we don't want to break that contract (which we don't). If you can implement stuff in operations instead then you get all of that for free. (As you also probably observed, adding a new primitive is a lot more code / requires a much deeper integration. It's more complicated and the maintenance burden is higher. Best to avoid when possible) So in that case when should you add a new Primitive?
|
Beta Was this translation helpful? Give feedback.
-
Thanks Awni!
Today, this contract isn’t enforced at compile time, right? IIRC it seems possible to add operation code as a Primitive without adding the vmap, vjp, jvp code. I’m not saying it should be enforced just curious if is; and the design overall; where the design is strict vs design is flexible for tinkering/hacking etc. IIUC this contract is to ensure computations occur on the backward pass, correct? I.e. for training.
Ah gotcha! I think I understand the “fat op” term better now. The catch here with these is that they’re only suitable for inference, right? I don’t recall if any of them have vjp, vmap impls —which I’m presuming are necessary only for training— and subsequently these fused fat op kernels cannot reliably easily implement vmap, vjp, jvp. |
Beta Was this translation helpful? Give feedback.
-
While working on
pinv
support in #875 it took me a rather embarrassingly long time to actually understand what MLX reviewers (Angelos, Awni) meant when they say ".. there isn't a need for a primitive; this op can be built in linalg namespace using svd..."As I grappled with the design of the repo, cpp api and various other cascade of dependencies, it took me a while before I could write reliable unit tests and confidently run them to validate iterative changes.
One such iteration was a full-blown Primitive of
PseudoInverse
, which calls theSVD
primitive and then hands off the matrix multiplication to lapack functions such as sgemmNow that I had reliable unit tests, and a general build-change-test loop going, it was easier to remove all the Primitives code and simply build a
linalg::pinv()
cpp op that describes the pinv array as a graph of ops usinglinalg::svd
.What I am curious about is: why don't we want to build a Primitive here for pinv? Is it because all the lower-level lapack function it calls are just svd, and cblass_sgemm, which are both already have their mlx specific functions?
Is there a need to build a Primitive only when we discover a new algorithm (or package or library dependency) that provably performs a given operation, say pinv, better (in terms of speed, memory etc)?
Beta Was this translation helpful? Give feedback.
All reactions