|
934 | 934 | "By varying `K` we can study how performance improves as we provide more labelled examples to the model.\n" |
935 | 935 | ] |
936 | 936 | }, |
937 | | - { |
938 | | - "cell_type": "code", |
939 | | - "execution_count": 18, |
940 | | - "metadata": {}, |
941 | | - "outputs": [], |
942 | | - "source": [ |
943 | | - "# ============================================================\n", |
944 | | - "# 9a. K-shot inference on test images\n", |
945 | | - "# ============================================================\n", |
946 | | - "\n", |
947 | | - "import numpy as np\n", |
948 | | - "import torch\n", |
949 | | - "\n", |
950 | | - "\n", |
951 | | - "def evaluate_kshot_iou(encoder, train_dataset, test_dataset, K=5, num_samples=None):\n", |
952 | | - " \"\"\"\n", |
953 | | - " Evaluate K-shot IoU on 'num_samples' random test images.\n", |
954 | | - " For each test image, randomly sample K *distinct* support images from the train set.\n", |
955 | | - " \"\"\"\n", |
956 | | - " encoder.eval()\n", |
957 | | - " rng = np.random.default_rng(42)\n", |
958 | | - "\n", |
959 | | - " # If num_samples is None, evaluate on all test samples\n", |
960 | | - " if num_samples is None:\n", |
961 | | - " num_samples = len(test_dataset)\n", |
962 | | - "\n", |
963 | | - " ious = []\n", |
964 | | - " for _ in range(num_samples):\n", |
965 | | - " # pick random test index\n", |
966 | | - " ti = rng.integers(0, len(test_dataset))\n", |
967 | | - " img_q, mask_q = test_dataset[ti]\n", |
968 | | - "\n", |
969 | | - " # pick K distinct support indices\n", |
970 | | - " support_indices = rng.choice(len(train_dataset), size=K, replace=False)\n", |
971 | | - " support_imgs = []\n", |
972 | | - " support_masks = []\n", |
973 | | - " for si in support_indices:\n", |
974 | | - " img_s, mask_s = train_dataset[si]\n", |
975 | | - " support_imgs.append(img_s)\n", |
976 | | - " support_masks.append(mask_s)\n", |
977 | | - " support_imgs = torch.stack(support_imgs, dim=0) # [K,3,H,W]\n", |
978 | | - " support_masks = torch.stack(support_masks, dim=0) # [K,1,H,W]\n", |
979 | | - "\n", |
980 | | - " # run K-shot prediction\n", |
981 | | - " logits = k_shot_predict(encoder, support_imgs, support_masks, img_q) # [1,2,H,W]\n", |
982 | | - "\n", |
983 | | - " iou = iou_from_logits(logits, mask_q.unsqueeze(0))\n", |
984 | | - " ious.append(iou)\n", |
985 | | - "\n", |
986 | | - " ious = np.array(ious)\n", |
987 | | - " print(f\"{K}-shot mean IoU over {num_samples} test samples: {ious.mean():.3f}\")\n", |
988 | | - " return ious" |
989 | | - ] |
990 | | - }, |
991 | 937 | { |
992 | 938 | "cell_type": "markdown", |
993 | 939 | "metadata": {}, |
|
0 commit comments