-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
78 lines (62 loc) · 2.11 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import cv2
from PIL import Image
import torch
from transformers import AutoModelForImageClassification, ViTImageProcessor
model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
processor = ViTImageProcessor.from_pretrained('Falconsai/nsfw_image_detection')
def getimage():
print("Enter the file path of the image: ")
while True:
path = input()
if path:
break
if path.endswith(".png") or path.endswith(".jpg") or path.endswith(".jpeg"):
try:
img = Image.open(path)
except Exception as e:
print("Invalid file path. Error: ", e)
return
with torch.no_grad():
inputs = processor(images=img, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_label = logits.argmax(-1).item()
if predicted_label:
print("NSFW")
else:
print("Not NSFW")
elif path.endswith(".mp4") or path.endswith(".webm"):
videoShit(path)
else:
print("Invalid file format")
def capture_screenshot(path):
vidObj = cv2.VideoCapture(path)
fps = vidObj.get(cv2.CAP_PROP_FPS)
frames_to_skip = int(fps * 10)
count = 0
success = 1
saved_image_names = []
while success:
success, image = vidObj.read()
if frames_to_skip > 0 and count % frames_to_skip == 0:
image_name = f"image_{count // frames_to_skip}.png"
cv2.imwrite(image_name, image)
saved_image_names.append(image_name)
count += 1
vidObj.release()
return saved_image_names
def videoShit(video_path):
imageName = capture_screenshot(video_path)
for cum in imageName:
img = Image.open(cum)
with torch.no_grad():
inputs = processor(images=img, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_label = logits.argmax(-1).item()
if predicted_label:
print("NSFW")
else:
print("Not NSFW")
if __name__ == "__main__":
getimage()