Skip to content

Commit 445d625

Browse files
authored
Merge pull request #363 from JdeRobot/issue-362
Upgrade streamlit and increase model file limit size
2 parents 5dc3127 + c216e5f commit 445d625

5 files changed

Lines changed: 32 additions & 63 deletions

File tree

app.py

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import streamlit as st
2-
import os
32
import sys
43
import subprocess
54
from tabs.dataset_viewer import dataset_viewer_tab
@@ -67,6 +66,10 @@ def browse_folder():
6766
return None
6867

6968

69+
def browse_dataset_path():
70+
st.session_state.dataset_path = browse_folder()
71+
72+
7073
st.set_page_config(page_title="PerceptionMetrics", layout="wide")
7174

7275
PAGES = {
@@ -77,8 +80,8 @@ def browse_folder():
7780

7881
# Initialize commonly used session state keys
7982
st.session_state.setdefault("dataset_path", "")
80-
st.session_state.setdefault("dataset_type_selectbox", "COCO")
81-
st.session_state.setdefault("split_selectbox", "val")
83+
st.session_state.setdefault("dataset_type", "COCO")
84+
st.session_state.setdefault("split", "val")
8285
st.session_state.setdefault("config_option", "Manual Configuration")
8386
st.session_state.setdefault("confidence_threshold", 0.5)
8487
st.session_state.setdefault("nms_threshold", 0.5)
@@ -95,66 +98,44 @@ def browse_folder():
9598
# First row: Type and Split
9699
col1, col2 = st.columns(2)
97100
with col1:
98-
dataset_type_selectbox = st.selectbox(
101+
st.selectbox(
99102
"Type",
100103
["COCO", "YOLO"],
101-
key="dataset_type_selectbox",
104+
key="dataset_type",
102105
)
103106
with col2:
104107
st.selectbox(
105108
"Split",
106109
["train", "val", "test"],
107-
key="split_selectbox",
110+
key="split",
108111
)
109112

110113
# Second row: Path and Browse button
111114
col1, col2 = st.columns([3, 1])
112115
with col1:
113-
dataset_path_input = st.text_input(
114-
"Dataset Folder",
115-
value=st.session_state.get("dataset_path", ""),
116-
key="dataset_path_input",
117-
)
116+
st.text_input("Dataset Folder", key="dataset_path")
118117
with col2:
119118
st.markdown(
120119
"<div style='margin-bottom: 1.75rem;'></div>", unsafe_allow_html=True
121120
)
122-
if st.button("Browse", key="browse_button"):
123-
folder = browse_folder()
124-
if folder and os.path.isdir(folder):
125-
st.session_state["dataset_path"] = folder
126-
st.rerun()
127-
elif folder is not None:
128-
st.warning("Selected path is not a valid folder.")
129-
else:
130-
st.warning(
131-
"Could not open folder browser. Please enter the path manually"
132-
)
133-
134-
if dataset_path_input != st.session_state.get("dataset_path", ""):
135-
st.session_state["dataset_path"] = dataset_path_input
136-
if dataset_type_selectbox != st.session_state.get("dataset_type", ""):
137-
st.session_state["dataset_type"] = dataset_type_selectbox
121+
st.button("Browse", on_click=browse_dataset_path)
138122

139123
# Additional input for YOLO config file
140-
if dataset_type_selectbox == "YOLO":
141-
dataset_config_file_uploader = st.file_uploader(
124+
if st.session_state.get("dataset_type", "COCO") == "YOLO":
125+
st.file_uploader(
142126
"Dataset Configuration (.yaml)",
143127
type=["yaml"],
144128
key="dataset_config_file",
145129
help="Upload a YAML dataset configuration file.",
146130
)
147-
if dataset_config_file_uploader != st.session_state.get(
148-
"dataset_config_file", None
149-
):
150-
st.session_state["dataset_config_file"] = dataset_config_file_uploader
151131

152132
with st.expander("Model Inputs", expanded=False):
153133
st.file_uploader(
154134
"Model File (.pt, .onnx, .h5, .pb, .pth, .torchscript)",
155135
type=["pt", "onnx", "h5", "pb", "pth", "torchscript"],
156136
key="model_file",
157137
help="Upload your trained model file.",
138+
max_upload_size=1024, # MB
158139
)
159140
st.file_uploader(
160141
"Ontology File (.json)",
@@ -254,14 +235,15 @@ def browse_folder():
254235
key="resize_width",
255236
help="Width to resize images for inference",
256237
)
238+
257239
# Load model action in sidebar
258240
from perceptionmetrics.models.torch_detection import TorchImageDetectionModel
259241
import json, tempfile
260242

261243
load_model_btn = st.button(
262244
"Load Model",
263245
type="primary",
264-
use_container_width=True,
246+
width="stretch",
265247
help="Load and save the model for use in the Inference tab",
266248
key="sidebar_load_model_btn",
267249
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ click = "^8.1.8"
2424
tensorboard = "^2.18.0"
2525
pycocotools = { version = "^2.0.7", markers = "sys_platform != 'win32'" }
2626
pycocotools-windows = { version = "^2.0.0.2", markers = "sys_platform == 'win32'" }
27-
Streamlit = "1.46.0"
27+
Streamlit = "1.54.0"
2828
streamlit-image-select = "^0.6.0"
2929
supervision = "^0.18.0"
3030

tabs/dataset_viewer.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ def dataset_viewer_tab():
77
import tempfile
88
from perceptionmetrics.datasets.coco import CocoDataset
99
from perceptionmetrics.datasets.yolo import YOLODataset
10-
import supervision as sv
1110
import numpy as np
1211
from PIL import Image
1312
from supervision.draw.color import ColorPalette
@@ -16,8 +15,8 @@ def dataset_viewer_tab():
1615

1716
# Get inputs from session state
1817
dataset_path = st.session_state.get("dataset_path", "")
19-
dataset_type = st.session_state.get("dataset_type_selectbox", "COCO").lower()
20-
split = st.session_state.get("split_selectbox", "val")
18+
dataset_type = st.session_state.get("dataset_type", "COCO").lower()
19+
split = st.session_state.get("split", "val")
2120

2221
# Header row only
2322
st.header("Dataset Viewer")
@@ -27,7 +26,6 @@ def dataset_viewer_tab():
2726
return
2827

2928
# Setup paths and pagination
30-
# Setup paths and pagination
3129
if dataset_type == "coco":
3230
img_dir = os.path.join(dataset_path, f"images/{split}2017")
3331
ann_file = os.path.join(
@@ -162,7 +160,7 @@ def dataset_viewer_tab():
162160

163161
# Pagination
164162
IMAGES_PER_PAGE = 12
165-
total_images, total_pages = (
163+
_, total_pages = (
166164
len(image_files),
167165
(len(image_files) + IMAGES_PER_PAGE - 1) // IMAGES_PER_PAGE,
168166
)
@@ -197,7 +195,7 @@ def dataset_viewer_tab():
197195
st.rerun()
198196
with col2:
199197
st.markdown(
200-
f"<div style='text-align:center;font-weight:bold;'>Page {current_page+1} of {total_pages}</div>",
198+
f"<div style='text-align:center;font-weight:bold;'>Page {current_page + 1} of {total_pages}</div>",
201199
unsafe_allow_html=True,
202200
)
203201
with col3:
@@ -220,7 +218,7 @@ def dataset_viewer_tab():
220218
col1, col2, col3 = st.columns([4, 1, 1])
221219
with col1:
222220
selected_img = st.selectbox(
223-
"Search image:", options=image_files, key="search_image_selectbox"
221+
"Search image:", options=image_files, key="search_image"
224222
)
225223
with col2:
226224
st.markdown(
@@ -231,7 +229,7 @@ def dataset_viewer_tab():
231229
st.session_state[page_key] = new_page
232230
st.session_state[
233231
f"img_select_all_{dataset_path}_{split}_{new_page}"
234-
] = (image_files.index(selected_img) % IMAGES_PER_PAGE)
232+
] = image_files.index(selected_img) % IMAGES_PER_PAGE
235233
st.session_state["show_search_dropdown"] = False
236234
st.rerun()
237235
with col3:
@@ -252,7 +250,7 @@ def dataset_viewer_tab():
252250
label="",
253251
images=image_paths,
254252
captions=sample_images,
255-
use_container_width=True,
253+
use_container_width=False,
256254
key=img_select_key,
257255
index=img_select_index,
258256
)
@@ -314,7 +312,7 @@ def dataset_viewer_tab():
314312
except AttributeError:
315313
resample = Image.LANCZOS
316314
annotated_pil.thumbnail((500, 500), resample)
317-
st.image(annotated_pil, use_container_width=False)
315+
st.image(annotated_pil, width="content")
318316
else:
319317
st.warning("No annotation found for this image.")
320318
except Exception as e:

tabs/evaluator.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
import os
33
import tempfile
44
import json
5-
import pandas as pd
6-
from perceptionmetrics.models.torch_detection import TorchImageDetectionModel
75
from perceptionmetrics.datasets.coco import CocoDataset
86

97

@@ -19,8 +17,8 @@ def evaluator_tab():
1917

2018
# Check for dataset from sidebar inputs
2119
dataset_path = st.session_state.get("dataset_path", "")
22-
dataset_type = st.session_state.get("dataset_type_selectbox", "Coco")
23-
split = st.session_state.get("split_selectbox", "val")
20+
dataset_type = st.session_state.get("dataset_type", "Coco")
21+
split = st.session_state.get("split", "val")
2422

2523
# Try to get existing dataset from session state first
2624
dataset_key = f"{dataset_path}_{split}"
@@ -134,11 +132,6 @@ def evaluator_tab():
134132
if save_predictions:
135133
predictions_outdir = tempfile.mkdtemp(prefix="eval_predictions_")
136134

137-
# Use model config as is (no confidence threshold override)
138-
eval_config = model.model_cfg.copy()
139-
140-
# Ready to evaluate
141-
142135
# Create progress bar for evaluation
143136
progress_bar = st.progress(0)
144137
status_text = st.empty()
@@ -205,7 +198,7 @@ def metrics_callback(metrics_df, processed, total):
205198

206199
with intermediate_table_placeholder.container():
207200
st.markdown("#### Per-Class Metrics (Intermediate)")
208-
st.dataframe(display_df, use_container_width=True)
201+
st.dataframe(display_df, width="stretch")
209202

210203
except Exception as e:
211204
st.error(f"Metrics callback error: {e}")
@@ -321,7 +314,7 @@ def display_evaluation_results(results):
321314
if col in display_df.columns:
322315
display_df[col] = display_df[col].round(3)
323316

324-
st.dataframe(display_df, use_container_width=True)
317+
st.dataframe(display_df, width="stretch")
325318

326319
# Now display Precision-Recall Curve
327320
if metrics_factory is not None:
@@ -334,7 +327,6 @@ def display_evaluation_results(results):
334327

335328
# Create the plot using streamlit's plotly integration
336329
import plotly.graph_objects as go
337-
from plotly.subplots import make_subplots
338330

339331
# Create the precision-recall curve
340332
fig = go.Figure()
@@ -379,7 +371,7 @@ def display_evaluation_results(results):
379371
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="lightgray")
380372
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor="lightgray")
381373

382-
st.plotly_chart(fig, use_container_width=True)
374+
st.plotly_chart(fig, width="stretch")
383375

384376
except Exception as e:
385377
st.error(f"Error plotting precision-recall curve: {e}")
@@ -403,7 +395,6 @@ def display_evaluation_results(results):
403395
else None
404396
)
405397
if curve_data is not None:
406-
import io
407398
import pandas as pd
408399

409400
pr_points_df = pd.DataFrame(
@@ -439,7 +430,7 @@ def display_evaluation_results(results):
439430
st.write("Columns:", metrics_df.columns.tolist())
440431

441432
st.markdown("**Sample Data:**")
442-
st.dataframe(metrics_df.head(), use_container_width=True)
433+
st.dataframe(metrics_df.head(), width="stretch")
443434

444435
if "evaluation_config" in st.session_state:
445436
st.markdown("**Evaluation Configuration:**")

tabs/inference.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,7 @@ def inference_tab():
8181
result_img = draw_detections(image.copy(), predictions, label_map)
8282

8383
st.markdown("#### Detection Results")
84-
st.image(
85-
result_img, caption="Detection Results", use_container_width=True
86-
)
84+
st.image(result_img, caption="Detection Results", width="stretch")
8785

8886
# Display detection statistics
8987
if (

0 commit comments

Comments
 (0)