Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions notebooks/vjepa2_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@
" return video\n",
"\n",
"\n",
"def forward_vjepa_video(model_hf, model_pt, hf_transform, pt_transform):\n",
"def forward_vjepa_video(model_hf, model_pt, hf_transform, pt_transform, device):\n",
" # Run a sample inference with VJEPA\n",
" with torch.inference_mode():\n",
" # Read and pre-process the image\n",
" video = get_video() # T x H x W x C\n",
" video = torch.from_numpy(video).permute(0, 3, 1, 2) # T x C x H x W\n",
" x_pt = pt_transform(video).cuda().unsqueeze(0)\n",
" x_hf = hf_transform(video, return_tensors=\"pt\")[\"pixel_values_videos\"].to(\"cuda\")\n",
" x_pt = pt_transform(video).to(device).unsqueeze(0)\n",
" x_hf = hf_transform(video, return_tensors=\"pt\")[\"pixel_values_videos\"].to(device)\n",
" # Extract the patch-wise features from the last layer\n",
" out_patch_features_pt = model_pt(x_pt)\n",
" out_patch_features_hf = model_hf.get_vision_features(x_hf)\n",
Expand Down Expand Up @@ -176,22 +176,30 @@
"# Path to local PyTorch weights\n",
"pt_model_path = \"YOUR_MODEL_PATH\"\n",
"\n",
"# Configuring GPU acceleration for CUDA or MPS(Apple Silicon)\n",
"if torch.cuda.is_available():\n",
" device = torch.device(\"cuda\")\n",
"elif torch.backends.mps.is_available():\n",
" device = torch.device(\"mps\") # Apple Silicon GPU support\n",
"else:\n",
" device = torch.device(\"cpu\")\n",
"\n",
"# Initialize the HuggingFace model, load pretrained weights\n",
"model_hf = AutoModel.from_pretrained(hf_model_name)\n",
"model_hf.cuda().eval()\n",
"model_hf.to(device).eval()\n",
"\n",
"# Build HuggingFace preprocessing transform\n",
"hf_transform = AutoVideoProcessor.from_pretrained(hf_model_name)\n",
"img_size = hf_transform.crop_size[\"height\"] # E.g. 384, 256, etc.\n",
"\n",
"# Initialize the PyTorch model, load pretrained weights\n",
"model_pt = vit_giant_xformers_rope(img_size=(img_size, img_size), num_frames=64)\n",
"model_pt.cuda().eval()\n",
"model_pt.to(device).eval()\n",
"load_pretrained_vjepa_pt_weights(model_pt, pt_model_path)\n",
"\n",
"### Can also use torch.hub to load the model\n",
"# model_pt, _ = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_giant_384')\n",
"# model_pt.cuda().eval()\n",
"# model_pt.to(device).eval()\n",
"\n",
"# Build PyTorch preprocessing transform\n",
"pt_video_transform = build_pt_video_transform(img_size=img_size)"
Expand All @@ -212,7 +220,7 @@
"source": [
"# Inference on video to get the patch-wise features\n",
"out_patch_features_hf, out_patch_features_pt = forward_vjepa_video(\n",
" model_hf, model_pt, hf_transform, pt_video_transform\n",
" model_hf, model_pt, hf_transform, pt_video_transform, device\n",
")\n",
"\n",
"print(\n",
Expand Down Expand Up @@ -246,7 +254,7 @@
"# Initialize the classifier\n",
"classifier_model_path = \"YOUR_ATTENTIVE_PROBE_PATH\"\n",
"classifier = (\n",
" AttentiveClassifier(embed_dim=model_pt.embed_dim, num_heads=16, depth=4, num_classes=174).cuda().eval()\n",
" AttentiveClassifier(embed_dim=model_pt.embed_dim, num_heads=16, depth=4, num_classes=174).to(device).eval()\n",
")\n",
"load_pretrained_vjepa_classifier_weights(classifier, classifier_model_path)\n",
"\n",
Expand Down
25 changes: 16 additions & 9 deletions notebooks/vjepa2_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ def get_video():
return video


def forward_vjepa_video(model_hf, model_pt, hf_transform, pt_transform):
def forward_vjepa_video(model_hf, model_pt, hf_transform, pt_transform, device):
# Run a sample inference with VJEPA
with torch.inference_mode():
# Read and pre-process the image
video = get_video() # T x H x W x C
video = torch.from_numpy(video).permute(0, 3, 1, 2) # T x C x H x W
x_pt = pt_transform(video).cuda().unsqueeze(0)
x_hf = hf_transform(video, return_tensors="pt")["pixel_values_videos"].to("cuda")
x_pt = pt_transform(video).to(device).unsqueeze(0)
x_hf = hf_transform(video, return_tensors="pt")["pixel_values_videos"].to(device)
# Extract the patch-wise features from the last layer
out_patch_features_pt = model_pt(x_pt)
out_patch_features_hf = model_hf.get_vision_features(x_hf)
Expand All @@ -96,7 +96,7 @@ def get_vjepa_video_classification_results(classifier, out_patch_features_pt):
return


def run_sample_inference():
def run_sample_inference(device):
# HuggingFace model repo name
hf_model_name = (
"facebook/vjepa2-vitg-fpc64-384" # Replace with your favored model, e.g. facebook/vjepa2-vitg-fpc64-384
Expand All @@ -114,23 +114,23 @@ def run_sample_inference():

# Initialize the HuggingFace model, load pretrained weights
model_hf = AutoModel.from_pretrained(hf_model_name)
model_hf.cuda().eval()
model_hf.to(device).eval()

# Build HuggingFace preprocessing transform
hf_transform = AutoVideoProcessor.from_pretrained(hf_model_name)
img_size = hf_transform.crop_size["height"] # E.g. 384, 256, etc.

# Initialize the PyTorch model, load pretrained weights
model_pt = vit_giant_xformers_rope(img_size=(img_size, img_size), num_frames=64)
model_pt.cuda().eval()
model_pt.to(device).eval()
load_pretrained_vjepa_pt_weights(model_pt, pt_model_path)

# Build PyTorch preprocessing transform
pt_video_transform = build_pt_video_transform(img_size=img_size)

# Inference on video
out_patch_features_hf, out_patch_features_pt = forward_vjepa_video(
model_hf, model_pt, hf_transform, pt_video_transform
model_hf, model_pt, hf_transform, pt_video_transform, device
)

print(
Expand All @@ -146,7 +146,7 @@ def run_sample_inference():
# Initialize the classifier
classifier_model_path = "YOUR_ATTENTIVE_PROBE_PATH"
classifier = (
AttentiveClassifier(embed_dim=model_pt.embed_dim, num_heads=16, depth=4, num_classes=174).cuda().eval()
AttentiveClassifier(embed_dim=model_pt.embed_dim, num_heads=16, depth=4, num_classes=174).to(device).eval()
)
load_pretrained_vjepa_classifier_weights(classifier, classifier_model_path)

Expand All @@ -167,5 +167,12 @@ def run_sample_inference():


if __name__ == "__main__":
# Configuring GPU acceleration for CUDA or MPS(Apple Silicon)
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
# Run with: `python -m notebooks.vjepa2_demo`
run_sample_inference()
run_sample_inference(device)