Skip to content

Commit

Permalink
add torch -> tensorflow model transpilation demo
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong committed Sep 10, 2024
1 parent 90d2d7c commit 8db7b5d
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 0 deletions.
6 changes: 6 additions & 0 deletions learn_the_basics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ Learn the basics

Transpiling Kornia functions to TensorFlow.

.. grid-item-card:: Transpiling Models from PyTorch to TensorFlow
:link: learn_the_basics/torch_to_tf_models.ipynb

Transpiling PyTorch models to TensorFlow.

.. grid-item-card:: Trace Code
:link: learn_the_basics/03_trace_code.ipynb

Expand All @@ -29,6 +34,7 @@ Learn the basics
:maxdepth: -1

learn_the_basics/torch_to_tf_functions.ipynb
learn_the_basics/torch_to_tf_models.ipynb
learn_the_basics/03_trace_code.ipynb
learn_the_basics/05_lazy_vs_eager.ipynb
learn_the_basics/06_how_to_use_decorators.ipynb
16 changes: 16 additions & 0 deletions learn_the_basics/example_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch


class SimpleModel(torch.nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 3, kernel_size=3)
self.relu = torch.nn.ReLU()
self.fc = torch.nn.Linear(3 * 26 * 26, 10)

def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
160 changes: 160 additions & 0 deletions learn_the_basics/torch_to_tf_models.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Transpiling Models from PyTorch to TensorFlow"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can install the dependencies required for this notebook by running the cell below ⬇️, or check out the [Get Started](https://ivy.dev/docs/overview/get_started.html) section of the docs to find out more about installing ivy."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install ivy\n",
"!pip install torch\n",
"!pip install tensorflow"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we'll go through an example of how any model written in PyTorch can be converted, and used in, TensorFlow via `ivy.transpile`. First, lets import a simple torch model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from example_models import SimpleModel\n",
"\n",
"\"\"\"\n",
"This model is defined as follows:\n",
"\n",
"class SimpleModel(torch.nn.Module):\n",
" def __init__(self):\n",
" super(SimpleModel, self).__init__()\n",
" self.conv1 = torch.nn.Conv2d(1, 3, kernel_size=3)\n",
" self.relu = torch.nn.ReLU()\n",
" self.fc = torch.nn.Linear(3 * 26 * 26, 10)\n",
"\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = self.relu(x)\n",
" x = torch.flatten(x, 1)\n",
" x = self.fc(x)\n",
" return x\n",
"\"\"\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we can convert the model to tensorflow"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ivy\n",
"\n",
"TFSimpleModel = ivy.transpile(SimpleModel, source=\"torch\", target=\"tensorflow\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can use the model with TensorFlow"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TensorShape([1, 10])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import tensorflow as tf\n",
"\n",
"tf_model = TFSimpleModel()\n",
"tf_model(tf.random.normal((1, 1, 28, 28))).shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also take advantage of TensorFlow-specific features, such as `tf.function`:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TensorShape([1, 10])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"compiled_model = tf.function(tf_model)\n",
"compiled_model(tf.random.normal((1, 1, 28, 28))).shape"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 8db7b5d

Please sign in to comment.