diff --git a/main.py b/main.py new file mode 100644 index 0000000..10baff2 --- /dev/null +++ b/main.py @@ -0,0 +1,78 @@ +import cv2 +from PIL import Image +import torch +from transformers import AutoModelForImageClassification, ViTImageProcessor + +model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection") +processor = ViTImageProcessor.from_pretrained('Falconsai/nsfw_image_detection') + +def getimage(): + print("Enter the file path of the image: ") + while True: + path = input() + if path: + break + if path.endswith(".png") or path.endswith(".jpg") or path.endswith(".jpeg"): + try: + img = Image.open(path) + except Exception as e: + print("Invalid file path. Error: ", e) + return + with torch.no_grad(): + inputs = processor(images=img, return_tensors="pt") + outputs = model(**inputs) + logits = outputs.logits + + predicted_label = logits.argmax(-1).item() + if predicted_label: + print("NSFW") + else: + print("Not NSFW") + + elif path.endswith(".mp4") or path.endswith(".webm"): + videoShit(path) + + else: + print("Invalid file format") + +def capture_screenshot(path): + vidObj = cv2.VideoCapture(path) + fps = vidObj.get(cv2.CAP_PROP_FPS) + frames_to_skip = int(fps * 10) + + count = 0 + success = 1 + saved_image_names = [] + + while success: + success, image = vidObj.read() + if frames_to_skip > 0 and count % frames_to_skip == 0: + image_name = f"image_{count // frames_to_skip}.png" + cv2.imwrite(image_name, image) + saved_image_names.append(image_name) + + count += 1 + + vidObj.release() + + return saved_image_names + + +def videoShit(video_path): + imageName = capture_screenshot(video_path) + for cum in imageName: + img = Image.open(cum) + with torch.no_grad(): + inputs = processor(images=img, return_tensors="pt") + outputs = model(**inputs) + logits = outputs.logits + + predicted_label = logits.argmax(-1).item() + if predicted_label: + print("NSFW") + else: + print("Not NSFW") + + +if __name__ == "__main__": + getimage() \ No newline at end of file diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..74d71b0 --- /dev/null +++ b/readme.md @@ -0,0 +1,90 @@ +# Enhanced Version: NSFW Detection Telegram Bot + +Welcome to the NSFW Detection Telegram Bot, an advanced tool designed to identify Not Safe for Work (NSFW) content in images through cutting-edge machine learning algorithms. The bot is written in Python using pyrogram, torch, transformers, TensorFlow, OpenCV, Pillow, and MongoDB. + +## Acknowledgments + +This project leverages the powerful `Falconsai/nsfw_image_detection` pre-trained model and dataset. We extend our gratitude to them for their contributions, enabling the functionality of this bot. + +## Getting Started + +
+ Python 3.9 | + Telegram API Key | + Telegram Bot Token | + MongoDB URI +
+ +Follow these simple steps to unleash the power of the NSFW Detection Telegram Bot: + +1. Begin by ensuring you have Git installed. If not, you can install it by running: + + ```bash + sudo apt-get update + sudo apt-get install git + ``` + + Then, clone the repository into your terminal: + + ```bash + git clone https://github.com/ArshCypherZ/NSFWDetection + ``` + +2. Now navigate into the directory: + + ```bash + cd NSFWDetection + ``` + +3. Install the necessary dependencies. Execute the following command: + + ```bash + pip3 install -U -r requirements.txt + ``` + +4. Acquire a Telegram Bot API token by creating a new bot through [Telegram BotFather](https://core.telegram.org/bots#botfather). + +5. Personalize the `telegram/__init__.py` script by replacing the variables with your Telegram Bot API token. + +6. Launch the bot using the following command: + + ```bash + python3 -m telegram + ``` + +7. Integrate the bot into your Telegram group or chat, and send an image for analysis. The bot will promptly provide you with the results. + +## Dependencies + +Ensure you have the following dependencies installed to run the NSFW Detection Telegram Bot seamlessly: + +- Python 3.x +- TensorFlow +- Pillow +- pyrogram 2.x +- motor +- OpenCV +- torch +- transformers + +## Script Testing (Unrelated to Telegram) + +Evaluate the script's performance by executing the command below in your terminal and supplying the image file path: + +```bash +pip3 install -U -r requirements.txt +``` + +```bash +python3 main.py +``` + +## Support the Project + +If you find the NSFW Detection project useful, consider supporting the project through a donation. Your contributions help us maintain and improve the service. + +- **UPI**: `arsh-j@paytm` diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..4babe01 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +transformers +pillow +pyrogram +torch +opencv-python +uvloop +cryptg +tgcrypto +motor \ No newline at end of file diff --git a/telegram/__init__.py b/telegram/__init__.py new file mode 100644 index 0000000..523f26b --- /dev/null +++ b/telegram/__init__.py @@ -0,0 +1,11 @@ +from pyrogram import Client +from uvloop import install + +api_id = '681' # Your api_id from my.telegram.org +api_hash = '453a' # Your api_hash from my.telegram.org +bot_token = '681:AAv1OJQhamQ' # Your bot token from @BotFather +db_url = 'mongodb://localhost:27017' # Your MongoDB URL from mongodb.com + +install() +client = Client("antinsfw", api_id, api_hash, bot_token=bot_token) + diff --git a/telegram/__main__.py b/telegram/__main__.py new file mode 100644 index 0000000..ee7b779 --- /dev/null +++ b/telegram/__main__.py @@ -0,0 +1,23 @@ +import asyncio +import importlib +from telegram import client +from uvloop import install +from pyrogram import idle +import logging + +loop = asyncio.get_event_loop() + + +imported_module = importlib.import_module("antinsfw.antinsfw") +imported_module = importlib.import_module("antinsfw.stats") +imported_module = importlib.import_module("antinsfw.db") + +async def gae(): + install() + await client.start() + await idle() + await client.stop() + +if __name__ == "__main__": + logging.info("Bot Started! Powered By @SpiralTechDivision") + loop.run_until_complete(gae()) \ No newline at end of file diff --git a/telegram/antinsfw.py b/telegram/antinsfw.py new file mode 100644 index 0000000..a00a617 --- /dev/null +++ b/telegram/antinsfw.py @@ -0,0 +1,158 @@ +import cv2 +import os +import logging +from PIL import Image +import torch +from telegram import client +from pyrogram import filters +from telegram.db import is_nsfw, add_chat, add_user, add_nsfw, remove_nsfw +from transformers import AutoModelForImageClassification, ViTImageProcessor +from pyrogram.enums import ChatType +from pyrogram.types import InlineKeyboardButton, InlineKeyboardMarkup + +model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection") +processor = ViTImageProcessor.from_pretrained('Falconsai/nsfw_image_detection') + +@client.on_message(filters.photo | filters.sticker | filters.animation | filters.video) +async def getimage(client, event): + if event.photo: + file_id = event.photo.file_id + if (await is_nsfw(file_id)): + await send_msg(event) + return + try: + await client.download_media(event.photo, os.path.join(os.getcwd(), "image.png")) + except Exception as e: + logging.error(f"Failed to download image. Error: {e}") + return + + elif event.sticker: + file_id = event.sticker.file_id + if (await is_nsfw(file_id)): + await send_msg(event) + return + if event.sticker.mime_type == "video/webm": + try: + await client.download_media(event.sticker, os.path.join(os.getcwd(), "animated.mp4")) + except Exception as e: + logging.error(f"Failed to download animated sticker. Error: {e}") + return + await videoShit(event, "animated.mp4", file_id) + + else: + try: + await client.download_media(event.sticker, os.path.join(os.getcwd(), "image.png")) + except Exception as e: + logging.error(f"Failed to download sticker. Error: {e}") + return + + elif event.animation: + file_id = event.animation.file_id + if (await is_nsfw(file_id)): + await send_msg(event) + return + try: + await client.download_media(event.animation, os.path.join(os.getcwd(), "gif.mp4")) + except Exception as e: + logging.error(f"Failed to download GIF. Error: {e}") + return + await videoShit(event, "gif.mp4", file_id) + + elif event.video: + file_id = event.video.file_id + if (await is_nsfw(file_id)): + await send_msg(event) + return + try: + await client.download_media(event.video, os.path.join(os.getcwd(), "video.mp4")) + except Exception as e: + logging.error(f"Failed to download video. Error: {e}") + return + await videoShit(event, "video.mp4", file_id) + else: + return + + img = Image.open("image.png") + with torch.no_grad(): + inputs = processor(images=img, return_tensors="pt") + outputs = model(**inputs) + logits = outputs.logits + + predicted_label = logits.argmax(-1).item() + if predicted_label: + await add_nsfw(file_id) + await send_msg(event) + else: + await remove_nsfw(file_id) + return + +@client.on_message(filters.command("start")) +async def start(_, event): + buttons = [[InlineKeyboardButton("Support Chat", url="t.me/SpiralTechDivision"), InlineKeyboardButton("News Channel", url="t.me/SpiralUpdates")]] + reply_markup = InlineKeyboardMarkup(buttons) + await event.reply_text("Hello, I am a bot that detects NSFW (Not Safe for Work) images. Send me an image to check if it is NSFW or not. In groups, just make me an admin with delete message rights and I will delete all NSFW images sent by anyone.", reply_markup=reply_markup) + if event.from_user.username: + await add_user(event.from_user.id, event.from_user.username) + else: + await add_user(event.from_user.id, "None") + + +async def send_msg(event): + if event.chat.type == ChatType.SUPERGROUP: + try: + await event.delete() + except: + pass + try: + await client.send_message(event.chat.id, "NSFW image detected :)") + except: + pass + await add_chat(event.chat.id) + else: + await event.reply("NSFW Image.") + + + +def capture_screenshot(path): + vidObj = cv2.VideoCapture(path) + fps = vidObj.get(cv2.CAP_PROP_FPS) + frames_to_skip = int(fps * 10) + + count = 0 + success = 1 + saved_image_names = [] + + while success: + success, image = vidObj.read() + if frames_to_skip > 0 and count % frames_to_skip == 0: + image_name = f"image_{count // frames_to_skip}.png" + cv2.imwrite(image_name, image) + saved_image_names.append(image_name) + + count += 1 + + vidObj.release() + + return saved_image_names + + +async def videoShit(event, video_path, file_id): + if (await is_nsfw(file_id)): + await send_msg(event) + return + imageName = capture_screenshot(video_path) + for cum in imageName: + img = Image.open(cum) + with torch.no_grad(): + inputs = processor(images=img, return_tensors="pt") + outputs = model(**inputs) + logits = outputs.logits + + predicted_label = logits.argmax(-1).item() + if predicted_label: + await add_nsfw(file_id) + await send_msg(event) + return + else: + await remove_nsfw(file_id) + return \ No newline at end of file diff --git a/telegram/db.py b/telegram/db.py new file mode 100644 index 0000000..6811bdf --- /dev/null +++ b/telegram/db.py @@ -0,0 +1,27 @@ +import motor.motor_asyncio +from telegram import db_url + +client = motor.motor_asyncio.AsyncIOMotorClient(db_url) +db = client['nsfw'] + +userdb = db.users +chatdb = db.chats +files = db.files + +async def add_user(user_id: int, username: str): + await userdb.update_one({'user_id': user_id}, {'$set': {'username': username}}, upsert=True) + +async def add_chat(chat_id: int): + await chatdb.update_one({'chat_id': chat_id}, {'$set': {'chat_id': chat_id}}, upsert=True) + +async def is_nsfw(file_id: str): + m = await files.find_one({'file_id': file_id}) + if m: + return m['nsfw'] + return False + +async def add_nsfw(file_id: str): + await files.update_one({'file_id': file_id}, {'$set': {'nsfw': True}}, upsert=True) + +async def remove_nsfw(file_id: str): + await files.update_one({'file_id': file_id}, {'$set': {'nsfw': False}}, upsert=True) \ No newline at end of file diff --git a/telegram/stats.py b/telegram/stats.py new file mode 100644 index 0000000..c20ff58 --- /dev/null +++ b/telegram/stats.py @@ -0,0 +1,12 @@ +from telegram import client +from pyrogram import filters +from telegram.db import db + +@client.on_message(filters.command("stats")) +async def stats(_, message): + user_count = await db.users.count_documents({}) + chat_count = await db.chats.count_documents({}) + nsfw_count = await db.files.count_documents({"nsfw": True}) + await message.reply_text( + f"**Stats:**\n\nUsers: {user_count}\nChats: {chat_count}\nNSFW Files: {nsfw_count}" + )