From 516f51970d7e83a9d4cc72b5c64ab74fbbfbb7bc Mon Sep 17 00:00:00 2001 From: Abhijit Satyaki Date: Mon, 5 Jan 2026 23:53:42 +0530 Subject: [PATCH] fix: use context managers for all SQLite database connections - Refactored 24 database connections across 5 files - Prevents connection leaks if exceptions occur - Follows Python best practices (PEP 343) --- auth/auth_utils.py | 188 +++++++++++++++++++--------------------- core/utils.py | 205 +++++++++++++++++++++----------------------- migrate_db.py | 23 +++-- pages/Journaling.py | 105 +++++++++++------------ setup_database.py | 73 ++++++++-------- 5 files changed, 281 insertions(+), 313 deletions(-) diff --git a/auth/auth_utils.py b/auth/auth_utils.py index 7c3d42d8..f2c93876 100644 --- a/auth/auth_utils.py +++ b/auth/auth_utils.py @@ -4,45 +4,44 @@ from auth.password_validator import PasswordValidator def init_db(): - conn = sqlite3.connect("users.db") - cursor = conn.cursor() - cursor.execute(""" - CREATE TABLE IF NOT EXISTS users ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL, - email TEXT UNIQUE NOT NULL, - password TEXT, - updated_at TEXT NOT NULL, - provider TEXT DEFAULT 'email', - provider_id TEXT, - profile_picture TEXT, - verified BOOLEAN DEFAULT 0 - ) - """) - - # Add new columns if they don't exist (for existing databases) - try: - cursor.execute("ALTER TABLE users ADD COLUMN provider TEXT DEFAULT 'email'") - except sqlite3.OperationalError: - pass # Column already exists - - try: - cursor.execute("ALTER TABLE users ADD COLUMN provider_id TEXT") - except sqlite3.OperationalError: - pass # Column already exists - - try: - cursor.execute("ALTER TABLE users ADD COLUMN profile_picture TEXT") - except sqlite3.OperationalError: - pass # Column already exists - - try: - cursor.execute("ALTER TABLE users ADD COLUMN verified BOOLEAN DEFAULT 0") - except sqlite3.OperationalError: - pass # Column already exists - - conn.commit() - conn.close() + with sqlite3.connect("users.db") as conn: + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + email TEXT UNIQUE NOT NULL, + password TEXT, + updated_at TEXT NOT NULL, + provider TEXT DEFAULT 'email', + provider_id TEXT, + profile_picture TEXT, + verified BOOLEAN DEFAULT 0 + ) + """) + + # Add new columns if they don't exist (for existing databases) + try: + cursor.execute("ALTER TABLE users ADD COLUMN provider TEXT DEFAULT 'email'") + except sqlite3.OperationalError: + pass # Column already exists + + try: + cursor.execute("ALTER TABLE users ADD COLUMN provider_id TEXT") + except sqlite3.OperationalError: + pass # Column already exists + + try: + cursor.execute("ALTER TABLE users ADD COLUMN profile_picture TEXT") + except sqlite3.OperationalError: + pass # Column already exists + + try: + cursor.execute("ALTER TABLE users ADD COLUMN verified BOOLEAN DEFAULT 0") + except sqlite3.OperationalError: + pass # Column already exists + + conn.commit() def hash_password(password): return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() @@ -51,56 +50,52 @@ def check_password(password, hashed): return bcrypt.checkpw(password.encode(), hashed.encode()) def register_user(name, email, password, provider='email', provider_id=None, profile_picture=None, verified=False): - conn = sqlite3.connect("users.db") - cursor = conn.cursor() - # Hash password only if provided (OAuth users don't need passwords) hashed_pw = hash_password(password) if password else None current_time = datetime.now().isoformat() - try: - cursor.execute(""" - INSERT INTO users (name, email, password, updated_at, provider, provider_id, profile_picture, verified) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, (name, email, hashed_pw, current_time, provider, provider_id, profile_picture, verified)) - conn.commit() - return True, "User registered successfully" - except sqlite3.IntegrityError: - return False, "Email already registered" - finally: - conn.close() + with sqlite3.connect("users.db") as conn: + cursor = conn.cursor() + try: + cursor.execute(""" + INSERT INTO users (name, email, password, updated_at, provider, provider_id, profile_picture, verified) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, (name, email, hashed_pw, current_time, provider, provider_id, profile_picture, verified)) + conn.commit() + return True, "User registered successfully" + except sqlite3.IntegrityError: + return False, "Email already registered" def authenticate_user(email, password): - conn = sqlite3.connect("users.db") - cursor = conn.cursor() - cursor.execute("SELECT name, password FROM users WHERE email = ?", (email,)) - result = cursor.fetchone() - conn.close() + with sqlite3.connect("users.db") as conn: + cursor = conn.cursor() + cursor.execute("SELECT name, password FROM users WHERE email = ?", (email,)) + result = cursor.fetchone() + if result and check_password(password, result[1]): user = {"name": result[0], "email": email} return True, user return False, None def check_user(email): - conn = sqlite3.connect("users.db") - cursor = conn.cursor() - cursor.execute("SELECT * FROM users WHERE email = ?", (email,)) - result = cursor.fetchone() - conn.close() + with sqlite3.connect("users.db") as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users WHERE email = ?", (email,)) + result = cursor.fetchone() + if result: - return True , result[4] - return False , None + return True, result[4] + return False, None def get_user_by_email(email): """Get user data by email for OAuth authentication""" - conn = sqlite3.connect("users.db") - cursor = conn.cursor() - cursor.execute(""" - SELECT id, name, email, provider, provider_id, profile_picture, verified, updated_at - FROM users WHERE email = ? - """, (email,)) - result = cursor.fetchone() - conn.close() + with sqlite3.connect("users.db") as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT id, name, email, provider, provider_id, profile_picture, verified, updated_at + FROM users WHERE email = ? + """, (email,)) + result = cursor.fetchone() if result: return { @@ -118,43 +113,38 @@ def get_user_by_email(email): def reset_password(email, new_password): hashed_pw = hash_password(new_password) current_time = datetime.now().isoformat() + try: - conn = sqlite3.connect("users.db") - cursor = conn.cursor() - cursor.execute("SELECT * FROM users WHERE email = ?", (email,)) - result = cursor.fetchone() + with sqlite3.connect("users.db") as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users WHERE email = ?", (email,)) + result = cursor.fetchone() - if not result: - conn.close() - return False, "User with this email does not exist." + if not result: + return False, "User with this email does not exist." - cursor.execute("UPDATE users SET password = ? , updated_at = ? WHERE email = ?", (hashed_pw, current_time, email)) - conn.commit() - conn.close() - return True, "Password updated successfully." + cursor.execute("UPDATE users SET password = ? , updated_at = ? WHERE email = ?", (hashed_pw, current_time, email)) + conn.commit() + return True, "Password updated successfully." except sqlite3.Error as e: - conn.close() return False, f"Database error: {str(e)}" def verify_token_count(email, token_updated_at): try: - conn = sqlite3.connect("users.db") - cursor = conn.cursor() - cursor.execute("SELECT updated_at FROM users WHERE email = ?", (email,)) - result = cursor.fetchone() - if not result: - conn.close() - return False, "User with this email does not exist." + with sqlite3.connect("users.db") as conn: + cursor = conn.cursor() + cursor.execute("SELECT updated_at FROM users WHERE email = ?", (email,)) + result = cursor.fetchone() + + if not result: + return False, "User with this email does not exist." - db_updated_at = result[0] + db_updated_at = result[0] - if str(db_updated_at) != str(token_updated_at): - conn.close() - return False, "Reset link is no longer valid (token outdated)." + if str(db_updated_at) != str(token_updated_at): + return False, "Reset link is no longer valid (token outdated)." - conn.close() - return True, None + return True, None except sqlite3.Error as e: - conn.close() return False, f"Database error: {str(e)}" \ No newline at end of file diff --git a/core/utils.py b/core/utils.py index 9b5fa9d5..880ae875 100644 --- a/core/utils.py +++ b/core/utils.py @@ -466,39 +466,34 @@ def save_feedback(convo_id, message, feedback, comment=None): Comment: {comment if comment else "No comment"} """) - conn = None try: - conn = sqlite3.connect("feedback.db") - c = conn.cursor() - - c.execute(''' - SELECT id FROM feedback WHERE user_email = ? AND convo_id = ? AND message = ? - ''', (hashed_email, convo_id, message)) - row = c.fetchone() + with sqlite3.connect("feedback.db") as conn: + c = conn.cursor() - if row: - c.execute(''' - UPDATE feedback - SET feedback = ?, comment = ?, timestamp = CURRENT_TIMESTAMP - WHERE id = ? - ''', (feedback, comment, row[0])) - print(f"[save_feedback] Updated existing feedback (id={row[0]})") - else: c.execute(''' - INSERT INTO feedback (user_email, convo_id, message, feedback, comment) - VALUES (?, ?, ?, ?, ?) - ''', (hashed_email, convo_id, message, feedback, comment)) - print("[save_feedback] Inserted new feedback") - - conn.commit() + SELECT id FROM feedback WHERE user_email = ? AND convo_id = ? AND message = ? + ''', (hashed_email, convo_id, message)) + row = c.fetchone() + + if row: + c.execute(''' + UPDATE feedback + SET feedback = ?, comment = ?, timestamp = CURRENT_TIMESTAMP + WHERE id = ? + ''', (feedback, comment, row[0])) + print(f"[save_feedback] Updated existing feedback (id={row[0]})") + else: + c.execute(''' + INSERT INTO feedback (user_email, convo_id, message, feedback, comment) + VALUES (?, ?, ?, ?, ?) + ''', (hashed_email, convo_id, message, feedback, comment)) + print("[save_feedback] Inserted new feedback") + + conn.commit() except Exception as e: print(f"[save_feedback] Exception while saving feedback: {e}") - finally: - if conn: - conn.close() - def get_feedback(convo_id, message): """ @@ -517,21 +512,19 @@ def get_feedback(convo_id, message): hashed_email = hash_email(user_email) try: - conn = sqlite3.connect("feedback.db") - c = conn.cursor() - c.execute(''' - SELECT feedback FROM feedback WHERE user_email = ? AND convo_id = ? AND message = ? - ''', (hashed_email, convo_id, message)) - row = c.fetchone() - if row: - return row[0] - else: - return None + with sqlite3.connect("feedback.db") as conn: + c = conn.cursor() + c.execute(''' + SELECT feedback FROM feedback WHERE user_email = ? AND convo_id = ? AND message = ? + ''', (hashed_email, convo_id, message)) + row = c.fetchone() + if row: + return row[0] + else: + return None except Exception as e: print(f"[get_feedback] Exception while fetching feedback: {e}") return None - finally: - conn.close() def get_feedback_per_message(convo_id=None): @@ -542,25 +535,24 @@ def get_feedback_per_message(convo_id=None): Returns: list: List of feedback dicts. """ - conn = sqlite3.connect("feedback.db") - c = conn.cursor() + with sqlite3.connect("feedback.db") as conn: + c = conn.cursor() - if convo_id is None: - c.execute(''' - SELECT user_email, convo_id, message, feedback, comment, timestamp - FROM feedback - ORDER BY timestamp DESC - ''') - else: - c.execute(''' - SELECT user_email, convo_id, message, feedback, comment, timestamp - FROM feedback - WHERE convo_id = ? - ORDER BY timestamp DESC - ''', (convo_id,)) + if convo_id is None: + c.execute(''' + SELECT user_email, convo_id, message, feedback, comment, timestamp + FROM feedback + ORDER BY timestamp DESC + ''') + else: + c.execute(''' + SELECT user_email, convo_id, message, feedback, comment, timestamp + FROM feedback + WHERE convo_id = ? + ORDER BY timestamp DESC + ''', (convo_id,)) - rows = c.fetchall() - conn.close() + rows = c.fetchall() return [ { @@ -582,26 +574,24 @@ def get_feedback_statistics(): dict: Statistics including total, positive, negative counts and percentage. """ try: - conn = sqlite3.connect("feedback.db") - c = conn.cursor() - - c.execute("SELECT COUNT(*) FROM feedback WHERE feedback = 'positive'") - positive = c.fetchone()[0] - - c.execute("SELECT COUNT(*) FROM feedback WHERE feedback = 'negative'") - negative = c.fetchone()[0] - - total = positive + negative - positive_pct = (positive / total * 100) if total > 0 else 0 - - conn.close() - - return { - "total": total, - "positive": positive, - "negative": negative, - "positive_percentage": round(positive_pct, 1) - } + with sqlite3.connect("feedback.db") as conn: + c = conn.cursor() + + c.execute("SELECT COUNT(*) FROM feedback WHERE feedback = 'positive'") + positive = c.fetchone()[0] + + c.execute("SELECT COUNT(*) FROM feedback WHERE feedback = 'negative'") + negative = c.fetchone()[0] + + total = positive + negative + positive_pct = (positive / total * 100) if total > 0 else 0 + + return { + "total": total, + "positive": positive, + "negative": negative, + "positive_percentage": round(positive_pct, 1) + } except Exception as e: print(f"[get_feedback_statistics] Error: {e}") return {"total": 0, "positive": 0, "negative": 0, "positive_percentage": 0} @@ -1208,19 +1198,18 @@ def clean_database(): int: Number of entries deleted. """ try: - conn = sqlite3.connect("feedback.db") - c = conn.cursor() - - cutoff_date = (datetime.now() - timedelta(days=90)).isoformat() - - c.execute("SELECT COUNT(*) FROM feedback WHERE timestamp < ?", (cutoff_date,)) - count = c.fetchone()[0] - - c.execute("DELETE FROM feedback WHERE timestamp < ?", (cutoff_date,)) - conn.commit() - conn.close() - - return count + with sqlite3.connect("feedback.db") as conn: + c = conn.cursor() + + cutoff_date = (datetime.now() - timedelta(days=90)).isoformat() + + c.execute("SELECT COUNT(*) FROM feedback WHERE timestamp < ?", (cutoff_date,)) + count = c.fetchone()[0] + + c.execute("DELETE FROM feedback WHERE timestamp < ?", (cutoff_date,)) + conn.commit() + + return count except Exception as e: print(f"[clean_database] Error: {e}") return 0 @@ -1253,22 +1242,21 @@ def export_user_data(): # Get feedback hashed_email = hash_email(user_email) try: - conn = sqlite3.connect("feedback.db") - c = conn.cursor() - c.execute("SELECT * FROM feedback WHERE user_email = ?", (hashed_email,)) - rows = c.fetchall() - - data_package["feedback"] = [ - { - "convo_id": r[2], - "message": r[3], - "feedback": r[4], - "comment": r[5], - "timestamp": r[6] - } - for r in rows - ] - conn.close() + with sqlite3.connect("feedback.db") as conn: + c = conn.cursor() + c.execute("SELECT * FROM feedback WHERE user_email = ?", (hashed_email,)) + rows = c.fetchall() + + data_package["feedback"] = [ + { + "convo_id": r[2], + "message": r[3], + "feedback": r[4], + "comment": r[5], + "timestamp": r[6] + } + for r in rows + ] except Exception as e: print(f"[export_user_data] Error: {e}") @@ -1293,11 +1281,10 @@ def delete_user_data(): # Delete feedback hashed_email = hash_email(user_email) - conn = sqlite3.connect("feedback.db") - c = conn.cursor() - c.execute("DELETE FROM feedback WHERE user_email = ?", (hashed_email,)) - conn.commit() - conn.close() + with sqlite3.connect("feedback.db") as conn: + c = conn.cursor() + c.execute("DELETE FROM feedback WHERE user_email = ?", (hashed_email,)) + conn.commit() # Clear session state st.session_state.conversations = [] diff --git a/migrate_db.py b/migrate_db.py index 12b38364..52e47392 100644 --- a/migrate_db.py +++ b/migrate_db.py @@ -15,21 +15,20 @@ } def migrate(): - conn = sqlite3.connect(DB_FILE) - cursor = conn.cursor() + with sqlite3.connect(DB_FILE) as conn: + cursor = conn.cursor() - # Get current columns in the table - cursor.execute(f"PRAGMA table_info({TABLE_NAME})") - current_columns = [col[1] for col in cursor.fetchall()] + # Get current columns in the table + cursor.execute(f"PRAGMA table_info({TABLE_NAME})") + current_columns = [col[1] for col in cursor.fetchall()] - # Add missing columns - for col_name, col_type in EXPECTED_COLUMNS.items(): - if col_name not in current_columns: - cursor.execute(f"ALTER TABLE {TABLE_NAME} ADD COLUMN {col_name} {col_type}") - print(f"✅ Added missing column: {col_name}") + # Add missing columns + for col_name, col_type in EXPECTED_COLUMNS.items(): + if col_name not in current_columns: + cursor.execute(f"ALTER TABLE {TABLE_NAME} ADD COLUMN {col_name} {col_type}") + print(f"✅ Added missing column: {col_name}") - conn.commit() - conn.close() + conn.commit() print("✅ Migration complete!") if __name__ == "__main__": diff --git a/pages/Journaling.py b/pages/Journaling.py index a102b188..b17a1109 100644 --- a/pages/Journaling.py +++ b/pages/Journaling.py @@ -148,70 +148,65 @@ def analyze_sentiment(entry: str) -> str: DB_PATH = "journals.db" def init_journal_db(): - conn = sqlite3.connect(DB_PATH) - cursor = conn.cursor() - cursor.execute(""" - CREATE TABLE IF NOT EXISTS journal_entries ( - id TEXT PRIMARY KEY, - email TEXT, - entry TEXT, - sentiment TEXT, - date TEXT, - tags TEXT - ) - """) - conn.commit() - conn.close() + with sqlite3.connect(DB_PATH) as conn: + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS journal_entries ( + id TEXT PRIMARY KEY, + email TEXT, + entry TEXT, + sentiment TEXT, + date TEXT, + tags TEXT + ) + """) + conn.commit() def save_entry(email, entry, sentiment, tags): - conn = sqlite3.connect(DB_PATH) - cursor = conn.cursor() - cursor.execute(""" - INSERT INTO journal_entries (id, email, entry, sentiment, date, tags) - VALUES (?, ?, ?, ?, ?, ?) - """, (str(uuid4()), email, entry, sentiment, str(date.today()), tags)) - conn.commit() - conn.close() + with sqlite3.connect(DB_PATH) as conn: + cursor = conn.cursor() + cursor.execute(""" + INSERT INTO journal_entries (id, email, entry, sentiment, date, tags) + VALUES (?, ?, ?, ?, ?, ?) + """, (str(uuid4()), email, entry, sentiment, str(date.today()), tags)) + conn.commit() def fetch_entries(email, sentiment_filter=None, start_date=None, end_date=None, tag_filter=None, search_query=None): - conn = sqlite3.connect(DB_PATH) - cursor = conn.cursor() - query = """ - SELECT id, entry, sentiment, date, tags FROM journal_entries - WHERE email = ? - """ - params = [email] - if sentiment_filter and sentiment_filter != "All": - query += " AND sentiment = ?" - params.append(sentiment_filter) - if start_date and end_date: - query += " AND date BETWEEN ? AND ?" - params.extend([start_date.strftime("%Y-%m-%d"), end_date.strftime("%Y-%m-%d")]) - if tag_filter: - for tag in tag_filter: - query += f" AND tags LIKE '%{tag}%'" - if search_query: - query += f" AND (entry LIKE '%{search_query}%' OR tags LIKE '%{search_query}%')" - - query += " ORDER BY date ASC" - rows = cursor.execute(query, params).fetchall() - conn.close() + with sqlite3.connect(DB_PATH) as conn: + cursor = conn.cursor() + query = """ + SELECT id, entry, sentiment, date, tags FROM journal_entries + WHERE email = ? + """ + params = [email] + if sentiment_filter and sentiment_filter != "All": + query += " AND sentiment = ?" + params.append(sentiment_filter) + if start_date and end_date: + query += " AND date BETWEEN ? AND ?" + params.extend([start_date.strftime("%Y-%m-%d"), end_date.strftime("%Y-%m-%d")]) + if tag_filter: + for tag in tag_filter: + query += f" AND tags LIKE '%{tag}%'" + if search_query: + query += f" AND (entry LIKE '%{search_query}%' OR tags LIKE '%{search_query}%')" + + query += " ORDER BY date ASC" + rows = cursor.execute(query, params).fetchall() return rows def update_entry(entry_id, new_text, new_tags): - conn = sqlite3.connect(DB_PATH) - cursor = conn.cursor() - new_sentiment = analyze_sentiment(new_text) - cursor.execute("UPDATE journal_entries SET entry = ?, sentiment = ?, tags = ? WHERE id = ?", (new_text, new_sentiment, new_tags, entry_id)) - conn.commit() - conn.close() + with sqlite3.connect(DB_PATH) as conn: + cursor = conn.cursor() + new_sentiment = analyze_sentiment(new_text) + cursor.execute("UPDATE journal_entries SET entry = ?, sentiment = ?, tags = ? WHERE id = ?", (new_text, new_sentiment, new_tags, entry_id)) + conn.commit() def delete_entry(entry_id): - conn = sqlite3.connect(DB_PATH) - cursor = conn.cursor() - cursor.execute("DELETE FROM journal_entries WHERE id = ?", (entry_id,)) - conn.commit() - conn.close() + with sqlite3.connect(DB_PATH) as conn: + cursor = conn.cursor() + cursor.execute("DELETE FROM journal_entries WHERE id = ?", (entry_id,)) + conn.commit() def create_mood_trend_chart(entries): if not entries: diff --git a/setup_database.py b/setup_database.py index 9603b412..61f2a568 100644 --- a/setup_database.py +++ b/setup_database.py @@ -12,52 +12,49 @@ def setup_journals_db(): """Initialize the journals database""" - conn = sqlite3.connect("journals.db") - cursor = conn.cursor() - cursor.execute(""" - CREATE TABLE IF NOT EXISTS journal_entries ( - id TEXT PRIMARY KEY, - email TEXT, - entry TEXT, - sentiment TEXT, - date TEXT - ) - """) - conn.commit() - conn.close() + with sqlite3.connect("journals.db") as conn: + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS journal_entries ( + id TEXT PRIMARY KEY, + email TEXT, + entry TEXT, + sentiment TEXT, + date TEXT + ) + """) + conn.commit() print("Journals database initialized successfully") def setup_feedback_db(): """Initialize the feedback database""" - conn = sqlite3.connect("feedback.db") - cursor = conn.cursor() - cursor.execute(""" - CREATE TABLE IF NOT EXISTS feedback ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_email TEXT, - convo_id INTEGER, - message TEXT, - feedback TEXT, - comment TEXT, - timestamp DATETIME DEFAULT CURRENT_TIMESTAMP - ) - """) - conn.commit() - conn.close() + with sqlite3.connect("feedback.db") as conn: + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS feedback ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_email TEXT, + convo_id INTEGER, + message TEXT, + feedback TEXT, + comment TEXT, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP + ) + """) + conn.commit() print("Feedback database initialized successfully") def update_feedback_table(): """Add user_email column if it doesn't exist (for old DBs)""" - conn = sqlite3.connect("feedback.db") - c = conn.cursor() - try: - c.execute("ALTER TABLE feedback ADD COLUMN user_email TEXT") - print("✅ user_email column added to feedback table") - except sqlite3.OperationalError: - # Column already exists - print("ℹ️ user_email column already exists in feedback table") - conn.commit() - conn.close() + with sqlite3.connect("feedback.db") as conn: + c = conn.cursor() + try: + c.execute("ALTER TABLE feedback ADD COLUMN user_email TEXT") + print("✅ user_email column added to feedback table") + except sqlite3.OperationalError: + # Column already exists + print("ℹ️ user_email column already exists in feedback table") + conn.commit() def main(): """Main setup function"""