diff --git a/docs/experimental_call_torch_tutorial.ipynb b/docs/experimental_call_torch_tutorial.ipynb new file mode 100644 index 0000000..8e92b44 --- /dev/null +++ b/docs/experimental_call_torch_tutorial.ipynb @@ -0,0 +1,283 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Ym4Xu6mWg2Na" + }, + "outputs": [], + "source": [ + "# For tips on running notebooks in Google Colab, see\n", + "# https://pytorch.org/tutorials/beginner/colab\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PHmroXZmg2Nb" + }, + "source": [ + "\n", + "# jaxonnxruntime call_torch Tutorial\n", + "**Author:** John Zhang\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6ez7jdFpg2Nc" + }, + "source": [ + "Here we introduce the call_torch API which can seamlessly translate PyTorch models into JAX functions. This integration unites PyTorch with the extensive JAX software ecosystem and harnesses the power of XLA hardware (TPU/GPU/CPU and openXLA ), enhancing cross-framework collaboration and performance potential\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install git+https://github.com/google/jaxonnxruntime.git\n" + ], + "metadata": { + "id": "Kkfn_llxzdDa" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!pip install onnx torch" + ], + "metadata": { + "id": "1GX2-ZCwzmPA" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Yra_JCiTg2Nd" + }, + "source": [ + "## Basic Usage\n", + "\n", + "Generally, we describe all models with format. We use JAX PyTree data structure for any type model parameters and model inputs.Broadly, our approach involves characterizing all models using a standardized format. This entails employing the JAX PyTree data structure to encapsulate model parameters and inputs of varying types.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4RF7lO5ng2Nd" + }, + "outputs": [], + "source": [ + "import torch\n", + "import jax\n", + "from jaxonnxruntime.experimental import call_torch\n", + "\n", + "def foo(x, y):\n", + " a = torch.sin(x)\n", + " b = torch.cos(y)\n", + " return a + b\n", + "\n", + "torch_inputs =(torch.randn(10, 10), torch.randn(10, 10))\n", + "torch_module = torch.jit.trace(foo, torch_inputs)\n", + "\n", + "print(\"torch_output: \", torch_module(*torch_inputs))\n" + ] + }, + { + "cell_type": "code", + "source": [ + "jax_fn, jax_params = call_torch.call_torch(torch_module, torch_inputs)\n", + "jax_inputs = jax.tree_map(call_torch.torch_tensor_to_np_array, torch_inputs)\n", + "print(\"jax_output:\", jax_fn(jax_params, jax_inputs))" + ], + "metadata": { + "id": "kEEVNkLP1Owa" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N0FJif4_g2Ne" + }, + "source": [ + "*We* can also take ``torch.nn.Module``.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zR31ReFKg2Ne" + }, + "outputs": [], + "source": [ + "class MyModule(torch.nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.lin = torch.nn.Linear(100, 10)\n", + "\n", + " def forward(self, x):\n", + " return torch.nn.functional.relu(self.lin(x))\n", + "\n", + "torch_module = MyModule()\n", + "torch_inputs = (torch.randn(10, 100), )" + ] + }, + { + "cell_type": "code", + "source": [ + "torch_module.eval()\n", + "jax_fn, jax_params = call_torch.call_torch(torch_module, torch_inputs)\n", + "jax_inputs = jax.tree_map(call_torch.torch_tensor_to_np_array, torch_inputs)\n", + "print(\"jax_output:\", jax_fn(jax_params, jax_inputs))" + ], + "metadata": { + "id": "H--BoG-E1d6w" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# A real testing model\n", + "\n" + ], + "metadata": { + "id": "g1LoU8pa3-uK" + } + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "from torchvision.models import resnet50\n", + "# Generates random input and targets data for the model, where `b` is\n", + "# batch size.\n", + "\n", + "def generate_data(b):\n", + " return (\n", + " torch.randn(b, 3, 128, 128).to(torch.float32),\n", + " )\n", + "\n", + "torch_inputs = generate_data(1)\n", + "torch_module = resnet50()\n", + "torch_module.eval()\n", + "torch_module = torch.jit.trace(torch_module, torch_inputs)\n", + "torch_outputs = [torch_module(*torch_inputs)]\n", + "\n" + ], + "metadata": { + "id": "p03lb4Ix4CvN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from jaxonnxruntime.experimental import call_torch\n", + "import jax\n", + "jax_fn, jax_params = call_torch.call_torch(torch_module, torch_inputs)\n", + "jax_fn = jax.jit(jax_fn)\n", + "jax_inputs = jax.tree_map(call_torch.torch_tensor_to_np_array, torch_inputs)\n", + "jax_outputs = jax_fn(jax_params, jax_inputs)" + ], + "metadata": { + "id": "d1UB2by-4Q8R" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from jaxonnxruntime.experimental.call_torch import CallTorchTestCase\n", + "test_case = CallTorchTestCase()\n", + "test_case.assert_allclose(jax.tree_map(call_torch.torch_tensor_to_np_array,torch_outputs), jax_outputs, rtol=1e-07, atol=1e-03)\n" + ], + "metadata": { + "id": "GrNFj9z27FTC" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "%timeit _ = torch_module(*torch_inputs)" + ], + "metadata": { + "id": "P-aPak7H9FkV" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "%timeit _ = jax_fn(jax_params, jax_inputs)\n" + ], + "metadata": { + "id": "LAEbgDR59LPV" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "aoQ9-ggb97KP" + }, + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "colab": { + "provenance": [], + "include_colab_link": true + }, + "accelerator": "TPU" + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file