You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
!pip install transformers
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from transformers import ViTFeatureExtractor, ViTModel, ViTConfig, AutoConfig
count = 0
for child in model_Res.children():
count += 1
if count < 8:
for param in child.parameters():
param.requires_grad = False
Modify the model - ViT model
model_trans = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
count = 0
for child in model_trans.children():
count += 1
if count >= 4:
for param in child.parameters():
param.requires_grad = False
!pip install transformers
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from transformers import ViTFeatureExtractor, ViTModel, ViTConfig, AutoConfig
Modify the model - ResNet
model_Res = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
Remove the last layer of the model Res
layers_Res = list(model_Res.children())
model_Res = nn.Sequential(*layers_Res[:-1])
Set the top layers to be not trainable
count = 0
for child in model_Res.children():
count += 1
if count < 8:
for param in child.parameters():
param.requires_grad = False
Modify the model - ViT model
model_trans = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
count = 0
for child in model_trans.children():
count += 1
if count >= 4:
for param in child.parameters():
param.requires_grad = False
layers_trans = list(model_trans.children())
model_trans_top = nn.Sequential(*layers_trans[:-2])
model1 = model_Res
model2 = model_trans_top
cka = CKA(model1, model2,
model1_name="ResNet50", model2_name="ViT",
device='cuda')
cka.compare(dataloader)
cka.plot_results(save_path="/content/drive/MyDrive/resnet-ViTcompare.png")
i got this error ValueError: Input image size (3232) doesn't match model (224224).
The text was updated successfully, but these errors were encountered: