Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 283 additions & 0 deletions docs/experimental_call_torch_tutorial.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/google/jaxonnxruntime/blob/call_torch/docs/experimental_call_torch_tutorial.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ym4Xu6mWg2Na"
},
"outputs": [],
"source": [
"# For tips on running notebooks in Google Colab, see\n",
"# https://pytorch.org/tutorials/beginner/colab\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PHmroXZmg2Nb"
},
"source": [
"\n",
"# jaxonnxruntime call_torch Tutorial\n",
"**Author:** John Zhang\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6ez7jdFpg2Nc"
},
"source": [
"Here we introduce the call_torch API which can seamlessly translate PyTorch models into JAX functions. This integration unites PyTorch with the extensive JAX software ecosystem and harnesses the power of XLA hardware (TPU/GPU/CPU and openXLA ), enhancing cross-framework collaboration and performance potential\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"source": [
"!pip install git+https://github.com/google/jaxonnxruntime.git\n"
],
"metadata": {
"id": "Kkfn_llxzdDa"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!pip install onnx torch"
],
"metadata": {
"id": "1GX2-ZCwzmPA"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Yra_JCiTg2Nd"
},
"source": [
"## Basic Usage\n",
"\n",
"Generally, we describe all models with format. We use JAX PyTree data structure for any type model parameters and model inputs.Broadly, our approach involves characterizing all models using a standardized format. This entails employing the JAX PyTree data structure to encapsulate model parameters and inputs of varying types.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4RF7lO5ng2Nd"
},
"outputs": [],
"source": [
"import torch\n",
"import jax\n",
"from jaxonnxruntime.experimental import call_torch\n",
"\n",
"def foo(x, y):\n",
" a = torch.sin(x)\n",
" b = torch.cos(y)\n",
" return a + b\n",
"\n",
"torch_inputs =(torch.randn(10, 10), torch.randn(10, 10))\n",
"torch_module = torch.jit.trace(foo, torch_inputs)\n",
"\n",
"print(\"torch_output: \", torch_module(*torch_inputs))\n"
]
},
{
"cell_type": "code",
"source": [
"jax_fn, jax_params = call_torch.call_torch(torch_module, torch_inputs)\n",
"jax_inputs = jax.tree_map(call_torch.torch_tensor_to_np_array, torch_inputs)\n",
"print(\"jax_output:\", jax_fn(jax_params, jax_inputs))"
],
"metadata": {
"id": "kEEVNkLP1Owa"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "N0FJif4_g2Ne"
},
"source": [
"*We* can also take ``torch.nn.Module``.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zR31ReFKg2Ne"
},
"outputs": [],
"source": [
"class MyModule(torch.nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.lin = torch.nn.Linear(100, 10)\n",
"\n",
" def forward(self, x):\n",
" return torch.nn.functional.relu(self.lin(x))\n",
"\n",
"torch_module = MyModule()\n",
"torch_inputs = (torch.randn(10, 100), )"
]
},
{
"cell_type": "code",
"source": [
"torch_module.eval()\n",
"jax_fn, jax_params = call_torch.call_torch(torch_module, torch_inputs)\n",
"jax_inputs = jax.tree_map(call_torch.torch_tensor_to_np_array, torch_inputs)\n",
"print(\"jax_output:\", jax_fn(jax_params, jax_inputs))"
],
"metadata": {
"id": "H--BoG-E1d6w"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# A real testing model\n",
"\n"
],
"metadata": {
"id": "g1LoU8pa3-uK"
}
},
{
"cell_type": "code",
"source": [
"import torch\n",
"from torchvision.models import resnet50\n",
"# Generates random input and targets data for the model, where `b` is\n",
"# batch size.\n",
"\n",
"def generate_data(b):\n",
" return (\n",
" torch.randn(b, 3, 128, 128).to(torch.float32),\n",
" )\n",
"\n",
"torch_inputs = generate_data(1)\n",
"torch_module = resnet50()\n",
"torch_module.eval()\n",
"torch_module = torch.jit.trace(torch_module, torch_inputs)\n",
"torch_outputs = [torch_module(*torch_inputs)]\n",
"\n"
],
"metadata": {
"id": "p03lb4Ix4CvN"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from jaxonnxruntime.experimental import call_torch\n",
"import jax\n",
"jax_fn, jax_params = call_torch.call_torch(torch_module, torch_inputs)\n",
"jax_fn = jax.jit(jax_fn)\n",
"jax_inputs = jax.tree_map(call_torch.torch_tensor_to_np_array, torch_inputs)\n",
"jax_outputs = jax_fn(jax_params, jax_inputs)"
],
"metadata": {
"id": "d1UB2by-4Q8R"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from jaxonnxruntime.experimental.call_torch import CallTorchTestCase\n",
"test_case = CallTorchTestCase()\n",
"test_case.assert_allclose(jax.tree_map(call_torch.torch_tensor_to_np_array,torch_outputs), jax_outputs, rtol=1e-07, atol=1e-03)\n"
],
"metadata": {
"id": "GrNFj9z27FTC"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"%timeit _ = torch_module(*torch_inputs)"
],
"metadata": {
"id": "P-aPak7H9FkV"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"%timeit _ = jax_fn(jax_params, jax_inputs)\n"
],
"metadata": {
"id": "LAEbgDR59LPV"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "aoQ9-ggb97KP"
},
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"colab": {
"provenance": [],
"include_colab_link": true
},
"accelerator": "TPU"
},
"nbformat": 4,
"nbformat_minor": 0
}