Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion TalkHeal.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def _get_query_params():
handle_oauth_callback()
st.stop()

# Restore session from cookie if not already authenticated
if not st.session_state.get("authenticated", False):
from auth.session_manager import restore_session_from_storage
restore_session_from_storage()

if not st.session_state.get("authenticated", False):
show_login_page()
st.stop()
Expand Down Expand Up @@ -97,7 +102,9 @@ def _get_query_params():
st.switch_page("pages/About.py")
with nav_cols[3]:
if st.button("Logout", key="logout_btn", help="Sign out", use_container_width=True):
for key in ["authenticated", "user_profile"]:
from auth.session_manager import clear_session_cookie
clear_session_cookie()
for key in ["authenticated", "user_profile", "user_name"]:
if key in st.session_state:
del st.session_state[key]
st.rerun()
Expand Down
9 changes: 7 additions & 2 deletions auth/oauth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ def handle_oauth_callback(provider_name: str, code: str, state: str) -> Tuple[bo
return False, user_info.get("error", "Failed to create/get user")

# Set session state
st.session_state.authenticated = True
st.session_state.user_profile = {
user_profile = {
"name": user_info["name"],
"email": user_info["email"],
"profile_picture": user_info["profile_picture"],
Expand All @@ -237,8 +236,14 @@ def handle_oauth_callback(provider_name: str, code: str, state: str) -> Tuple[bo
"provider_id": user_info["provider_id"],
"verified": user_info["verified"]
}
st.session_state.authenticated = True
st.session_state.user_profile = user_profile
st.session_state.user_name = user_info["name"]

# Save session to cookie for persistence across page refreshes
from auth.session_manager import set_session_cookie
set_session_cookie(user_info["email"], user_profile)

# Clean up OAuth state
if "oauth_states" in st.session_state and state in st.session_state.oauth_states:
del st.session_state.oauth_states[state]
Expand Down
215 changes: 215 additions & 0 deletions auth/session_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
"""
Session management using cookies for persistent authentication across page refreshes.
Uses extra-streamlit-components for reliable cookie handling.
"""
import streamlit as st
import json
from datetime import datetime, timedelta
from auth.auth_utils import get_user_by_email
from extra_streamlit_components import CookieManager

# Cookie name for storing session data
SESSION_COOKIE_NAME = "talkheal_session"
SESSION_EXPIRY_DAYS = 7 # Session expires after 7 days

# Initialize cookie manager
@st.cache_resource
def get_cookie_manager():
return CookieManager()

def set_session_cookie(email, user_data):
"""
Set a cookie with user session data.

Args:
email (str): User's email
user_data (dict): User profile data
"""
try:
cookie_manager = get_cookie_manager()

session_data = {
"email": email,
"authenticated": True,
"user_profile": user_data,
"expires_at": (datetime.now() + timedelta(days=SESSION_EXPIRY_DAYS)).isoformat()
}

# Convert to JSON string
session_json = json.dumps(session_data)

# Set cookie (expires in 7 days)
cookie_manager.set(
SESSION_COOKIE_NAME,
session_json,
expires_at=datetime.now() + timedelta(days=SESSION_EXPIRY_DAYS)
)
except Exception as e:
# If cookie manager fails, fall back to localStorage
_set_session_storage_fallback(email, user_data)


def get_session_cookie():
"""
Get session data from cookie.

Returns:
dict or None: Session data if cookie exists and is valid, None otherwise
"""
try:
cookie_manager = get_cookie_manager()
cookie_value = cookie_manager.get(SESSION_COOKIE_NAME)

if not cookie_value:
# Try fallback to localStorage
return _get_session_storage_fallback()

# Parse JSON data
session_data = json.loads(cookie_value)

# Check if session has expired
expires_at_str = session_data.get("expires_at")
if expires_at_str:
expires_at = datetime.fromisoformat(expires_at_str)
if datetime.now() > expires_at:
# Session expired, clear cookie
clear_session_cookie()
return None

return session_data
except Exception as e:
# If there's any error, try fallback
return _get_session_storage_fallback()


def clear_session_cookie():
"""
Clear the session cookie.
"""
try:
cookie_manager = get_cookie_manager()
cookie_manager.delete(SESSION_COOKIE_NAME)
except Exception:
pass
# Also clear localStorage fallback
_clear_session_storage_fallback()


def restore_session_from_storage():
"""
Restore user session from cookie if it exists.
This should be called at the start of the app.

Returns:
bool: True if session was restored, False otherwise
"""
# Only restore if not already authenticated
if st.session_state.get("authenticated", False):
return True

# Check if we've already tried to restore in this session
# This prevents infinite loops
if st.session_state.get("session_restore_attempted", False):
return False

# Mark that we've attempted restoration
st.session_state["session_restore_attempted"] = True

# Try to get session data from cookie
session_data = get_session_cookie()

if session_data and session_data.get("authenticated"):
email = session_data.get("email")
user_profile = session_data.get("user_profile", {})

# Handle guest users (they don't exist in database)
if email == "guest@talkheal.app":
# Restore guest session directly
st.session_state.authenticated = True
st.session_state.user_profile = user_profile
st.session_state.user_name = user_profile.get("name", "Guest Healer")
return True

# For regular users, verify they still exist in database
user = get_user_by_email(email)
if user:
# Restore session state
st.session_state.authenticated = True
st.session_state.user_profile = user_profile
st.session_state.user_name = user_profile.get("name", email)
return True
else:
# User doesn't exist anymore, clear cookie
clear_session_cookie()
return False

return False


# Fallback functions using localStorage (in case cookies don't work)
def _set_session_storage_fallback(email, user_data):
"""Fallback: Set localStorage using JavaScript"""
session_data = {
"email": email,
"authenticated": True,
"user_profile": user_data,
"expires_at": (datetime.now() + timedelta(days=SESSION_EXPIRY_DAYS)).isoformat()
}
session_json = json.dumps(session_data)
js_code = f"""
<script>
try {{
localStorage.setItem("{SESSION_COOKIE_NAME}", {json.dumps(session_json)});
}} catch(e) {{
console.error("Error setting session storage:", e);
}}
</script>
"""
st.markdown(js_code, unsafe_allow_html=True)


def _get_session_storage_fallback():
"""Fallback: Get localStorage using JavaScript"""
try:
from streamlit_js_eval import streamlit_js_eval
storage_value = streamlit_js_eval(
js_expressions=f'localStorage.getItem("{SESSION_COOKIE_NAME}")',
key=f"get_session_fallback_{SESSION_COOKIE_NAME}"
)

if not storage_value or storage_value == "null" or storage_value == "None" or storage_value == "":
return None

storage_str = str(storage_value) if not isinstance(storage_value, str) else storage_value
if isinstance(storage_value, dict):
storage_str = storage_value.get('value', None) or storage_value.get('result', None)
if not storage_str:
return None

session_data = json.loads(storage_str)

# Check expiration
expires_at_str = session_data.get("expires_at")
if expires_at_str:
expires_at = datetime.fromisoformat(expires_at_str)
if datetime.now() > expires_at:
_clear_session_storage_fallback()
return None

return session_data
except Exception:
return None


def _clear_session_storage_fallback():
"""Fallback: Clear localStorage"""
js_code = f"""
<script>
try {{
localStorage.removeItem("{SESSION_COOKIE_NAME}");
}} catch(e) {{
console.error("Error clearing session storage:", e);
}}
</script>
"""
st.markdown(js_code, unsafe_allow_html=True)
22 changes: 14 additions & 8 deletions components/login_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,16 +431,20 @@ def show_login_page():
try:
success, user = authenticate_user(email, password)
if success:
st.session_state.authenticated = True
st.session_state.user_profile = {
from auth.session_manager import set_session_cookie
user_profile = {
"name": user.get("name", ""),
"email": user.get("email", email),
"profile_picture": user.get("photo", None),
"join_date": user.get("join_date", datetime.now().strftime("%B %Y")),
"font_size": user.get("font_size", "Medium")
}
st.session_state.authenticated = True
st.session_state.user_profile = user_profile
# Set user_name for display purposes
st.session_state.user_name = user.get("name", email)
# Save session to cookie
set_session_cookie(email, user_profile)
st.rerun()

else:
Expand Down Expand Up @@ -493,19 +497,21 @@ def show_login_page():
# Guest Login Button with Full Logic
st.markdown('<div class="auth-button">', unsafe_allow_html=True)
if st.button("Login as Guest"):
# Set the authentication flag to True, just like in a real login
st.session_state.authenticated = True

from auth.session_manager import set_session_cookie
# Create a simple, fake user profile for the Guest
st.session_state.user_profile = {
user_profile = {
"name": "Guest Healer",
"email": "guest@talkheal.app",
"profile_picture": None,
"join_date": datetime.now().strftime("%B %Y"),
"font_size": "Medium"
}
# Fix: Use the user_profile instead of undefined 'user' variable
st.session_state.user_name = st.session_state.user_profile["name"]
# Set the authentication flag
st.session_state.authenticated = True
st.session_state.user_profile = user_profile
st.session_state.user_name = user_profile["name"]
# Save session to cookie
set_session_cookie("guest@talkheal.app", user_profile)
# Rerun the app to enter the main dashboard
st.rerun()
st.markdown('</div>', unsafe_allow_html=True)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ bcrypt
pygame
streamlit-modal
streamlit_js_eval
extra-streamlit-components
scikit-learn
joblib
numpy
Expand Down