-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
158 lines (124 loc) · 5.52 KB
/
Copy pathapp.py
File metadata and controls
158 lines (124 loc) · 5.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import io
import numpy as np
import streamlit as st
from PIL import Image
import tensorflow as tf
st.set_page_config(page_title="Facial Emotion Classification", layout="centered")
st.title("Facial Emotion Classification")
CLASS_NAMES = ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']
MODEL_PATHS = {
"enhanced_cnn": "ML_project_/enhanced_cnn_20251127_115643/best_model.keras",
"stanford_cnn": "ML_project_/stanford_cnn_20251127_125730/stanford_five_layer_final_0.7158.keras",
"resnet50": "ML_project_/resnet50_fixed_20251204_100110/resnet50_fer_final.keras",
}
#load the models
@st.cache_resource
def load_models(paths: dict):
out = {}
for name, p in paths.items():
out[name] = tf.keras.models.load_model(p)
return out
try:
models = load_models(MODEL_PATHS)
except Exception as e:
st.error(f"Failed to load one of the .keras models.\n\n{e}")
st.stop()
# pre-processing strategies used to train and test the models.
def pil_to_rgb_np(pil_img):
return np.array(pil_img.convert("RGB"), dtype=np.uint8)
def resize_np_uint8(img_uint8, size_hw):
# use PIL for simple resizing
pil = Image.fromarray(img_uint8)
pil = pil.resize((size_hw[1], size_hw[0]))
return np.array(pil, dtype=np.uint8)
def preprocess_for_model(model_name: str, img_uint8_rgb: np.ndarray):
"""
Returns float32 batch for each model.
"""
if "resnet50" in model_name:
x = resize_np_uint8(img_uint8_rgb, (224, 224))
x = x.astype(np.float32)
x = tf.keras.applications.resnet50.preprocess_input(x)
x = np.expand_dims(x, axis=0)
return x
if "stanford_cnn" in model_name:
x = resize_np_uint8(img_uint8_rgb, (48, 48)).astype(np.float32)
gray = np.dot(x[..., :3], [0.2989, 0.5870, 0.1140])
gray = np.expand_dims(gray, axis=-1)
gray = gray / 255.0
gray = np.expand_dims(gray, axis=0)
return gray.astype(np.float32)
if "enhanced_cnn" in model_name:
x = resize_np_uint8(img_uint8_rgb, (48, 48)).astype(np.float32) / 255.0
x = np.expand_dims(x, axis=0)
return x.astype(np.float32)
x = resize_np_uint8(img_uint8_rgb, (48, 48)).astype(np.float32) / 255.0
x = np.expand_dims(x, axis=0)
return x.astype(np.float32)
def ensure_probs(y):
"""
Converts model output to a 1D probability vector [C].
"""
y = np.asarray(y)
if y.ndim == 2:
y = y[0]
if not (np.all(y >= 0) and np.all(y <= 1.0) and np.isclose(np.sum(y), 1.0, atol=1e-3)):
y = tf.nn.softmax(y).numpy()
return y.astype(np.float32)
#predicition function
def predict_ensemble(pil_img):
img_uint8 = pil_to_rgb_np(pil_img)
per_model_results = {}
for name, model in models.items():
x = preprocess_for_model(name, img_uint8)
y = model.predict(x, verbose=0)
probs = ensure_probs(y)
if probs.shape[0] != len(CLASS_NAMES):
raise ValueError(
f"{name} outputs {probs.shape[0]} classes, but CLASS_NAMES has {len(CLASS_NAMES)}."
)
# Store both probabilities and max confidence info
max_idx = int(np.argmax(probs))
max_conf = float(probs[max_idx])
per_model_results[name] = {
'probs': probs,
'predicted_class': CLASS_NAMES[max_idx],
'max_confidence': max_conf,
'predicted_idx': max_idx
}
#model with highest probability
best_model = max(per_model_results.keys(),
key=lambda k: per_model_results[k]['max_confidence'])
ensemble_result = per_model_results[best_model]
return per_model_results, ensemble_result, best_model
# Basic UI to upload image and get predicitions
uploaded = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg", "webp", "bmp"])
if uploaded is not None:
img = Image.open(io.BytesIO(uploaded.read()))
st.image(img, caption="Uploaded image", use_container_width=True)
if st.button("Predict", type="primary"):
try:
per_model_results, ensemble_result, best_model = predict_ensemble(img)
# Show ensemble result (from best model)
st.subheader("Ensemble Prediction")
col1, col2 = st.columns([2, 1])
with col1:
st.success(f"**{ensemble_result['predicted_class']}**")
with col2:
st.metric("Confidence", f"{ensemble_result['max_confidence']:.3f}")
st.subheader("All Model Predictions")
for name in ["enhanced_cnn", "stanford_cnn", "resnet50"]:
if name in per_model_results:
result = per_model_results[name]
st.write(f"**{name}** → {result['predicted_class']} (conf: {result['max_confidence']:.3f})")
st.subheader(f"Detailed Probabilities ({best_model})")
winning_probs = ensemble_result['probs']
for i, cname in enumerate(CLASS_NAMES):
prob_val = float(winning_probs[i])
if i == ensemble_result['predicted_idx']:
st.write(f"**{cname}**: {prob_val:.4f}")
else:
st.write(f"{cname}: {prob_val:.4f}")
st.progress(prob_val)
except Exception as e:
st.error(f"Prediction failed: {e}")