diff --git a/README.md b/README.md index 7eced42..485903a 100644 --- a/README.md +++ b/README.md @@ -8,54 +8,64 @@ This project leverages Python, Telegram API, and data science tools to build a r ## 🏗️ 1. Project Setup and Data Collection (Task 1) -**Branch created:** `task1` for initial project setup and data scraping. - **Deliverables:** - * GitHub code for Task 1 (data ingestion and preprocessing). - * Data summary (1-2 pages) covering data preparation and labeling steps. +* GitHub code for Task 1 (data ingestion and preprocessing). +* Data summary (1-2 pages) covering data preparation and labeling steps. ### Repo/Project Structure -``` +```bash EthioMart/ - ├── src/ - │ ├── telegram_scraper.py # Collects raw data from Telegram channels - │ └── preprocessor.py # Cleans and preprocesses raw text data - ├── config/ - │ └── config.py # Stores configuration variables (e.g., API credentials, channel list) - ├── data/ - │ ├── raw/ # Stores raw scraped data (e.g., telegram_data.csv) - │ │ └── telegram_data.csv - │ ├── processed/ # Stores cleaned and preprocessed data - │ │ └── clean_telegram_data.csv - │ └── labeled/ # Will store manually and semi-automatically labeled data (for Task 2) - ├── photos/ # Stores downloaded images from Telegram messages - ├── notebooks/ # Jupyter notebooks for EDA, experimentation, and documentation - │ ├── data_ingestion_eda.ipynb - │ └── data_preprocessing_eda.ipynb - ├── outputs/ # Stores generated plots and visualizations - ├── reports/ # For interim and final project reports - ├── tests/ # Unit tests for various modules (e.g., preprocessor) - │ └── test_preprocessor.py - ├── .github/workflows/ # CI/CD pipelines (e.g., for DVC and code quality) - ├── .env # Environment variables (e.g., Telegram API keys) - ├── requirements.txt # Python package dependencies - ├── .gitignore # Files/directories to ignore in Git - ├── README.md # Project overview and setup instructions - └── DVC.md # Documentation for Data Version Control (DVC) setup (Future) +├── src/ +│ ├── telegram_scraper.py # Collects raw data from Telegram channels +│ ├── preprocessor.py # Cleans and preprocesses raw text data +│ ├── data_labeler.py # Rule-based labeling for NER (Task 2) +│ └── model_finetuner.py # Fine-tunes NER models (Task 3 & 4) +├── config/ +│ └── config.py # Stores configuration variables (e.g., API credentials, channel list) +├── data/ +│ ├── raw/ # Stores raw scraped data (e.g., telegram_data.csv) +│ │ └── telegram_data.csv +│ ├── processed/ # Stores cleaned and preprocessed data +│ │ └── clean_telegram_data.csv +│ └── labeled/ # Stores manually and semi-automatically labeled data +│ └── telegram_ner_data_rule_based.conll +├── models/ # Stores fine-tuned NER models (Task 3 & 4) +│ └── afro_xlmr_ner_fine_tuned/ +├── photos/ # Stores downloaded images from Telegram messages +├── notebooks/ # Jupyter notebooks for EDA, experimentation, and documentation +│ ├── data_ingestion_eda.ipynb +│ └── data_preprocessing_eda.ipynb +├── outputs/ # Stores generated plots and visualizations +│ └── plots/ +├── reports/ # For interim and final project reports +├── tests/ # Unit tests for various modules (e.g., preprocessor) +│ └── test_preprocessor.py +│ └── test_telegram_scraper.py +├── .github/workflows/ # CI/CD pipelines (e.g., for DVC and code quality) +├── .env # Environment variables (e.g., Telegram API keys) +├── requirements.txt # Python package dependencies +├── .gitignore # Files/directories to ignore in Git +├── README.md # Project overview and setup instructions ``` ### Tech Stack - Tools Used - * **Python 3.11+** - * **Telethon:** For interacting with the Telegram API to scrape messages and metadata. - * **Pandas, NumPy:** For efficient data manipulation and analysis. - * **Matplotlib, Seaborn:** For data visualization and exploratory data analysis. - * **Jupyter Notebook:** For interactive data exploration and reproducible analysis. - * **re (Regex):** For advanced text cleaning and pattern matching. - * **pathlib:** For robust path management. - * **pytest:** For unit testing the project's functions. +* **Python 3.11+** +* **Telethon:** For interacting with the Telegram API to scrape messages and metadata. +* **Pandas, NumPy:** For efficient data manipulation and analysis. +* **Matplotlib, Seaborn:** For data visualization and exploratory data analysis. +* **Jupyter Notebook:** For interactive data exploration and reproducible analysis. +* **re (Regex):** For advanced text cleaning and pattern matching. +* **pathlib:** For robust path management. +* **pytest:** For unit testing the project's functions. +* **Hugging Face transformers:** For loading, fine-tuning, and evaluating transformer models. +* **Hugging Face datasets:** For efficient data loading and preprocessing for LLMs. +* **seqeval:** For evaluating NER model performance. +* **torch (PyTorch):** The deep learning framework used by the models. +* **scikit-learn:** For data splitting utilities. +* **tensorboard:** For visualizing training progress. ## 🚀 2. Usage and Data Pipeline Steps @@ -73,13 +83,13 @@ This section guides you through the process of setting up the project, collectin 2. **Set Up Environment Variables:** Create a `.env` file in the project root: - ``` + ```bash TELEGRAM_API_ID=your_api_id TELEGRAM_API_HASH=your_api_hash TELEGRAM_PHONE_NUMBER=your_phone_number ``` - Obtain `API_ID` and `API_HASH` from [my.telegram.org](https://my.telegram.org/). + Obtain `API_ID` and `API_HASH` from [my.telegram.org](https://my.telegram.org). 3. **Install Dependencies:** @@ -98,7 +108,7 @@ This section guides you through the process of setting up the project, collectin ``` *Note: You will be prompted to enter a Telegram verification code for the first run.* - *Output: `data/raw/telegram_data.csv` and images in `photos/`.* + **Output:** `data/raw/telegram_data.csv` and images in `photos/`. 2. **Perform Initial Data Ingestion EDA (`notebooks/data_ingestion_eda.ipynb`):** Explore the characteristics of the raw scraped data (e.g., missing values, distribution of views/reactions, presence of images). @@ -108,30 +118,26 @@ This section guides you through the process of setting up the project, collectin jupyter notebook notebooks/data_ingestion_eda.ipynb ``` - *Insights:* - - * Approximately 46% of messages have missing text, indicating the necessity for OCR on images. - * 88% of messages contain images, highlighting the importance of image analysis. - * High-engagement messages (top quartile) have 27+ reactions. + **Insights:** + * Approximately 46% of messages have missing text, indicating the necessity for OCR on images. + * 88% of messages contain images, highlighting the importance of image analysis. + * High-engagement messages (top quartile) have 27+ reactions. 3. **Run the Preprocessor (`src/preprocessor.py`):** This script cleans and normalizes the raw text data by: - - * Normalizing Amharic character variations. - * Strictly removing emojis and pictorial symbols (without converting them to text). - * Removing URLs and hashtags. - * Standardizing currency expressions (e.g., "1500ብር" to "1500 ETB"). - * Retaining Telegram usernames and phone numbers. - * Removing extra spaces and cleaning miscellaneous characters. - Run from the project root: - - + * Normalizing Amharic character variations. + * Strictly removing emojis and pictorial symbols (without converting them to text). + * Removing URLs and hashtags. + * Standardizing currency expressions (e.g., "1500ብር" to "1500 ETB"). + * Retaining Telegram usernames and phone numbers. + * Removing extra spaces and cleaning miscellaneous characters. + Run from the project root: ```bash python src/preprocessor.py ``` - *Output: `data/processed/clean_telegram_data.csv`.* + **Output:** `data/processed/clean_telegram_data.csv`. 4. **Perform Preprocessing EDA (`notebooks/data_preprocessing_eda.ipynb`):** Analyze the characteristics of the cleaned text data, such as text length distribution and common words. @@ -142,12 +148,11 @@ This section guides you through the process of setting up the project, collectin jupyter notebook notebooks/data_preprocessing_eda.ipynb ``` - *Insights:* - - * Confirmed loading and basic characteristics of `clean_telegram_data.csv`. - * Analyzed distribution of preprocessed text lengths and common words. - * Verified retention of Telegram usernames and phone numbers. - * Identified that \~46% of `preprocessed_text` entries are empty (corresponding to messages that were originally only emojis/images/etc.). + **Insights:** + * Confirmed loading and basic characteristics of `clean_telegram_data.csv`. + * Analyzed distribution of preprocessed text lengths and common words. + * Verified retention of Telegram usernames and phone numbers. + * Identified that ~46% of `preprocessed_text` entries are empty (corresponding to messages that were originally only emojis/images/etc.). 5. **Run Unit Tests (`tests/test_preprocessor.py`):** Verify the correctness of the `preprocessor.py` functions. @@ -157,14 +162,79 @@ This section guides you through the process of setting up the project, collectin pytest tests/test_preprocessor.py ``` -## 🎯 3. Next Steps (Task 2 onwards) +## 🎯 3. Named Entity Recognition (NER) Pipeline + +This section details the steps for labeling data and fine-tuning an NER model to extract key business entities. + +### 3.1. Data Labeling (Task 2) + +The cleaned text data is converted into a CoNLL-like format, suitable for Named Entity Recognition (NER) model training. This step involves applying rule-based labeling to identify entities such as product names, prices, locations, contact information, and delivery details. + +**Script:** `src/data_labeler.py` +**Execution:** + +```bash +python src/data_labeler.py +``` + +**Output:** `data/labeled/telegram_ner_data_rule_based.conll` +**Process:** + + * Reads `clean_telegram_data.csv`. + * Applies a set of refined regex patterns to identify and extract entities. + * Handles overlap resolution by prioritizing certain entity types and longer matches. + * Converts the identified entities into the CoNLL format (Token t Tag), ensuring consistency for model training. + **Status:** Completed. The script successfully generated the labeled `.conll` file. + +### 3.2. Model Fine-tuning (Task 3) + +A pre-trained multilingual transformer model (`Davlan/afro-xlmr-large`) is fine-tuned on the labeled Amharic NER dataset to accurately extract entities from new Telegram messages. + +**Script:** `src/model_finetuner.py` +**Execution:** + +```bash +python src/model_finetuner.py +``` + +**Output:** The fine-tuned model and its tokenizer are saved to `models/afro_xlmr_ner_fine_tuned/`. +**Process:** + + * **Data Loading & Splitting:** Loads the CoNLL data, parses it into sentences, and splits it into 80% training, 10% validation, and 10% test sets. Stratification is attempted to maintain class distribution, but automatically disabled for robustness with small sample sizes or imbalanced classes. + * **Tokenization & Label Alignment:** Uses the `afro-xlmr-large` tokenizer to convert words into subword tokens and aligns the word-level NER labels to these subwords, correctly handling B-, I-, L-, U-, and O tags for sequence tagging. + * **Model Initialization:** Loads `Davlan/afro-xlmr-large` for token classification, configuring its output layer for the defined NER labels (PRODUCT, PRICE, LOC, CONTACT, DELIVERY). + * **Training:** Fine-tunes the model for 5 epochs using a batch size of 8, with evaluation performed at each epoch. + * **Evaluation:** Calculates Precision, Recall, and F1-score on the validation and test sets to assess model performance. + **Status:** Completed. The model was successfully fine-tuned and saved. + +### Initial Model Performance (`afro-xlmr-large` on Test Set): + +| Entity Type | Precision | Recall | F1-Score | Support | +| :---------- | :-------- | :----- | :------- | :------ | +| CONTACT | 0.00 | 0.00 | 0.00 | 1 | +| DELIVERY | 0.00 | 0.00 | 0.00 | 0 | +| LOC | 0.10 | 0.05 | 0.07 | 55 | +| PRICE | 0.01 | 0.06 | 0.01 | 16 | +| PRODUCT | 0.02 | 0.25 | 0.03 | 4 | +| **micro avg** | **0.02** | **0.07** | **0.03** | **76** | +| **macro avg** | **0.03** | **0.07** | **0.02** | **76** | +| **weighted avg** | **0.08** | **0.07** | **0.06** | **76** | + +**Summary:** The initial performance is very low across all entity types, with F1-scores close to zero. This is primarily attributed to the small training dataset (only 40 sentences for training). Transformer models require significantly more labeled data to learn robust patterns for NER. Future improvements will focus on expanding the dataset and potentially exploring data augmentation techniques. + +## 🎯 4. Model Comparison & Selection (Task 4) - Next Steps + +The next phase will involve comparing the performance of `afro-xlmr-large` with other suitable multilingual models. + +**Objective:** Fine-tune and evaluate additional models (e.g., DistilBERT, mBERT) to identify the best-performing architecture for the Amharic NER task. +**Steps:** + + * Integrate options to load and fine-tune DistilBERT or mBERT within `src/model_finetuner.py` or a new script. + * Run training and evaluation for each candidate model. + * Compare models based on precision, recall, F1-score, training speed, and resource usage. + * Select the optimal model for production. -The next phase will focus on preparing the data for the Named Entity Recognition (NER) task. +### Future Enhancements (Tasks 5 & 6) - * **Data Labeling:** Convert cleaned text and existing labeled data into the CoNLL format. This will involve: - * Defining and applying labels for product, price, location, contact, and delivery entities. - * Addressing any remaining text overlapping issues (e.g., "ዋጋ ስልክ አድራሻ" or "price contact") by careful tokenization and labeling strategy. - * **Data Splitting:** Divide the labeled data into training, validation, and test sets for model development. - * **Data Versioning:** Set up DVC to version `telegram_data.csv` and large image datasets in `photos/`. - * **Model Fine-Tuning:** Plan for fine-tuning Amharic LLM models for NER. - * **Model Interpretability:** Integrate tools like SHAP/LIME for understanding model predictions. \ No newline at end of file + * **Model Interpretability (Task 5):** Implement SHAP and LIME to explain model predictions, especially for difficult cases. + * **FinTech Vendor Scorecard for Micro-Lending (Task 6):** Develop an analytics engine to combine extracted NER entities with Telegram post metadata (views, timestamps) to calculate key vendor performance metrics (posting frequency, average views per post, average price point) and derive a "Lending Score." diff --git a/src/labeling.py b/src/labeling.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/model_finetuner.py b/src/model_finetuner.py new file mode 100644 index 0000000..99715a6 --- /dev/null +++ b/src/model_finetuner.py @@ -0,0 +1,348 @@ +# EthioMart/src/model_finetuner.py + +import pandas as pd +import re +from pathlib import Path +import logging +from tqdm import tqdm +import os +from sklearn.model_selection import train_test_split +from datasets import Dataset, Features, Value, ClassLabel, Sequence +import torch +from collections import Counter + +# Hugging Face imports +from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer +from transformers import DataCollatorForTokenClassification +from seqeval.metrics import precision_score, recall_score, f1_score, classification_report + +import sys + +# Add the project root to sys.path to allow importing from src and config +project_root = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(project_root)) + +# Import configurations from config/config.py +try: + from config.config import DATA_DIR +except ImportError: + logging.error("Could not import DATA_DIR from config.config. Please ensure the config file is correct.") + DATA_DIR = Path(__file__).parent.parent / "data" + +# Set up logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +# --- Configuration --- +MODEL_CHECKPOINT = "Davlan/afro-xlmr-large" +OUTPUT_DIR = Path(project_root / "models" / "afro_xlmr_ner_fine_tuned") +CONLL_FILE_PATH = Path(DATA_DIR.parent / "labeled" / "telegram_ner_data_rule_based.conll") + +# Training arguments (can be adjusted) +TRAINING_ARGS = TrainingArguments( + output_dir=str(OUTPUT_DIR / "training_logs"), + num_train_epochs=5, + per_device_train_batch_size=8, + per_device_eval_batch_size=8, + learning_rate=2e-5, + warmup_steps=500, + weight_decay=0.01, + eval_strategy="epoch", + logging_dir=str(OUTPUT_DIR / "runs"), + logging_steps=100, + save_strategy="epoch", + load_best_model_at_end=True, + metric_for_best_model="f1", + greater_is_better=True, + report_to="tensorboard", +) + +# --- Data Loading and Parsing --- +def read_conll_file(file_path): + """ + Reads a CoNLL-formatted file and parses it into a list of dictionaries, + where each dictionary represents a sentence with tokens and NER tags. + """ + texts = [] + tokens = [] + ner_tags = [] + + all_possible_labels = ["O", + "B-PRODUCT", "I-PRODUCT", "L-PRODUCT", "U-PRODUCT", + "B-PRICE", "I-PRICE", "L-PRICE", "U-PRICE", + "B-LOC", "I-LOC", "L-LOC", "U-LOC", + "B-CONTACT", "I-CONTACT", "L-CONTACT", "U-CONTACT", + "B-DELIVERY", "I-DELIVERY", "L-DELIVERY", "U-DELIVERY"] + + with open(file_path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + if len(parts) == 2: + tokens.append(parts[0]) + tag = parts[1] + if tag not in all_possible_labels: + logging.warning(f"Unknown tag '{tag}' found for token '{parts[0]}'. Changing to 'O'.") + ner_tags.append("O") + else: + ner_tags.append(tag) + else: + logging.warning(f"Skipping malformed line (too many/few parts): '{line}'") + else: + if tokens: + texts.append({"tokens": tokens, "ner_tags": ner_tags}) + tokens = [] + ner_tags = [] + if tokens: + texts.append({"tokens": tokens, "ner_tags": ner_tags}) + + logging.info(f"Loaded {len(texts)} sentences from CoNLL file.") + + all_flat_tags = [tag for sent in texts for tag in sent['ner_tags']] + logging.info(f"Class distribution: {Counter(all_flat_tags)}") + + return texts, all_possible_labels + +# --- Tokenization and Label Alignment --- +def tokenize_and_align_labels(examples, tokenizer, label_to_id, id_to_label): + """ + Tokenizes a batch of token lists and aligns the NER labels with the new subword tokens. + Handles 'O' (Outside), 'B-' (Beginning), 'I-' (Inside), 'L-' (Last), 'U-' (Unit-length/Single) tags. + + This function processes `examples["tokens"]` which is expected to be a list of lists of strings, + e.g., `[["token1", "token2"], ["tokenA", "tokenB"]]`. + """ + # Ensure examples["tokens"] is a list of lists of strings, even if some are empty. + # The `is_split_into_words=True` expects this format. + # Pass `None` for empty lists of tokens, as tokenizer can't process them + # and this often indicates an issue with data input. + # Instead, we will handle empty lists gracefully within the loop below. + + # Tokenize the batch of sentences + # The `examples["tokens"]` is already a list of lists of strings (sentences). + tokenized_inputs = tokenizer( + examples["tokens"], + truncation=True, + is_split_into_words=True + ) + + labels = [] + # Iterate through each example in the batch + for i, word_label_ids in enumerate(examples["ner_tags"]): # word_label_ids will be a list of integers for current sentence + word_ids = tokenized_inputs.word_ids(batch_index=i) + previous_word_idx = None + label_ids = [] + + # Guard against empty `word_ids` if a sentence was too short or tokenized to nothing + if not word_ids: + labels.append([-100]) # Append a placeholder so `labels` list has correct length + continue + + for word_idx in word_ids: + if word_idx is None: + label_ids.append(-100) # Special tokens + elif word_idx != previous_word_idx: + # This is the first token of a new word. Get original string label. + # Ensure word_idx is within bounds of word_label_ids + if word_idx < len(word_label_ids): + original_label_id = word_label_ids[word_idx] + original_label_str = id_to_label[original_label_id] # Convert ID to string + + if original_label_str.startswith("B-"): + label_ids.append(original_label_id) + elif original_label_str.startswith("I-") or \ + original_label_str.startswith("L-") or \ + original_label_str.startswith("U-"): + entity_type = original_label_str.split("-")[1] + label_ids.append(label_to_id["B-" + entity_type]) + else: # 'O' tag + label_ids.append(original_label_id) + else: + # This implies a mismatch. Assign -100 to be safe. + logging.warning(f"Word index {word_idx} out of bounds for word labels. Assigning -100.") + label_ids.append(-100) + else: + # This is a subsequent token of a word that has been split into subwords + if word_idx < len(word_label_ids): + original_label_id = word_label_ids[word_idx] + original_label_str = id_to_label[original_label_id] + + if original_label_str.startswith("B-") or \ + original_label_str.startswith("U-"): + entity_type = original_label_str.split("-")[1] + label_ids.append(label_to_id["I-" + entity_type]) + elif original_label_str.startswith("I-") or \ + original_label_str.startswith("L-"): + label_ids.append(original_label_id) # Keep original I- or L- if it's not the start + else: # 'O' tag + label_ids.append(original_label_id) + else: + # Mismatch. Assign -100. + logging.warning(f"Word index {word_idx} out of bounds for word labels (subsequent token). Assigning -100.") + label_ids.append(-100) + previous_word_idx = word_idx + labels.append(label_ids) + tokenized_inputs["labels"] = labels + return tokenized_inputs + +# --- Evaluation Metrics --- +def compute_metrics(p, label_list): + """ + Computes and returns precision, recall, and f1-score for NER. + """ + predictions, labels = p + predictions = torch.argmax(torch.tensor(predictions), dim=2) + + # Remove ignored index (where label is -100) + true_predictions = [ + [label_list[p_id] for (p_id, l_id) in zip(prediction, label) if l_id != -100] + for prediction, label in zip(predictions, labels) + ] + true_labels = [ + [label_list[l_id] for (p_id, l_id) in zip(prediction, label) if l_id != -100] + for prediction, label in zip(predictions, labels) + ] + + precision = precision_score(true_labels, true_predictions) + recall = recall_score(true_labels, true_predictions) + f1 = f1_score(true_labels, true_predictions) + + logging.info(f"\n--- Classification Report ---\n{classification_report(true_labels, true_predictions)}") + + return { + "precision": precision, + "recall": recall, + "f1": f1, + } + + +def main(): + # 1. Load the labeled dataset + if not CONLL_FILE_PATH.exists(): + logging.error(f"❌ CoNLL file not found at {CONLL_FILE_PATH}. Please ensure data_labeler.py has run and produced this file.") + sys.exit(1) + + raw_data, all_possible_labels = read_conll_file(CONLL_FILE_PATH) + + if not raw_data: + logging.error("❌ No data loaded from CoNLL file. Exiting.") + sys.exit(1) + + # 2. Split the dataset into train, validation, and test sets + + # Create a simpler stratification key: based on presence of entities to reduce unique classes + stratify_keys = [] + for sentence in raw_data: + has_entity = any(tag != 'O' for tag in sentence['ner_tags']) + stratify_keys.append("HAS_ENTITY" if has_entity else "NO_ENTITY") + + unique_stratify_keys_counts = Counter(stratify_keys) + viable_stratify = True + if len(unique_stratify_keys_counts) > 1: + for key, count in unique_stratify_keys_counts.items(): + if count < 2: + viable_stratify = False + logging.warning(f"Stratification key '{key}' has only {count} member(s). Disabling stratification for robustness.") + break + else: + viable_stratify = False + logging.info("Only one unique stratification key found or not enough samples per class. Performing non-stratified split.") + + # Perform splits + all_indices = list(range(len(raw_data))) + if viable_stratify: + logging.info("Performing stratified split.") + train_val_indices, test_indices = train_test_split( + all_indices, test_size=0.1, random_state=42, stratify=stratify_keys + ) + train_indices, val_indices = train_test_split( + train_val_indices, test_size=0.111, random_state=42, stratify=[stratify_keys[i] for i in train_val_indices] + ) + else: + logging.info("Performing non-stratified split.") + train_val_indices, test_indices = train_test_split(all_indices, test_size=0.1, random_state=42) + train_indices, val_indices = train_test_split(train_val_indices, test_size=0.111, random_state=42) + + train_data = [raw_data[i] for i in train_indices] + val_data = [raw_data[i] for i in val_indices] + test_data = [raw_data[i] for i in test_indices] + + logging.info(f"Dataset split: Train={len(train_data)}, Validation={len(val_data)}, Test={len(test_data)}") + + # 3. Create Hugging Face Dataset objects + features = Features({ + "tokens": Sequence(Value(dtype="string")), # Changed to Sequence(Value(dtype="string")) for lists of tokens + "ner_tags": Sequence(ClassLabel(names=all_possible_labels)), + }) + + train_dataset = Dataset.from_list(train_data, features=features) + val_dataset = Dataset.from_list(val_data, features=features) + test_dataset = Dataset.from_list(test_data, features=features) + + # 4. Load Tokenizer and Model + tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT) + + label_to_id = {label: i for i, label in enumerate(all_possible_labels)} + id_to_label = {i: label for label, i in label_to_id.items()} + + model = AutoModelForTokenClassification.from_pretrained( + MODEL_CHECKPOINT, + num_labels=len(all_possible_labels), + id2label=id_to_label, + label2id=label_to_id + ) + logging.info(f"Model and tokenizer loaded from {MODEL_CHECKPOINT}.") + logging.info(f"Number of labels: {len(all_possible_labels)}") + logging.info(f"Label mappings: {label_to_id}") + + # 5. Tokenize and Align Labels for all datasets + logging.info("Tokenizing and aligning labels for datasets...") + tokenized_train_dataset = train_dataset.map( + lambda examples: tokenize_and_align_labels(examples, tokenizer, label_to_id, id_to_label), + batched=True, + remove_columns=["tokens", "ner_tags"] + ) + tokenized_val_dataset = val_dataset.map( + lambda examples: tokenize_and_align_labels(examples, tokenizer, label_to_id, id_to_label), + batched=True, + remove_columns=["tokens", "ner_tags"] + ) + tokenized_test_dataset = test_dataset.map( + lambda examples: tokenize_and_align_labels(examples, tokenizer, label_to_id, id_to_label), + batched=True, + remove_columns=["tokens", "ner_tags"] + ) + logging.info("Tokenization and alignment complete.") + + data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer) + + # 6. Set up Trainer + trainer = Trainer( + model=model, + args=TRAINING_ARGS, + train_dataset=tokenized_train_dataset, + eval_dataset=tokenized_val_dataset, + data_collator=data_collator, + tokenizer=tokenizer, + compute_metrics=lambda p: compute_metrics(p, list(id_to_label.values())) + ) + + # 7. Train the model + logging.info("Starting model training...") + trainer.train() + logging.info("Model training finished.") + + # 8. Evaluate the fine-tuned model on the test set + logging.info("Evaluating model on the test set...") + results = trainer.evaluate(tokenized_test_dataset) + logging.info(f"Test Set Evaluation Results: {results}") + + # 9. Save the model and tokenizer + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + model.save_pretrained(OUTPUT_DIR) + tokenizer.save_pretrained(OUTPUT_DIR) + logging.info(f"✅ Fine-tuned model and tokenizer saved to {OUTPUT_DIR}") + +if __name__ == "__main__": + main() +