diff --git a/examples/pytorch/image-classification/mnist.ipynb b/examples/pytorch/image-classification/mnist.ipynb new file mode 100644 index 0000000000..4710c1ee9c --- /dev/null +++ b/examples/pytorch/image-classification/mnist.ipynb @@ -0,0 +1,770 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# PyTorch DDP Fashion MNIST Training Example" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "This example demonstrates how to train a convolutional neural network to classify images using the [Fashion MNIST](https://github.com/zalandoresearch/fashion-mnist) dataset and [PyTorch Distributed Data Parallel (DDP)](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html).\n", + "\n", + "This notebook walks you through running that example locally, and how to easily scale PyTorch DDP across multiple nodes with Kubeflow TrainJob." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "## Install the Kubeflow SDK" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You need to install the Kubeflow SDK to interact with Kubeflow APIs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO (astefanutti): Change to the Kubeflow SDK when it's available.\n", + "!pip install git+https://github.com/kubeflow/training-operator.git@master#subdirectory=sdk_v2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Install the PyTorch dependencies" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You also need to install PyTorch and Torchvision to be able to run the example locally:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install torch==2.5.1\n", + "!pip install torchvision==0.20.0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the training function" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "def train_fashion_mnist(params):\n", + " import os\n", + "\n", + " import torch\n", + " import torch.distributed as dist\n", + " import torch.nn.functional as F\n", + " from torch import nn\n", + " from torch.utils.data import DataLoader, DistributedSampler\n", + " from torchvision import datasets, transforms\n", + "\n", + " # Define the PyTorch CNN model to be trained\n", + " class Net(nn.Module):\n", + " def __init__(self):\n", + " super(Net, self).__init__()\n", + " self.conv1 = nn.Conv2d(1, 20, 5, 1)\n", + " self.conv2 = nn.Conv2d(20, 50, 5, 1)\n", + " self.fc1 = nn.Linear(4 * 4 * 50, 500)\n", + " self.fc2 = nn.Linear(500, 10)\n", + "\n", + " def forward(self, x):\n", + " x = F.relu(self.conv1(x))\n", + " x = F.max_pool2d(x, 2, 2)\n", + " x = F.relu(self.conv2(x))\n", + " x = F.max_pool2d(x, 2, 2)\n", + " x = x.view(-1, 4 * 4 * 50)\n", + " x = F.relu(self.fc1(x))\n", + " x = self.fc2(x)\n", + " return F.log_softmax(x, dim=1)\n", + "\n", + " # Use NCCL is a GPU is available, otherwise use Gloo as communication backend\n", + " device, backend = (\"cuda\", \"nccl\") if torch.cuda.is_available() else (\"cpu\", \"gloo\")\n", + "\n", + " print(f\"Using Device: {device}, Backend: {backend}\")\n", + "\n", + " # Setup PyTorch Distributed\n", + " local_rank = int(os.getenv(\"LOCAL_RANK\", 0))\n", + " dist.init_process_group(backend=backend)\n", + "\n", + " print(\n", + " \"Distributed Training for WORLD_SIZE: {}, RANK: {}, LOCAL_RANK: {}\".format(\n", + " dist.get_world_size(),\n", + " dist.get_rank(),\n", + " local_rank,\n", + " )\n", + " )\n", + "\n", + " # Create the model and load it into the device\n", + " device = torch.device(f\"{device}:{local_rank}\")\n", + " model = nn.parallel.DistributedDataParallel(Net().to(device))\n", + "\n", + " # Retrieve the Fashion-MNIST dataset\n", + " dataset = datasets.FashionMNIST(\n", + " \"./data\",\n", + " train=True,\n", + " download=True,\n", + " transform=transforms.Compose([transforms.ToTensor()]),\n", + " )\n", + "\n", + " # Shard the dataset accross workers\n", + " train_loader = DataLoader(\n", + " dataset,\n", + " batch_size=100,\n", + " sampler=DistributedSampler(dataset),\n", + " pin_memory=torch.cuda.is_available(),\n", + " )\n", + "\n", + " # Setup the optimization loop\n", + " criterion = nn.CrossEntropyLoss().to(device)\n", + " optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)\n", + "\n", + " # TODO(astefanutti): add parameters to the training function\n", + " for epoch in range(1, 5):\n", + " model.train()\n", + "\n", + " # Iterate over mini-batches from the training set\n", + " for batch_idx, (inputs, labels) in enumerate(train_loader):\n", + " # Copy the data to the GPU device if available\n", + " inputs, labels = inputs.to(device), labels.to(device)\n", + " # Forward pass\n", + " outputs = model(inputs)\n", + " loss = criterion(outputs, labels)\n", + " # Backward pass\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " if batch_idx % 10 == 0 and dist.get_rank() == 0:\n", + " print(\n", + " \"Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\".format(\n", + " epoch,\n", + " batch_idx * len(inputs),\n", + " len(train_loader.dataset),\n", + " 100.0 * batch_idx / len(train_loader),\n", + " loss.item(),\n", + " )\n", + " )\n", + "\n", + " # Wait for the distributed training to complete\n", + " dist.barrier()\n", + " if dist.get_rank() == 0:\n", + " print(\"Training is finished\")\n", + "\n", + " # Finally clean up PyTorch distributed\n", + " dist.destroy_process_group()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dry-run the training locally" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using Device: cpu, Backend: gloo\n", + "Distributed Training for WORLD_SIZE: 1, RANK: 0, LOCAL_RANK: 0\n", + "Train Epoch: 1 [0/60000 (0%)]\tLoss: 2.305710\n", + "Train Epoch: 1 [1000/60000 (2%)]\tLoss: 2.196936\n", + "Train Epoch: 1 [2000/60000 (3%)]\tLoss: 1.686612\n", + "Train Epoch: 1 [3000/60000 (5%)]\tLoss: 1.011347\n", + "Train Epoch: 1 [4000/60000 (7%)]\tLoss: 0.865053\n", + "Train Epoch: 1 [5000/60000 (8%)]\tLoss: 0.836808\n", + "Train Epoch: 1 [6000/60000 (10%)]\tLoss: 0.621258\n", + "Train Epoch: 1 [7000/60000 (12%)]\tLoss: 0.766442\n", + "Train Epoch: 1 [8000/60000 (13%)]\tLoss: 0.638275\n", + "Train Epoch: 1 [9000/60000 (15%)]\tLoss: 0.640819\n", + "Train Epoch: 1 [10000/60000 (17%)]\tLoss: 0.629251\n", + "Train Epoch: 1 [11000/60000 (18%)]\tLoss: 0.604344\n", + "Train Epoch: 1 [12000/60000 (20%)]\tLoss: 0.594033\n", + "Train Epoch: 1 [13000/60000 (22%)]\tLoss: 0.601909\n", + "Train Epoch: 1 [14000/60000 (23%)]\tLoss: 0.644867\n", + "Train Epoch: 1 [15000/60000 (25%)]\tLoss: 0.650109\n", + "Train Epoch: 1 [16000/60000 (27%)]\tLoss: 0.496788\n", + "Train Epoch: 1 [17000/60000 (28%)]\tLoss: 0.524327\n", + "Train Epoch: 1 [18000/60000 (30%)]\tLoss: 0.384348\n", + "Train Epoch: 1 [19000/60000 (32%)]\tLoss: 0.560006\n", + "Train Epoch: 1 [20000/60000 (33%)]\tLoss: 0.442501\n", + "Train Epoch: 1 [21000/60000 (35%)]\tLoss: 0.376513\n", + "Train Epoch: 1 [22000/60000 (37%)]\tLoss: 0.325805\n", + "Train Epoch: 1 [23000/60000 (38%)]\tLoss: 0.577412\n", + "Train Epoch: 1 [24000/60000 (40%)]\tLoss: 0.481250\n", + "Train Epoch: 1 [25000/60000 (42%)]\tLoss: 0.367576\n", + "Train Epoch: 1 [26000/60000 (43%)]\tLoss: 0.525611\n", + "Train Epoch: 1 [27000/60000 (45%)]\tLoss: 0.298134\n", + "Train Epoch: 1 [28000/60000 (47%)]\tLoss: 0.393725\n", + "Train Epoch: 1 [29000/60000 (48%)]\tLoss: 0.428786\n", + "Train Epoch: 1 [30000/60000 (50%)]\tLoss: 0.504138\n", + "Train Epoch: 1 [31000/60000 (52%)]\tLoss: 0.380300\n", + "Train Epoch: 1 [32000/60000 (53%)]\tLoss: 0.368163\n", + "Train Epoch: 1 [33000/60000 (55%)]\tLoss: 0.334117\n", + "Train Epoch: 1 [34000/60000 (57%)]\tLoss: 0.473941\n", + "Train Epoch: 1 [35000/60000 (58%)]\tLoss: 0.427030\n", + "Train Epoch: 1 [36000/60000 (60%)]\tLoss: 0.432930\n", + "Train Epoch: 1 [37000/60000 (62%)]\tLoss: 0.449589\n", + "Train Epoch: 1 [38000/60000 (63%)]\tLoss: 0.415637\n", + "Train Epoch: 1 [39000/60000 (65%)]\tLoss: 0.461157\n", + "Train Epoch: 1 [40000/60000 (67%)]\tLoss: 0.293386\n", + "Train Epoch: 1 [41000/60000 (68%)]\tLoss: 0.522025\n", + "Train Epoch: 1 [42000/60000 (70%)]\tLoss: 0.418245\n", + "Train Epoch: 1 [43000/60000 (72%)]\tLoss: 0.415467\n", + "Train Epoch: 1 [44000/60000 (73%)]\tLoss: 0.382607\n", + "Train Epoch: 1 [45000/60000 (75%)]\tLoss: 0.271329\n", + "Train Epoch: 1 [46000/60000 (77%)]\tLoss: 0.531167\n", + "Train Epoch: 1 [47000/60000 (78%)]\tLoss: 0.369711\n", + "Train Epoch: 1 [48000/60000 (80%)]\tLoss: 0.449973\n", + "Train Epoch: 1 [49000/60000 (82%)]\tLoss: 0.341356\n", + "Train Epoch: 1 [50000/60000 (83%)]\tLoss: 0.214439\n", + "Train Epoch: 1 [51000/60000 (85%)]\tLoss: 0.393635\n", + "Train Epoch: 1 [52000/60000 (87%)]\tLoss: 0.357560\n", + "Train Epoch: 1 [53000/60000 (88%)]\tLoss: 0.340900\n", + "Train Epoch: 1 [54000/60000 (90%)]\tLoss: 0.357980\n", + "Train Epoch: 1 [55000/60000 (92%)]\tLoss: 0.333233\n", + "Train Epoch: 1 [56000/60000 (93%)]\tLoss: 0.448625\n", + "Train Epoch: 1 [57000/60000 (95%)]\tLoss: 0.447581\n", + "Train Epoch: 1 [58000/60000 (97%)]\tLoss: 0.335855\n", + "Train Epoch: 1 [59000/60000 (98%)]\tLoss: 0.236294\n", + "Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.339131\n", + "Train Epoch: 2 [1000/60000 (2%)]\tLoss: 0.417863\n", + "Train Epoch: 2 [2000/60000 (3%)]\tLoss: 0.253763\n", + "Train Epoch: 2 [3000/60000 (5%)]\tLoss: 0.289723\n", + "Train Epoch: 2 [4000/60000 (7%)]\tLoss: 0.379311\n", + "Train Epoch: 2 [5000/60000 (8%)]\tLoss: 0.355971\n", + "Train Epoch: 2 [6000/60000 (10%)]\tLoss: 0.406986\n", + "Train Epoch: 2 [7000/60000 (12%)]\tLoss: 0.368667\n", + "Train Epoch: 2 [8000/60000 (13%)]\tLoss: 0.378213\n", + "Train Epoch: 2 [9000/60000 (15%)]\tLoss: 0.353444\n", + "Train Epoch: 2 [10000/60000 (17%)]\tLoss: 0.411004\n", + "Train Epoch: 2 [11000/60000 (18%)]\tLoss: 0.368378\n", + "Train Epoch: 2 [12000/60000 (20%)]\tLoss: 0.292819\n", + "Train Epoch: 2 [13000/60000 (22%)]\tLoss: 0.308447\n", + "Train Epoch: 2 [14000/60000 (23%)]\tLoss: 0.431134\n", + "Train Epoch: 2 [15000/60000 (25%)]\tLoss: 0.377884\n", + "Train Epoch: 2 [16000/60000 (27%)]\tLoss: 0.307553\n", + "Train Epoch: 2 [17000/60000 (28%)]\tLoss: 0.318508\n", + "Train Epoch: 2 [18000/60000 (30%)]\tLoss: 0.233857\n", + "Train Epoch: 2 [19000/60000 (32%)]\tLoss: 0.467560\n", + "Train Epoch: 2 [20000/60000 (33%)]\tLoss: 0.437345\n", + "Train Epoch: 2 [21000/60000 (35%)]\tLoss: 0.193662\n", + "Train Epoch: 2 [22000/60000 (37%)]\tLoss: 0.285555\n", + "Train Epoch: 2 [23000/60000 (38%)]\tLoss: 0.349669\n", + "Train Epoch: 2 [24000/60000 (40%)]\tLoss: 0.263075\n", + "Train Epoch: 2 [25000/60000 (42%)]\tLoss: 0.246990\n", + "Train Epoch: 2 [26000/60000 (43%)]\tLoss: 0.393016\n", + "Train Epoch: 2 [27000/60000 (45%)]\tLoss: 0.264355\n", + "Train Epoch: 2 [28000/60000 (47%)]\tLoss: 0.362251\n", + "Train Epoch: 2 [29000/60000 (48%)]\tLoss: 0.311958\n", + "Train Epoch: 2 [30000/60000 (50%)]\tLoss: 0.345881\n", + "Train Epoch: 2 [31000/60000 (52%)]\tLoss: 0.282521\n", + "Train Epoch: 2 [32000/60000 (53%)]\tLoss: 0.288107\n", + "Train Epoch: 2 [33000/60000 (55%)]\tLoss: 0.219091\n", + "Train Epoch: 2 [34000/60000 (57%)]\tLoss: 0.467342\n", + "Train Epoch: 2 [35000/60000 (58%)]\tLoss: 0.548685\n", + "Train Epoch: 2 [36000/60000 (60%)]\tLoss: 0.361122\n", + "Train Epoch: 2 [37000/60000 (62%)]\tLoss: 0.238097\n", + "Train Epoch: 2 [38000/60000 (63%)]\tLoss: 0.301981\n", + "Train Epoch: 2 [39000/60000 (65%)]\tLoss: 0.320932\n", + "Train Epoch: 2 [40000/60000 (67%)]\tLoss: 0.201068\n", + "Train Epoch: 2 [41000/60000 (68%)]\tLoss: 0.392767\n", + "Train Epoch: 2 [42000/60000 (70%)]\tLoss: 0.311765\n", + "Train Epoch: 2 [43000/60000 (72%)]\tLoss: 0.302623\n", + "Train Epoch: 2 [44000/60000 (73%)]\tLoss: 0.321070\n", + "Train Epoch: 2 [45000/60000 (75%)]\tLoss: 0.223466\n", + "Train Epoch: 2 [46000/60000 (77%)]\tLoss: 0.372748\n", + "Train Epoch: 2 [47000/60000 (78%)]\tLoss: 0.282145\n", + "Train Epoch: 2 [48000/60000 (80%)]\tLoss: 0.457679\n", + "Train Epoch: 2 [49000/60000 (82%)]\tLoss: 0.336742\n", + "Train Epoch: 2 [50000/60000 (83%)]\tLoss: 0.241133\n", + "Train Epoch: 2 [51000/60000 (85%)]\tLoss: 0.389549\n", + "Train Epoch: 2 [52000/60000 (87%)]\tLoss: 0.243806\n", + "Train Epoch: 2 [53000/60000 (88%)]\tLoss: 0.245917\n", + "Train Epoch: 2 [54000/60000 (90%)]\tLoss: 0.260870\n", + "Train Epoch: 2 [55000/60000 (92%)]\tLoss: 0.295636\n", + "Train Epoch: 2 [56000/60000 (93%)]\tLoss: 0.462265\n", + "Train Epoch: 2 [57000/60000 (95%)]\tLoss: 0.445578\n", + "Train Epoch: 2 [58000/60000 (97%)]\tLoss: 0.340570\n", + "Train Epoch: 2 [59000/60000 (98%)]\tLoss: 0.242849\n", + "Train Epoch: 3 [0/60000 (0%)]\tLoss: 0.304856\n", + "Train Epoch: 3 [1000/60000 (2%)]\tLoss: 0.396272\n", + "Train Epoch: 3 [2000/60000 (3%)]\tLoss: 0.204409\n", + "Train Epoch: 3 [3000/60000 (5%)]\tLoss: 0.254451\n", + "Train Epoch: 3 [4000/60000 (7%)]\tLoss: 0.371637\n", + "Train Epoch: 3 [5000/60000 (8%)]\tLoss: 0.319432\n", + "Train Epoch: 3 [6000/60000 (10%)]\tLoss: 0.413572\n", + "Train Epoch: 3 [7000/60000 (12%)]\tLoss: 0.360567\n", + "Train Epoch: 3 [8000/60000 (13%)]\tLoss: 0.277111\n", + "Train Epoch: 3 [9000/60000 (15%)]\tLoss: 0.333096\n", + "Train Epoch: 3 [10000/60000 (17%)]\tLoss: 0.356821\n", + "Train Epoch: 3 [11000/60000 (18%)]\tLoss: 0.299927\n", + "Train Epoch: 3 [12000/60000 (20%)]\tLoss: 0.241802\n", + "Train Epoch: 3 [13000/60000 (22%)]\tLoss: 0.234447\n", + "Train Epoch: 3 [14000/60000 (23%)]\tLoss: 0.377867\n", + "Train Epoch: 3 [15000/60000 (25%)]\tLoss: 0.259782\n", + "Train Epoch: 3 [16000/60000 (27%)]\tLoss: 0.314048\n", + "Train Epoch: 3 [17000/60000 (28%)]\tLoss: 0.268746\n", + "Train Epoch: 3 [18000/60000 (30%)]\tLoss: 0.171334\n", + "Train Epoch: 3 [19000/60000 (32%)]\tLoss: 0.409583\n", + "Train Epoch: 3 [20000/60000 (33%)]\tLoss: 0.325467\n", + "Train Epoch: 3 [21000/60000 (35%)]\tLoss: 0.158008\n", + "Train Epoch: 3 [22000/60000 (37%)]\tLoss: 0.219097\n", + "Train Epoch: 3 [23000/60000 (38%)]\tLoss: 0.334556\n", + "Train Epoch: 3 [24000/60000 (40%)]\tLoss: 0.205891\n", + "Train Epoch: 3 [25000/60000 (42%)]\tLoss: 0.246978\n", + "Train Epoch: 3 [26000/60000 (43%)]\tLoss: 0.402184\n", + "Train Epoch: 3 [27000/60000 (45%)]\tLoss: 0.234204\n", + "Train Epoch: 3 [28000/60000 (47%)]\tLoss: 0.377011\n", + "Train Epoch: 3 [29000/60000 (48%)]\tLoss: 0.294795\n", + "Train Epoch: 3 [30000/60000 (50%)]\tLoss: 0.387394\n", + "Train Epoch: 3 [31000/60000 (52%)]\tLoss: 0.244570\n", + "Train Epoch: 3 [32000/60000 (53%)]\tLoss: 0.235462\n", + "Train Epoch: 3 [33000/60000 (55%)]\tLoss: 0.235727\n", + "Train Epoch: 3 [34000/60000 (57%)]\tLoss: 0.384234\n", + "Train Epoch: 3 [35000/60000 (58%)]\tLoss: 0.536794\n", + "Train Epoch: 3 [36000/60000 (60%)]\tLoss: 0.282465\n", + "Train Epoch: 3 [37000/60000 (62%)]\tLoss: 0.216106\n", + "Train Epoch: 3 [38000/60000 (63%)]\tLoss: 0.246326\n", + "Train Epoch: 3 [39000/60000 (65%)]\tLoss: 0.298494\n", + "Train Epoch: 3 [40000/60000 (67%)]\tLoss: 0.206228\n", + "Train Epoch: 3 [41000/60000 (68%)]\tLoss: 0.359189\n", + "Train Epoch: 3 [42000/60000 (70%)]\tLoss: 0.305101\n", + "Train Epoch: 3 [43000/60000 (72%)]\tLoss: 0.222482\n", + "Train Epoch: 3 [44000/60000 (73%)]\tLoss: 0.274947\n", + "Train Epoch: 3 [45000/60000 (75%)]\tLoss: 0.200588\n", + "Train Epoch: 3 [46000/60000 (77%)]\tLoss: 0.300547\n", + "Train Epoch: 3 [47000/60000 (78%)]\tLoss: 0.286985\n", + "Train Epoch: 3 [48000/60000 (80%)]\tLoss: 0.386801\n", + "Train Epoch: 3 [49000/60000 (82%)]\tLoss: 0.403562\n", + "Train Epoch: 3 [50000/60000 (83%)]\tLoss: 0.216983\n", + "Train Epoch: 3 [51000/60000 (85%)]\tLoss: 0.466640\n", + "Train Epoch: 3 [52000/60000 (87%)]\tLoss: 0.213542\n", + "Train Epoch: 3 [53000/60000 (88%)]\tLoss: 0.205078\n", + "Train Epoch: 3 [54000/60000 (90%)]\tLoss: 0.226228\n", + "Train Epoch: 3 [55000/60000 (92%)]\tLoss: 0.296321\n", + "Train Epoch: 3 [56000/60000 (93%)]\tLoss: 0.360502\n", + "Train Epoch: 3 [57000/60000 (95%)]\tLoss: 0.379296\n", + "Train Epoch: 3 [58000/60000 (97%)]\tLoss: 0.326375\n", + "Train Epoch: 3 [59000/60000 (98%)]\tLoss: 0.209175\n", + "Train Epoch: 4 [0/60000 (0%)]\tLoss: 0.253068\n", + "Train Epoch: 4 [1000/60000 (2%)]\tLoss: 0.263589\n", + "Train Epoch: 4 [2000/60000 (3%)]\tLoss: 0.220770\n", + "Train Epoch: 4 [3000/60000 (5%)]\tLoss: 0.241027\n", + "Train Epoch: 4 [4000/60000 (7%)]\tLoss: 0.278266\n", + "Train Epoch: 4 [5000/60000 (8%)]\tLoss: 0.288942\n", + "Train Epoch: 4 [6000/60000 (10%)]\tLoss: 0.377470\n", + "Train Epoch: 4 [7000/60000 (12%)]\tLoss: 0.256117\n", + "Train Epoch: 4 [8000/60000 (13%)]\tLoss: 0.269593\n", + "Train Epoch: 4 [9000/60000 (15%)]\tLoss: 0.337227\n", + "Train Epoch: 4 [10000/60000 (17%)]\tLoss: 0.269241\n", + "Train Epoch: 4 [11000/60000 (18%)]\tLoss: 0.314784\n", + "Train Epoch: 4 [12000/60000 (20%)]\tLoss: 0.231463\n", + "Train Epoch: 4 [13000/60000 (22%)]\tLoss: 0.248236\n", + "Train Epoch: 4 [14000/60000 (23%)]\tLoss: 0.381541\n", + "Train Epoch: 4 [15000/60000 (25%)]\tLoss: 0.243203\n", + "Train Epoch: 4 [16000/60000 (27%)]\tLoss: 0.257698\n", + "Train Epoch: 4 [17000/60000 (28%)]\tLoss: 0.255600\n", + "Train Epoch: 4 [18000/60000 (30%)]\tLoss: 0.164215\n", + "Train Epoch: 4 [19000/60000 (32%)]\tLoss: 0.403348\n", + "Train Epoch: 4 [20000/60000 (33%)]\tLoss: 0.312329\n", + "Train Epoch: 4 [21000/60000 (35%)]\tLoss: 0.148644\n", + "Train Epoch: 4 [22000/60000 (37%)]\tLoss: 0.207623\n", + "Train Epoch: 4 [23000/60000 (38%)]\tLoss: 0.298741\n", + "Train Epoch: 4 [24000/60000 (40%)]\tLoss: 0.224506\n", + "Train Epoch: 4 [25000/60000 (42%)]\tLoss: 0.236052\n", + "Train Epoch: 4 [26000/60000 (43%)]\tLoss: 0.341681\n", + "Train Epoch: 4 [27000/60000 (45%)]\tLoss: 0.231967\n", + "Train Epoch: 4 [28000/60000 (47%)]\tLoss: 0.260894\n", + "Train Epoch: 4 [29000/60000 (48%)]\tLoss: 0.251320\n", + "Train Epoch: 4 [30000/60000 (50%)]\tLoss: 0.335226\n", + "Train Epoch: 4 [31000/60000 (52%)]\tLoss: 0.184827\n", + "Train Epoch: 4 [32000/60000 (53%)]\tLoss: 0.192683\n", + "Train Epoch: 4 [33000/60000 (55%)]\tLoss: 0.181942\n", + "Train Epoch: 4 [34000/60000 (57%)]\tLoss: 0.427882\n", + "Train Epoch: 4 [35000/60000 (58%)]\tLoss: 0.440395\n", + "Train Epoch: 4 [36000/60000 (60%)]\tLoss: 0.249151\n", + "Train Epoch: 4 [37000/60000 (62%)]\tLoss: 0.251662\n", + "Train Epoch: 4 [38000/60000 (63%)]\tLoss: 0.263578\n", + "Train Epoch: 4 [39000/60000 (65%)]\tLoss: 0.269635\n", + "Train Epoch: 4 [40000/60000 (67%)]\tLoss: 0.174087\n", + "Train Epoch: 4 [41000/60000 (68%)]\tLoss: 0.307767\n", + "Train Epoch: 4 [42000/60000 (70%)]\tLoss: 0.276867\n", + "Train Epoch: 4 [43000/60000 (72%)]\tLoss: 0.267041\n", + "Train Epoch: 4 [44000/60000 (73%)]\tLoss: 0.265095\n", + "Train Epoch: 4 [45000/60000 (75%)]\tLoss: 0.211275\n", + "Train Epoch: 4 [46000/60000 (77%)]\tLoss: 0.256053\n", + "Train Epoch: 4 [47000/60000 (78%)]\tLoss: 0.293347\n", + "Train Epoch: 4 [48000/60000 (80%)]\tLoss: 0.413967\n", + "Train Epoch: 4 [49000/60000 (82%)]\tLoss: 0.300236\n", + "Train Epoch: 4 [50000/60000 (83%)]\tLoss: 0.140499\n", + "Train Epoch: 4 [51000/60000 (85%)]\tLoss: 0.336364\n", + "Train Epoch: 4 [52000/60000 (87%)]\tLoss: 0.176903\n", + "Train Epoch: 4 [53000/60000 (88%)]\tLoss: 0.226507\n", + "Train Epoch: 4 [54000/60000 (90%)]\tLoss: 0.235510\n", + "Train Epoch: 4 [55000/60000 (92%)]\tLoss: 0.330527\n", + "Train Epoch: 4 [56000/60000 (93%)]\tLoss: 0.343374\n", + "Train Epoch: 4 [57000/60000 (95%)]\tLoss: 0.282549\n", + "Train Epoch: 4 [58000/60000 (97%)]\tLoss: 0.242879\n", + "Train Epoch: 4 [59000/60000 (98%)]\tLoss: 0.190787\n", + "Training is finished\n" + ] + } + ], + "source": [ + "import os\n", + "\n", + "# Set the Torch Distributed env variables so the training function can be run in the notebook\n", + "# See https://pytorch.org/docs/stable/elastic/run.html#environment-variables\n", + "os.environ[\"RANK\"] = \"0\"\n", + "os.environ[\"LOCAL_RANK\"] = \"0\"\n", + "os.environ[\"WORLD_SIZE\"] = \"1\"\n", + "os.environ[\"MASTER_ADDR\"] = \"localhost\"\n", + "os.environ[\"MASTER_PORT\"] = \"1234\"\n", + "\n", + "# Run the training function locally\n", + "train_fashion_mnist()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Scale PyTorch DDP with Kubeflow TrainJob" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can use `TrainingClient()` from the Kubeflow SDK to communicate with Kubeflow APIs and scale your training function across multiple PyTorch training nodes.\n", + "\n", + "Kubeflow Trainer creates a `TrainJob` resource and automatically sets the appropriate environment variables to set up PyTorch in distributed environment." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "from kubeflow.training import Trainer, TrainingClient\n", + "client = TrainingClient()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## List the Training Runtimes" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can get the list of available Training Runtimes to start your TrainJob:" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Runtime(name='torch-distributed', phase='pre-training', accelerator='Unknown', accelerator_count='Unknown')\n" + ] + } + ], + "source": [ + "for runtime in client.list_runtimes():\n", + " print(runtime)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Each Training Runtime shows whether you can use it for pre-training or post-training.\n", + "Additionally, it shows available accelerator type and number of available resources." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run the distributed TrainJob" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "job_name = client.train(\n", + " # Use one the of the training runtimes installed on your Kubernetes cluster\n", + " runtime_ref=\"torch-distributed\",\n", + " trainer=Trainer(\n", + " func=train_fashion_mnist,\n", + " func_args={\n", + " \"epochs\": 10,\n", + " },\n", + " # Set how many worker Pods you want the job to be distributed into\n", + " num_nodes=4,\n", + " # Set the resources for each worker Pod\n", + " resources_per_node={\n", + " \"cpu\": 8,\n", + " \"memory\": \"16Gi\",\n", + " \"nvidia.com/gpu\": 1,\n", + " },\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Check the TrainJob components" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can check the details of the TrainJob that's created:" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TrainJob(name='d1018a6a6f7c', runtime_ref='torch-distributed', creation_timestamp=datetime.datetime(2025, 1, 29, 16, 41, 28, tzinfo=tzutc()), components=[Component(name='trainer-node-0', status='Running', device='gpu', device_count='1', pod_name='d1018a6a6f7c-trainer-node-0-0-pmzdh'), Component(name='trainer-node-1', status='Running', device='gpu', device_count='1', pod_name='d1018a6a6f7c-trainer-node-0-1-57t57'), Component(name='trainer-node-2', status='Running', device='gpu', device_count='1', pod_name='d1018a6a6f7c-trainer-node-0-2-pk8cb'), Component(name='trainer-node-3', status='Running', device='gpu', device_count='1', pod_name='d1018a6a6f7c-trainer-node-0-3-vcm88')], status='Created')\n" + ] + } + ], + "source": [ + "job = client.get_job(job_name)\n", + "print(job)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Since the TrainJob is distributed using 4 nodes, the TrainJob creates 4 components: `trainer-node-0`, ..., `trainer-node-3`, and you can get the individual status for each of these components." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Watch the TrainJob logs" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[trainer-node]: Using Device: cuda, Backend: nccl\n", + "[trainer-node]: Distributed Training for WORLD_SIZE: 4, RANK: 0, LOCAL_RANK: 0\n", + "[trainer-node]: Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz\n", + "[trainer-node]: Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz\n", + "100%|██████████| 26.4M/26.4M [00:01<00:00, 14.6MB/s]\n", + "[trainer-node]: Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw\n", + "[trainer-node]: Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz\n", + "[trainer-node]: Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz\n", + "100%|██████████| 29.5k/29.5k [00:00<00:00, 327kB/s]\n", + "[trainer-node]: Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw\n", + "[trainer-node]: Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz\n", + "[trainer-node]: Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz\n", + "100%|██████████| 4.42M/4.42M [00:00<00:00, 5.87MB/s]\n", + "[trainer-node]: Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw\n", + "[trainer-node]: Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz\n", + "[trainer-node]: Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz\n", + "100%|██████████| 5.15k/5.15k [00:00<00:00, 37.1MB/s]\n", + "[trainer-node]: Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw\n", + "[trainer-node]: Train Epoch: 1 [0/60000 (0%)]\tLoss: 2.297739\n", + "[trainer-node]: Train Epoch: 1 [1000/60000 (7%)]\tLoss: 1.720075\n", + "[trainer-node]: Train Epoch: 1 [2000/60000 (13%)]\tLoss: 1.435773\n", + "[trainer-node]: Train Epoch: 1 [3000/60000 (20%)]\tLoss: 1.635886\n", + "[trainer-node]: Train Epoch: 1 [4000/60000 (27%)]\tLoss: 0.979150\n", + "[trainer-node]: Train Epoch: 1 [5000/60000 (33%)]\tLoss: 0.812853\n", + "[trainer-node]: Train Epoch: 1 [6000/60000 (40%)]\tLoss: 0.795117\n", + "[trainer-node]: Train Epoch: 1 [7000/60000 (47%)]\tLoss: 0.769644\n", + "[trainer-node]: Train Epoch: 1 [8000/60000 (53%)]\tLoss: 0.474975\n", + "[trainer-node]: Train Epoch: 1 [9000/60000 (60%)]\tLoss: 0.560368\n", + "[trainer-node]: Train Epoch: 1 [10000/60000 (67%)]\tLoss: 0.425427\n", + "[trainer-node]: Train Epoch: 1 [11000/60000 (73%)]\tLoss: 0.490059\n", + "[trainer-node]: Train Epoch: 1 [12000/60000 (80%)]\tLoss: 0.446036\n", + "[trainer-node]: Train Epoch: 1 [13000/60000 (87%)]\tLoss: 0.514898\n", + "[trainer-node]: Train Epoch: 1 [14000/60000 (93%)]\tLoss: 0.399690\n", + "[trainer-node]: Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.516461\n", + "[trainer-node]: Train Epoch: 2 [1000/60000 (7%)]\tLoss: 0.437594\n", + "[trainer-node]: Train Epoch: 2 [2000/60000 (13%)]\tLoss: 0.383412\n", + "[trainer-node]: Train Epoch: 2 [3000/60000 (20%)]\tLoss: 0.318116\n", + "[trainer-node]: Train Epoch: 2 [4000/60000 (27%)]\tLoss: 0.457998\n", + "[trainer-node]: Train Epoch: 2 [5000/60000 (33%)]\tLoss: 0.385253\n", + "[trainer-node]: Train Epoch: 2 [6000/60000 (40%)]\tLoss: 0.555350\n", + "[trainer-node]: Train Epoch: 2 [7000/60000 (47%)]\tLoss: 0.455115\n", + "[trainer-node]: Train Epoch: 2 [8000/60000 (53%)]\tLoss: 0.363895\n", + "[trainer-node]: Train Epoch: 2 [9000/60000 (60%)]\tLoss: 0.441691\n", + "[trainer-node]: Train Epoch: 2 [10000/60000 (67%)]\tLoss: 0.315430\n", + "[trainer-node]: Train Epoch: 2 [11000/60000 (73%)]\tLoss: 0.385001\n", + "[trainer-node]: Train Epoch: 2 [12000/60000 (80%)]\tLoss: 0.313463\n", + "[trainer-node]: Train Epoch: 2 [13000/60000 (87%)]\tLoss: 0.338070\n", + "[trainer-node]: Train Epoch: 2 [14000/60000 (93%)]\tLoss: 0.305447\n", + "[trainer-node]: Train Epoch: 3 [0/60000 (0%)]\tLoss: 0.375701\n", + "[trainer-node]: Train Epoch: 3 [1000/60000 (7%)]\tLoss: 0.419893\n", + "[trainer-node]: Train Epoch: 3 [2000/60000 (13%)]\tLoss: 0.362021\n", + "[trainer-node]: Train Epoch: 3 [3000/60000 (20%)]\tLoss: 0.246983\n", + "[trainer-node]: Train Epoch: 3 [4000/60000 (27%)]\tLoss: 0.328127\n", + "[trainer-node]: Train Epoch: 3 [5000/60000 (33%)]\tLoss: 0.299359\n", + "[trainer-node]: Train Epoch: 3 [6000/60000 (40%)]\tLoss: 0.484671\n", + "[trainer-node]: Train Epoch: 3 [7000/60000 (47%)]\tLoss: 0.463402\n", + "[trainer-node]: Train Epoch: 3 [8000/60000 (53%)]\tLoss: 0.277818\n", + "[trainer-node]: Train Epoch: 3 [9000/60000 (60%)]\tLoss: 0.384348\n", + "[trainer-node]: Train Epoch: 3 [10000/60000 (67%)]\tLoss: 0.271459\n", + "[trainer-node]: Train Epoch: 3 [11000/60000 (73%)]\tLoss: 0.305670\n", + "[trainer-node]: Train Epoch: 3 [12000/60000 (80%)]\tLoss: 0.249658\n", + "[trainer-node]: Train Epoch: 3 [13000/60000 (87%)]\tLoss: 0.311472\n", + "[trainer-node]: Train Epoch: 3 [14000/60000 (93%)]\tLoss: 0.264840\n", + "[trainer-node]: Train Epoch: 4 [0/60000 (0%)]\tLoss: 0.346031\n", + "[trainer-node]: Train Epoch: 4 [1000/60000 (7%)]\tLoss: 0.362948\n", + "[trainer-node]: Train Epoch: 4 [2000/60000 (13%)]\tLoss: 0.324963\n", + "[trainer-node]: Train Epoch: 4 [3000/60000 (20%)]\tLoss: 0.233126\n", + "[trainer-node]: Train Epoch: 4 [4000/60000 (27%)]\tLoss: 0.256574\n", + "[trainer-node]: Train Epoch: 4 [5000/60000 (33%)]\tLoss: 0.264268\n", + "[trainer-node]: Train Epoch: 4 [6000/60000 (40%)]\tLoss: 0.451931\n", + "[trainer-node]: Train Epoch: 4 [7000/60000 (47%)]\tLoss: 0.475285\n", + "[trainer-node]: Train Epoch: 4 [8000/60000 (53%)]\tLoss: 0.250741\n", + "[trainer-node]: Train Epoch: 4 [9000/60000 (60%)]\tLoss: 0.252023\n", + "[trainer-node]: Train Epoch: 4 [10000/60000 (67%)]\tLoss: 0.268340\n", + "[trainer-node]: Train Epoch: 4 [11000/60000 (73%)]\tLoss: 0.270340\n", + "[trainer-node]: Train Epoch: 4 [12000/60000 (80%)]\tLoss: 0.222117\n", + "[trainer-node]: Train Epoch: 4 [13000/60000 (87%)]\tLoss: 0.292431\n", + "[trainer-node]: Train Epoch: 4 [14000/60000 (93%)]\tLoss: 0.273884\n", + "[trainer-node]: Training is finished\n" + ] + } + ], + "source": [ + "_ = client.get_job_logs(job_name, follow=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Each node processes it's assigned shard of the Fashion-MNIST dataset.\n", + "As the `TrainJob` is distributed on 4 nodes, and the dataset contains a total of 60 000 samples, each node processes 15 000 samples." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Delete the TrainJob" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "client.delete_job(job_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}