diff --git a/DPF/filters/data_filter.py b/DPF/filters/data_filter.py index 7d981f4..56e35e3 100644 --- a/DPF/filters/data_filter.py +++ b/DPF/filters/data_filter.py @@ -5,7 +5,7 @@ import pandas as pd from torch.utils.data import DataLoader, Dataset -from tqdm import tqdm +from tqdm.auto import tqdm from DPF.dataloaders.dataloader_utils import identical_collate_fn from DPF.modalities import ModalityName diff --git a/DPF/filters/images/face_focus_filter.py b/DPF/filters/images/face_focus_filter.py new file mode 100644 index 0000000..c30ca97 --- /dev/null +++ b/DPF/filters/images/face_focus_filter.py @@ -0,0 +1,202 @@ +import os +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any, Dict, List +import cv2 +import torch +from torch.multiprocessing import Pool, set_start_method +# Set the start method to 'spawn' at the beginning of your script +try: + set_start_method('spawn') +except RuntimeError: + pass +from functools import partial +import numpy as np +from scipy.stats import kurtosis +from DPF.types import ModalityToDataMapping +from DPF.utils import read_image_rgb_from_bytes +from .img_filter import ImageFilter +from PIL import Image, UnidentifiedImageError +from retinaface.pre_trained_models import get_model + +class FaceFocusFilter(ImageFilter): + def __init__( + self, + threshold: float = 2000.0, + detect_face = True, + workers: int = 1, + batch_size: int = 1, + pbar: bool = True, + device=None, + _pbar_position: int = 0 + ): + super().__init__(pbar, _pbar_position) + self.threshold = threshold + self.detect_face = detect_face + self.num_workers = workers + self.batch_size = batch_size + if not device: + self.device = 'cuda' if torch.cuda.is_available() else "cpu" + else: + self.device = device + self.face_detector = get_model("resnet50_2020-07-20", + max_size=2048, + device=self.device) + self.face_detector.eval() + + + @property + def result_columns(self) -> list[str]: + return ["face_focus_measure", "bg_focus_measure", "bbox", "faces_count", "confidence", "face_focus_pass", 'focus_pass'] + + @property + def dataloader_kwargs(self) -> dict[str, Any]: + return { + "num_workers": self.num_workers, + "batch_size": self.batch_size, + "drop_last": False, + } + + def preprocess_data( + self, + modality2data: ModalityToDataMapping, + metadata: dict[str, Any] + ) -> Any: + key = metadata[self.key_column] + try: + pil_image = read_image_rgb_from_bytes(modality2data['image']) + numpy_image = np.array(pil_image) + opencv_image = cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR) + return key, opencv_image + except (OSError, UnidentifiedImageError, ValueError) as e: + print(f"Error processing image for key {key}: {str(e)}") + return key, None + + def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]: + df_batch_labels = self._get_dict_from_schema() + + for key, image in batch: + info = self.process_image(image) + + if info: + df_batch_labels["face_focus_measure"].append(info["face_focus_measure"]) + df_batch_labels["bg_focus_measure"].append(info["bg_focus_measure"]) + df_batch_labels["bbox"].append(info["bbox"]) + df_batch_labels["faces_count"].append(info["faces_count"]) + df_batch_labels["confidence"].append(info["confidence"]) + df_batch_labels["face_focus_pass"].append(info["face_focus_pass"]) + df_batch_labels["focus_pass"].append(False) + else: + df_batch_labels["face_focus_measure"].append(0) + df_batch_labels["bg_focus_measure"].append(0) + df_batch_labels["bbox"].append(False) + df_batch_labels["faces_count"].append(0) + df_batch_labels["confidence"].append(0.0) + df_batch_labels["face_focus_pass"].append(False) + df_batch_labels["focus_pass"].append(False) + + df_batch_labels[self.key_column].append(key) + + return df_batch_labels + + # def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]: + # df_batch_labels = self._get_dict_from_schema() + + # # Create a partial function with self.process_image + # process_image_partial = partial(self.process_image) + + # # Use multiprocessing to process images in parallel + # with Pool() as pool: + # results = pool.map(process_image_partial, [image for _, image in batch]) + + # for (key, _), info in zip(batch, results): + # for column in self.result_columns: + # df_batch_labels[column].append(info.get(column, 0 if column in ['face_focus_measure', 'bg_focus_measure', 'faces_count', 'confidence'] else False)) + # df_batch_labels[self.key_column].append(key) + + # return df_batch_labels + + def tenengrad_variance(self, image): + """ + Calculate the Tenengrad variance focus measure for the given image. + """ + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + gx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) + gy = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) + gx_squared = np.square(gx) + gy_squared = np.square(gy) + tenengrad_variance = np.mean(gx_squared + gy_squared) + return tenengrad_variance + + def process_image(self, image): + + # Calculate the focus measure for the entire image + bg_focus_measure = self.tenengrad_variance(image) + + focus_pass = bg_focus_measure > self.threshold + if not self.detect_face: + # Check if the face is in focus + focus_pass = bg_focus_measure > self.threshold + return { + "face_focus_measure": 0, + "bg_focus_measure": bg_focus_measure, + "bbox": None, + "faces_count": 0, + "confidence":0, + "face_focus_pass": None, + "focus_pass": focus_pass + } + + # Detect faces in the image + faces = self.face_detector.predict_jsons(image) + + # if faces not found + if faces is None or len(faces) == 0 or faces[0]['score'] == -1 or not faces[0]['bbox']: + return { + "face_focus_measure": 0, + "bg_focus_measure": bg_focus_measure, + "bbox": None, + "faces_count": 0, + "confidence": 0, + "face_focus_pass": False, + "focus_pass": focus_pass + } + + # Get the face with the highest confidence + face = max(faces, key=lambda x: x['score']) + + faces = [x for x in faces if x['score'] > 0.5] + + bbox = face['bbox'] + landmarks = face['landmarks'] + + # Extract the face region + x1, y1, x2, y2 = map(int, bbox) + face_region = image[y1:y2, x1:x2] + + if face_region.size == 0: + # print(f"Warning: Empty face region detected for image") + return { + "face_focus_measure": 0, + "bg_focus_measure": bg_focus_measure, + "bbox": None, + "faces_count": len(faces), + "confidence": face["score"], + "face_focus_pass": False, + "focus_pass": focus_pass + } + + # Calculate the focus measure for the face region + face_focus_measure = self.tenengrad_variance(face_region) + + # Check if the face is in focus + in_focus = face_focus_measure > self.threshold + + return { + "face_focus_measure": face_focus_measure, + "bg_focus_measure": bg_focus_measure, + "bbox": bbox, + "faces_count": len(faces), + "confidence": face["score"], + "focus_pass": focus_pass, + "face_focus_pass": (len(faces) == 1) and in_focus and face['score'] > 0.5 + } diff --git a/DPF/filters/images/focus_peaking_filter.py b/DPF/filters/images/focus_peaking_filter.py new file mode 100644 index 0000000..671a462 --- /dev/null +++ b/DPF/filters/images/focus_peaking_filter.py @@ -0,0 +1,133 @@ +import os +from typing import Any +from deepface import DeepFace +import cv2 +import numpy as np +from scipy.stats import kurtosis +from DPF.types import ModalityToDataMapping +from DPF.utils import read_image_rgb_from_bytes +from .img_filter import ImageFilter + +class FocusFilter(ImageFilter): + """ + Filter for detecting faces and checking if the face is in focus. + + Parameters + ---------- + face_focus_threshold: float = 2000.0 + Threshold value for the Tenengrad variance focus measure to determine if the face is in focus. + workers: int = 16 + Number of processes to use for reading data and calculating focus scores. + batch_size: int = 64 + Batch size for processing images. + pbar: bool = True + Whether to use a progress bar. + """ + + def __init__( + self, + threshold: float = 2000.0, + workers: int = 1, + batch_size: int = 1, + pbar: bool = True, + _pbar_position: int = 0, + detect_face = True + ): + super().__init__(pbar, _pbar_position) + self.threshold = threshold + self.num_workers = workers + self.batch_size = batch_size + self.detect_face = detect_face + + @property + def result_columns(self) -> list[str]: + return ["in_focus", "focus_measure"] + + @property + def dataloader_kwargs(self) -> dict[str, Any]: + return { + "num_workers": self.num_workers, + "batch_size": self.batch_size, + "drop_last": False, + } + + def preprocess_data( + self, + modality2data: ModalityToDataMapping, + metadata: dict[str, Any] + ) -> Any: + key = metadata[self.key_column] + pil_image = read_image_rgb_from_bytes(modality2data['image']) + image = np.array(pil_image) + return key, image + + def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]: + df_batch_labels = self._get_dict_from_schema() + + for key, image in batch: + face_info = process_image(image, threshold=self.threshold) + if face_info: + df_batch_labels["face_detected"].append(True) + df_batch_labels["face_in_focus"].append(face_info["face_in_focus"]) + df_batch_labels["face_focus_measure"].append(face_info["face_focus_measure"]) + else: + df_batch_labels["face_detected"].append(False) + df_batch_labels["face_in_focus"].append(False) + df_batch_labels["face_focus_measure"].append(0.0) + df_batch_labels[self.key_column].append(key) + + return df_batch_labels + +def tenengrad_variance(image): + """ + Calculate the Tenengrad variance focus measure for the given image. + """ + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + gx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) + gy = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) + gx_squared = np.square(gx) + gy_squared = np.square(gy) + tenengrad_variance = np.mean(gx_squared + gy_squared) + return tenengrad_variance + +def process_image(image, threshold=2000.0): + # Calculate the focus measure for the entire image + focus_measure = tenengrad_variance(image) + + if not detect_faces: + return focus_measure + + # Detect faces in the image + faces = DeepFace.extract_faces(image, + enforce_detection=False, + detector_backend='retinaface') + + # Filter faces based on confidence and presence of both eyes + filtered_faces = [face for face in faces if face['confidence'] > 0.1] + if not filtered_faces: + return None + + face = max(filtered_faces, key=lambda x: x['confidence']) + + # Check if exactly one face is detected after filtering + if len(filtered_faces) == 1 and face['confidence'] > 0.5: + face['facial_area']['confidence'] = face['confidence'] + if face['facial_area']['left_eye'] is not None and face['facial_area']['right_eye'] is not None: + + + + # Extract the face region + x, y, w, h = face['facial_area']['x'], face['facial_area']['y'], face['facial_area']['w'], face['facial_area']['h'] + face_region = image[y:y+h, x:x+w] + + # Calculate the focus measure for the face region + face_focus_measure = tenengrad_variance(face_region) + + # Check if the face is in focus + face_in_focus = face_focus_measure > threshold + + # Add the focus information to the face dictionary + face['facial_area']['face_in_focus'] = face_in_focus + face['facial_area']['face_focus_measure'] = face_focus_measure + + return face['facial_area'] diff --git a/DPF/filters/images/grayscale_filter.py b/DPF/filters/images/grayscale_filter.py new file mode 100644 index 0000000..851e10d --- /dev/null +++ b/DPF/filters/images/grayscale_filter.py @@ -0,0 +1,118 @@ +import cv2 +import numpy as np +import joblib +from typing import Any, List, Dict +from skimage.feature import graycomatrix, graycoprops, local_binary_pattern +from scipy.fftpack import fft2, fftshift +from DPF.types import ModalityToDataMapping +from DPF.utils import read_image_rgb_from_bytes +from .img_filter import ImageFilter + + +class GrayscaleFilter(ImageFilter): + """ + Filter for estimating noise levels in images. + + Parameters + ---------- + model_path: str + Path to the trained noise estimation model (joblib file). + params_path: str + Path to the feature extraction parameters (joblib file). + workers: int = 16 + Number of processes to use for reading data and calculating noise levels. + batch_size: int = 64 + Batch size for processing images. + pbar: bool = True + Whether to use a progress bar. + """ + + def __init__( + self, + workers: int = 1, + batch_size: int = 1, + pbar: bool = True, + _pbar_position: int = 0 + ): + super().__init__(pbar, _pbar_position) + self.num_workers = workers + self.batch_size = batch_size + + @property + def result_columns(self) -> list[str]: + return ["grayscale_pass"] + + @property + def dataloader_kwargs(self) -> dict[str, Any]: + return { + "num_workers": self.num_workers, + "batch_size": self.batch_size, + "drop_last": False, + } + + def preprocess_data( + self, + modality2data: ModalityToDataMapping, + metadata: dict[str, Any] + ) -> Any: + key = metadata[self.key_column] + image = read_image_rgb_from_bytes(modality2data['image']) + return key, image + + def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]: + df_batch_labels = self._get_dict_from_schema() + + for key, image in batch: + try: + image = np.array(image) + result = self.process_image(image) + df_batch_labels["grayscale_pass"].append(not result) + except Exception as e: + print(f"Error processing image: {str(e)}") + df_batch_labels["grayscale_pass"].append(False) + df_batch_labels[self.key_column].append(key) + + return df_batch_labels + + + def process_image(self, image): + """ + Detects if an image is grayscale (black and white) or not. + + Args: + image_path (str): Path to the image file. + + Returns: + bool: True if the image is grayscale, False otherwise. + """ + + # Check if the image has only one channel + if len(image.shape) == 2: + return True + + # Convert the image to the RGB color space + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # Check if the histograms for all channels are identical + hist_r = cv2.calcHist([image_rgb], [0], None, [256], [0, 256]) + hist_g = cv2.calcHist([image_rgb], [1], None, [256], [0, 256]) + hist_b = cv2.calcHist([image_rgb], [2], None, [256], [0, 256]) + if np.array_equal(hist_r, hist_g) and np.array_equal(hist_r, hist_b): + return True + + # Check if the histogram is concentrated along the diagonal in the RGB color cube + hist_3d = cv2.calcHist([image_rgb], [0, 1, 2], None, [256, 256, 256], [0, 256, 0, 256, 0, 256]) + diagonal_sum = np.sum([hist_3d[i, i, i] for i in range(256)]) + total_sum = np.sum(hist_3d) + if diagonal_sum / total_sum > 0.9: + return True + + # Check for low saturation + image_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + saturation = image_hsv[:, :, 1] + mean_saturation = np.mean(saturation) + std_saturation = np.std(saturation) + if mean_saturation < 10 and std_saturation < 5: + return True + + return False diff --git a/DPF/filters/images/hash_filters.py b/DPF/filters/images/hash_filters.py index dc410e5..63f2129 100644 --- a/DPF/filters/images/hash_filters.py +++ b/DPF/filters/images/hash_filters.py @@ -9,6 +9,7 @@ from ...types import ModalityToDataMapping from .img_filter import ImageFilter +Image.MAX_IMAGE_PIXELS = None def get_phash(pil_img: Image.Image, hash_size: int = 8, highfreq_factor: int = 4) -> str: img_size = hash_size * highfreq_factor diff --git a/DPF/filters/images/noise_estimation_filter.py b/DPF/filters/images/noise_estimation_filter.py new file mode 100644 index 0000000..e855ea5 --- /dev/null +++ b/DPF/filters/images/noise_estimation_filter.py @@ -0,0 +1,140 @@ +import cv2 +import numpy as np +import joblib +from typing import Any, List, Dict +from skimage.feature import graycomatrix, graycoprops, local_binary_pattern +from scipy.fftpack import fft2, fftshift +from DPF.types import ModalityToDataMapping +from DPF.utils import read_image_rgb_from_bytes +from .img_filter import ImageFilter + + +class NoiseEstimationFilter(ImageFilter): + """ + Filter for estimating noise levels in images. + + Parameters + ---------- + model_path: str + Path to the trained noise estimation model (joblib file). + params_path: str + Path to the feature extraction parameters (joblib file). + workers: int = 16 + Number of processes to use for reading data and calculating noise levels. + batch_size: int = 64 + Batch size for processing images. + pbar: bool = True + Whether to use a progress bar. + """ + + def __init__( + self, + model_path: str, + params_path: str, + workers: int = 1, + batch_size: int = 1, + pbar: bool = True, + _pbar_position: int = 0 + ): + super().__init__(pbar, _pbar_position) + self.num_workers = workers + self.batch_size = batch_size + self.model = joblib.load(model_path) + self.params = joblib.load(params_path) + + @property + def result_columns(self) -> list[str]: + return ["estimated_noise_level", "noise_filter_pass"] + + @property + def dataloader_kwargs(self) -> dict[str, Any]: + return { + "num_workers": self.num_workers, + "batch_size": self.batch_size, + "drop_last": False, + } + + def preprocess_data( + self, + modality2data: ModalityToDataMapping, + metadata: dict[str, Any] + ) -> Any: + key = metadata[self.key_column] + image = read_image_rgb_from_bytes(modality2data['image']) + return key, image + + def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]: + df_batch_labels = self._get_dict_from_schema() + + for key, image in batch: + try: + noise_level = self.process_image(image) + df_batch_labels["estimated_noise_level"].append(noise_level) + df_batch_labels["noise_filter_pass"].append(noise_level<1.0) + except Exception as e: + print(f"Error processing image: {str(e)}") + df_batch_labels["estimated_noise_level"].append(None) + df_batch_labels["noise_filter_pass"].append(False) + df_batch_labels[self.key_column].append(key) + + return df_batch_labels + + def process_image(self, image): + image = np.array(image) + cropped_image = self.crop_image(image) + features = self.extract_features(cropped_image) + noise_level = self.model.predict([features])[0] + return noise_level + + def crop_image(self, image): + target_size = self.params['target_size'] + height, width = image.shape[:2] + if height > width: + scale = target_size / height + new_height = target_size + new_width = int(width * scale) + else: + scale = target_size / width + new_width = target_size + new_height = int(height * scale) + + resized = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) + + start_y = max(0, (new_height - target_size) // 2) + start_x = max(0, (new_width - target_size) // 2) + cropped = resized[start_y:start_y+target_size, start_x:start_x+target_size] + + return cropped + + def extract_features(self, image): + gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + + # Statistical features + std_dev = np.std(gray) + entropy = -np.sum(np.histogram(gray, bins=256, density=True)[0] * + np.log2(np.histogram(gray, bins=256, density=True)[0] + 1e-7)) + + # Edge detection + edges = cv2.Canny(gray, self.params['canny_low'], self.params['canny_high']) + edge_density = np.sum(edges) / (edges.shape[0] * edges.shape[1]) + + # Texture analysis - GLCM + glcm = graycomatrix(gray, self.params['glcm_distances'], self.params['glcm_angles'], + 256, symmetric=True, normed=True) + contrast = graycoprops(glcm, 'contrast')[0, 0] + homogeneity = graycoprops(glcm, 'homogeneity')[0, 0] + + # Texture analysis - LBP + lbp = local_binary_pattern(gray, self.params['lbp_n_points'], self.params['lbp_radius'], method='uniform') + lbp_hist, _ = np.histogram(lbp.ravel(), bins=np.arange(0, self.params['lbp_n_points'] + 3), + range=(0, self.params['lbp_n_points'] + 2)) + lbp_hist = lbp_hist.astype("float") + lbp_hist /= (lbp_hist.sum() + 1e-7) + + # Frequency domain analysis + f_transform = fft2(gray) + f_shift = fftshift(f_transform) + magnitude_spectrum = 20 * np.log(np.abs(f_shift) + 1) + mean_magnitude = np.mean(magnitude_spectrum) + + return np.concatenate(([std_dev, entropy, edge_density, contrast, homogeneity, mean_magnitude], lbp_hist)) \ No newline at end of file diff --git a/DPF/filters/images/ocr_model/dataset.py b/DPF/filters/images/ocr_model/dataset.py index 2d75d6f..d30ec44 100644 --- a/DPF/filters/images/ocr_model/dataset.py +++ b/DPF/filters/images/ocr_model/dataset.py @@ -11,6 +11,7 @@ # TODO(review) - зависимость отсутствует в requirements.txt from PIL import Image +Image.MAX_IMAGE_PIXELS = None # TODO(review) - зачем наследоваться от object? class ResizeNormalize: diff --git a/DPF/filters/images/oily_skin_filter.py b/DPF/filters/images/oily_skin_filter.py new file mode 100644 index 0000000..a74b3ae --- /dev/null +++ b/DPF/filters/images/oily_skin_filter.py @@ -0,0 +1,100 @@ +import os +from typing import Any +from urllib.request import urlretrieve + +import torch +import torch.nn as nn +from transformers import ViTImageProcessor, ViTForImageClassification +from PIL import Image + +from DPF.types import ModalityToDataMapping +from DPF.utils import read_image_rgb_from_bytes +from .img_filter import ImageFilter + +class OilySkinFilter(ImageFilter): + """ + Filter for skin type detection using a ViT model. + + Parameters + ---------- + model_name: str + Name or path of the pre-trained model to use + device: str = "cuda:0" + Device to use + workers: int = 16 + Number of processes to use for reading data and calculating scores + batch_size: int = 64 + Batch size for model + pbar: bool = True + Whether to use a progress bar + """ + + def __init__( + self, + model_name: str = "dima806/skin_types_image_detection", + device: str = "cuda:0", + workers: int = 16, + batch_size: int = 64, + pbar: bool = True, + _pbar_position: int = 0 + ): + super().__init__(pbar, _pbar_position) + + self.num_workers = workers + self.batch_size = batch_size + self.device = device + + self.processor = ViTImageProcessor.from_pretrained(model_name) + self.model = ViTForImageClassification.from_pretrained(model_name) + self.model.to(self.device) + + self.id2label = self.model.config.id2label + + @property + def result_columns(self) -> list[str]: + return ["skin_type", "confidence_score"] + + @property + def dataloader_kwargs(self) -> dict[str, Any]: + return { + "num_workers": self.num_workers, + "batch_size": self.batch_size, + "drop_last": False, + } + + def preprocess_data( + self, + modality2data: ModalityToDataMapping, + metadata: dict[str, Any] + ) -> Any: + key = metadata[self.key_column] + pil_image = read_image_rgb_from_bytes(modality2data['image']) + + # Apply preprocessing + inputs = self.processor(images=pil_image, return_tensors="pt") + pixel_values = inputs['pixel_values'].squeeze() + + return key, pixel_values + + def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]: + df_batch_labels = self._get_dict_from_schema() + + keys, pixel_values = list(zip(*batch)) + pixel_values = torch.stack(pixel_values).to(self.device) + + with torch.no_grad(): + outputs = self.model(pixel_values=pixel_values) + + logits = outputs.logits + probabilities = torch.nn.functional.softmax(logits, dim=-1) + predictions = torch.argmax(logits, dim=-1) + + for key, pred, prob in zip(keys, predictions, probabilities): + skin_type = self.id2label[pred.item()] + confidence = prob[pred].item() + + df_batch_labels[self.key_column].append(key) + df_batch_labels["skin_type"].append(skin_type) + df_batch_labels["confidence_score"].append(confidence) + + return df_batch_labels \ No newline at end of file diff --git a/DPF/filters/images/text_detection_filter.py b/DPF/filters/images/text_detection_filter.py index b44602a..ba77084 100644 --- a/DPF/filters/images/text_detection_filter.py +++ b/DPF/filters/images/text_detection_filter.py @@ -13,7 +13,7 @@ class CRAFTFilter(ImageFilter): def __init__( self, - weights_folder: str, + weights_folder: str = "/tmp/datasets_utils", use_refiner: bool = False, device: str = "cuda:0", workers: int = 16, @@ -31,7 +31,7 @@ def __init__( @property def result_columns(self) -> list[str]: - return ["text_boxes", "num_text_boxes", "text_area"] + return ["text_boxes", "num_text_boxes", "text_area", "text_detection_pass"] @property def dataloader_kwargs(self) -> dict[str, Any]: @@ -61,5 +61,6 @@ def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]: df_batch_labels["num_text_boxes"].append(len(boxes)) df_batch_labels["text_area"].append(boxes_area(boxes)/(orig_size[0]*orig_size[1])) df_batch_labels[self.key_column].append(key) + df_batch_labels["text_detection_pass"].append(len(boxes) == 0) return df_batch_labels diff --git a/DPF/filters/images/watermarks_filter.py b/DPF/filters/images/watermarks_filter.py index 1a57440..aad4f33 100644 --- a/DPF/filters/images/watermarks_filter.py +++ b/DPF/filters/images/watermarks_filter.py @@ -11,7 +11,7 @@ except ImportError: from torch.utils.data import default_collate -from huggingface_hub import cached_download, hf_hub_url +from huggingface_hub import hf_hub_download, hf_hub_url from torchvision import models, transforms from DPF.filters.utils import FP16Module @@ -46,11 +46,14 @@ def get_watermarks_detection_model( num_ftrs = model_ft.fc.in_features model_ft.fc = nn.Linear(num_ftrs, 2) - config_file_url = hf_hub_url(repo_id=config["repo_id"], filename=config["filename"]) - cached_download( - config_file_url, cache_dir=cache_dir, force_filename=config["filename"] + # Download weights directly using hf_hub_download + weights_path = hf_hub_download( + repo_id=config["repo_id"], + filename=config["filename"], + cache_dir=cache_dir ) - weights = torch.load(os.path.join(cache_dir, config["filename"]), device) + + weights = torch.load(weights_path, device) model_ft.load_state_dict(weights) if fp16: @@ -113,7 +116,7 @@ def __init__( @property def result_columns(self) -> list[str]: - return [f"watermark_{self.watermarks_model}"] + return ["watermark_filter_pass"] @property def dataloader_kwargs(self) -> dict[str, Any]: @@ -141,8 +144,10 @@ def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]: with torch.no_grad(): outputs = self.model(batch) - df_batch_labels[f"watermark_{self.watermarks_model}"].extend( - torch.max(outputs, 1)[1].cpu().reshape(-1).tolist() + # Get predictions (0 or 1) and convert to boolean (True if no watermark, False if watermark) + predictions = torch.max(outputs, 1)[1].cpu().reshape(-1).tolist() + df_batch_labels["watermark_filter_pass"].extend( + [pred == 0 for pred in predictions] ) df_batch_labels[self.key_column].extend(keys) diff --git a/DPF/filters/texts/google_translate_filter.py b/DPF/filters/texts/google_translate_filter.py index 7d2212e..c8f9d89 100644 --- a/DPF/filters/texts/google_translate_filter.py +++ b/DPF/filters/texts/google_translate_filter.py @@ -4,7 +4,7 @@ import pandas as pd from deep_translator import GoogleTranslator from deep_translator.base import BaseTranslator -from tqdm import tqdm +from tqdm.auto import tqdm from DPF.filters import ColumnFilter diff --git a/DPF/filters/videos/image_filter_adapter.py b/DPF/filters/videos/image_filter_adapter.py index 3684412..2f09564 100644 --- a/DPF/filters/videos/image_filter_adapter.py +++ b/DPF/filters/videos/image_filter_adapter.py @@ -10,6 +10,7 @@ from .video_filter import VideoFilter +Image.MAX_IMAGE_PIXELS = None class ImageFilterAdapter(VideoFilter): """ diff --git a/DPF/processors/processor.py b/DPF/processors/processor.py index 1c21820..55c1015 100644 --- a/DPF/processors/processor.py +++ b/DPF/processors/processor.py @@ -4,7 +4,7 @@ import pandas as pd from torch.utils.data import DataLoader, Dataset -from tqdm import tqdm +from tqdm.auto import tqdm from DPF.configs import DatasetConfig, config2format from DPF.connectors import Connector, LocalConnector diff --git a/DPF/transforms/image_resize_transforms.py b/DPF/transforms/image_resize_transforms.py index 469a77d..d6a1e38 100644 --- a/DPF/transforms/image_resize_transforms.py +++ b/DPF/transforms/image_resize_transforms.py @@ -8,6 +8,7 @@ ) from DPF.transforms.resizer import Resizer +Image.MAX_IMAGE_PIXELS = None class ImageResizeTransforms(BaseFilesTransforms): diff --git a/DPF/utils/image_utils.py b/DPF/utils/image_utils.py index ff6d145..623641d 100644 --- a/DPF/utils/image_utils.py +++ b/DPF/utils/image_utils.py @@ -2,6 +2,7 @@ from PIL import Image +Image.MAX_IMAGE_PIXELS = None def read_image_rgb(path: str, force_rgb: bool = True) -> Image.Image: pil_img = Image.open(path) diff --git a/pyproject.toml b/pyproject.toml index b59d1e8..94d2d70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,11 @@ filters = [ 'py3langid', 'deep_translator', 'huggingface_hub', - 'videohash' + 'scikit-image', + 'scikit-learn', + 'joblib', + 'retinaface_pytorch' + 'videohash' ] nsfw_detector = ['tensorflow', 'autokeras'] llava = [