Skip to content

Commit 5a8ab31

Browse files
committed
routers
1 parent fa1775d commit 5a8ab31

File tree

7 files changed

+164
-194
lines changed

7 files changed

+164
-194
lines changed

app.py renamed to bll.py

+35-30
Original file line numberDiff line numberDiff line change
@@ -11,44 +11,48 @@
1111

1212
from models.sample_model.labels import cars_models
1313

14-
# Initialize the models
15-
model_sample_model = YOLO("./models/sample_model/yolov8n.pt")
14+
# Initialize the model
15+
model_sample_model = YOLO('./models/sample_model/yolov8n.pt')
1616

1717

18+
# region Common func
19+
1820
def get_image_from_bytes(binary_image: bytes) -> Image:
19-
""" Convert image from bytes to PIL RGB format. """
20-
input_image = Image.open(io.BytesIO(binary_image)).convert("RGB")
21+
''' Convert image from bytes to PIL RGB format. '''
22+
input_image = Image.open(io.BytesIO(binary_image)).convert('RGB')
2123
return input_image
2224

2325

2426
def get_bytes_from_image(image: Image) -> bytes:
25-
""" Convert PIL image to Bytes. """
27+
''' Convert PIL image to Bytes. '''
2628
return_image = io.BytesIO()
2729
image.save(return_image, format='JPEG', quality=100)
2830
return_image.seek(0)
2931
return return_image
3032

33+
# endregion
34+
35+
# region Yolo func
36+
3137

3238
def transform_predict_to_df(results: list, labeles_dict: dict) -> pd.DataFrame:
33-
"""
34-
Transform predict from yolov8 (torch.Tensor) to pandas DataFrame.
35-
"""
39+
''' Transform predict from yolov8 (torch.Tensor) to pandas DataFrame. '''
3640
predict_bbox = pd.DataFrame(
37-
results[0].to("cpu").numpy().boxes.xyxy, columns=[
38-
'xmin', 'ymin', 'xmax', 'ymax']
41+
results[0].to('cpu').numpy().boxes.xyxy, columns=[
42+
'xmin', "ymin", 'xmax', 'ymax']
3943
)
40-
predict_bbox['confidence'] = results[0].to("cpu").numpy().boxes.conf
44+
predict_bbox['confidence'] = results[0].to('cpu').numpy().boxes.conf
4145
predict_bbox['class'] = (
42-
results[0].to("cpu").numpy().boxes.cls).astype(int)
43-
predict_bbox['name'] = predict_bbox["class"].replace(labeles_dict)
46+
results[0].to('cpu').numpy().boxes.cls).astype(int)
47+
predict_bbox['name'] = predict_bbox['class'].replace(labeles_dict)
4448
return predict_bbox
4549

4650

4751
def get_model_predict(
4852
model: YOLO, input_image: Image,
4953
save: bool = False, image_size: int = 1248,
5054
conf: float = 0.5, augment: bool = False) -> pd.DataFrame:
51-
""" Get the predictions of a model on an input image. """
55+
''' Get the predictions of a model on an input image. '''
5256
predictions = model.predict(
5357
imgsz=image_size,
5458
source=input_image,
@@ -63,10 +67,8 @@ def get_model_predict(
6367
return predictions
6468

6569

66-
# ----------------------- BBOX Func -----------------------
67-
6870
def add_bboxs_on_img(image: Image, predict: pd.DataFrame()) -> Image:
69-
""" Add a bounding box on the image. """
71+
''' Add a bounding box on the image. '''
7072
annotator = Annotator(np.array(image))
7173
predict = predict.sort_values(by=['xmin'], ascending=True)
7274
for i, row in predict.iterrows():
@@ -77,7 +79,7 @@ def add_bboxs_on_img(image: Image, predict: pd.DataFrame()) -> Image:
7779

7880

7981
def detect_sample_model(input_image: Image) -> pd.DataFrame:
80-
""" Predict from sample_model. Base on YoloV8. """
82+
''' Predict from sample_model. Base on YoloV8. '''
8183
predict = get_model_predict(
8284
model=model_sample_model,
8385
input_image=input_image,
@@ -88,14 +90,16 @@ def detect_sample_model(input_image: Image) -> pd.DataFrame:
8890
)
8991
return predict
9092

93+
# endregion
94+
95+
# region Car classification func
9196

92-
#------------------------MODELS CLASSIFICATION------------------------
9397

9498
class YOLOSegmentation:
95-
"""
99+
'''
96100
Useful class to get bboxes, classes, segmentations, scores in correct
97101
format to pass them to cv2 image processed functions.
98-
"""
102+
'''
99103
def __init__(self, model_path):
100104
self.model = YOLO(model_path)
101105

@@ -117,7 +121,7 @@ def detect(self, img):
117121
return bboxes, class_ids, segmentation_contours_idx, scores
118122

119123

120-
#model for segmentation
124+
# Initialize model for segmentation
121125
model_for_classify = YOLOSegmentation("./models/sample_model/yolov8m-seg.pt")
122126

123127
# for model classification project we will segment only cars, busses, truckes
@@ -129,7 +133,7 @@ def detect(self, img):
129133

130134

131135
def extract_segment_image(img, segmentator=model_for_classify) -> Image:
132-
""" Extract segmented and cropped car, truck, bus and return as PIL"""
136+
''' Extract segmented and cropped car, truck, bus and return as PIL. '''
133137
open_cv_image = np.array(img)
134138
open_cv_image = open_cv_image[:, :, ::-1].copy()
135139

@@ -151,16 +155,16 @@ def extract_segment_image(img, segmentator=model_for_classify) -> Image:
151155

152156

153157
def load_model():
154-
""" Load and evaluate saved model. """
158+
''' Load and evaluate saved model. '''
155159
model = torch.load(
156160
'./models/sample_model/model_mob_netv3_79_perc.pth',
157161
map_location=torch.device('cpu'))
158162
model.eval()
159163
return model
160164

161165

162-
def image_to_tensor(cv2_img: np.ndarray) -> torch.Tensor:
163-
""" Converting cv2 output to torch tensor. """
166+
def image_to_tensor(cv2_img: np.ndarray) -> torch.Tensor:
167+
''' Converting cv2 output to torch tensor. '''
164168
image = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB)
165169
transform = transforms.Compose([
166170
transforms.ToTensor(),
@@ -180,6 +184,7 @@ def draw_predictions_on_image(img: Image, label: str) -> Image:
180184

181185

182186
def predict_brand_and_model(img, segmentator=model_for_classify):
187+
''' Predict car's brand and model using PyTorch model. '''
183188
open_cv_image = np.array(img)
184189
open_cv_image = open_cv_image[:, :, ::-1].copy()
185190

@@ -201,7 +206,7 @@ def predict_brand_and_model(img, segmentator=model_for_classify):
201206
answer = predict.argmax(-1)
202207
name = cars_models.get(answer.item()).split('_')
203208
name = f'Brand: {name[0]}, model: {name[1]}'
204-
im_rgb = cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB)
205-
# PIL_im = Image.fromarray(im_rgb)
206-
PIL_im = draw_predictions_on_image(img, name)
207-
return PIL_im
209+
im_PIL = draw_predictions_on_image(img, name)
210+
return im_PIL
211+
212+
# endregion

logger_cfg.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
config = {
2+
"handlers": [
3+
{"sink": "log.log",
4+
"format": "<green>{time:HH:mm:ss}</green> | <level>{message}</level>",
5+
'rotation': "1 MB", 'level': "DEBUG", 'compression': "zip"},
6+
],
7+
}
8+
9+
10+
# logger.remove()
11+
# logger.add(
12+
# colorize=True,
13+
# format="<green>{time:HH:mm:ss}</green> | <level>{message}</level>",
14+
# level=10,
15+
# )
16+
# logger.add("log.log", rotation="1 MB", level="DEBUG", compression="zip")

main.py

+14-163
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,22 @@
11
import json
2-
import pandas as pd
3-
from PIL import Image
4-
from loguru import logger
5-
import sys
62

7-
from fastapi import FastAPI, File, status
3+
from fastapi import FastAPI
84
from fastapi.responses import RedirectResponse
9-
from fastapi.responses import StreamingResponse
105
from fastapi.middleware.cors import CORSMiddleware
11-
from fastapi.exceptions import HTTPException
126

13-
from app import get_image_from_bytes
14-
from app import detect_sample_model
15-
from app import add_bboxs_on_img
16-
from app import get_bytes_from_image
17-
from app import extract_segment_image
18-
from app import predict_brand_and_model
7+
from routers import check, classify, try_yolo
198

209

21-
from models.sample_model.labels import cars_models
22-
from models.sample_model.yolo_labels import yolo_classes
10+
# region FastAPI SetUp
2311

24-
25-
logger.remove()
26-
logger.add(
27-
sys.stderr,
28-
colorize=True,
29-
format="<green>{time:HH:mm:ss}</green> | <level>{message}</level>",
30-
level=10,
31-
)
32-
logger.add("log.log", rotation="1 MB", level="DEBUG", compression="zip")
33-
34-
# ------------------ FastAPI Setup ------------------
35-
36-
# title
3712
app = FastAPI(
38-
title="Object Detection FastAPI Template",
13+
title="Car models Classification FastAPI",
3914
description="""Obtain object value out of image
40-
and return image and json result""",
41-
version="2023.1.31",
15+
and return image with label""",
16+
version="2023.7.16",
4217
)
4318

44-
# This function is needed if you want to allow client requests
45-
# from specific domains (specified in the origins argument)
46-
# to access resources from the FastAPI server,
47-
# and the client and server are hosted on different domains.
19+
4820
origins = [
4921
"http://localhost",
5022
"http://localhost:8008",
@@ -62,144 +34,23 @@
6234

6335
@app.on_event("startup")
6436
def save_openapi_json():
65-
'''This function is used to save the OpenAPI documentation
37+
''' This function is used to save the OpenAPI documentation
6638
data of the FastAPI application to a JSON file.
6739
The purpose of saving the OpenAPI documentation data is to
6840
which can be used for documentation purposes or
69-
to generate client libraries. It is not necessarily needed,
70-
but can be helpful in certain scenarios.'''
41+
to generate client libraries. '''
7142
openapi_data = app.openapi()
72-
# Change "openapi.json" to desired filename
7343
with open("openapi.json", "w") as file:
7444
json.dump(openapi_data, file)
7545

46+
# endregion
47+
7648

77-
# redirect
7849
@app.get("/", include_in_schema=False)
7950
async def redirect():
8051
return RedirectResponse("/docs")
8152

8253

83-
@app.get('/healthcheck', status_code=status.HTTP_200_OK)
84-
def perform_healthcheck():
85-
'''
86-
It basically sends a GET request to the route & hopes to get a "200"
87-
response code. Failing to return a 200 response code just enables
88-
the GitHub Actions to rollback to the last version the project was
89-
found in a "working condition". It acts as a last line of defense in
90-
case something goes south.
91-
Additionally, it also returns a JSON response in the form of:
92-
{
93-
'healtcheck': 'Everything OK!'
94-
}
95-
'''
96-
return {'healthcheck': 'Everything OK!'}
97-
98-
99-
# ------------------ Support Func ------------------
100-
101-
def crop_image_by_predict(
102-
image: Image,
103-
predict: pd.DataFrame(),
104-
crop_class_name: str,) -> Image:
105-
"""Crop an image based on the detection of a certain object in the image.
106-
107-
Args:
108-
image: Image to be cropped.
109-
predict (pd.DataFrame): Dataframe containing the prediction results
110-
of object detection model.
111-
crop_class_name (str, optional): The name of the object class to crop
112-
the image by. if not provided, function returns the first object found
113-
in the image.
114-
115-
Returns:
116-
Image: Cropped image or None
117-
"""
118-
crop_predicts = predict[(predict['name'] == crop_class_name)]
119-
120-
if crop_predicts.empty:
121-
raise HTTPException(
122-
status_code=400, detail=f"{crop_class_name} not found in photo")
123-
124-
# if there are several detections, choose the one with more confidence
125-
if len(crop_predicts) > 1:
126-
crop_predicts = crop_predicts.sort_values(
127-
by=['confidence'], ascending=False)
128-
129-
crop_bbox = crop_predicts[['xmin', 'ymin', 'xmax', 'ymax']].iloc[0].values
130-
img_crop = image.crop(crop_bbox)
131-
return img_crop
132-
133-
134-
# ------------------ MAIN Func ------------------
135-
136-
@app.get('/available_models', status_code=status.HTTP_200_OK)
137-
def get_all_models():
138-
'''
139-
To check the list of cars' models, which are avalable for recognition.
140-
'''
141-
return cars_models
142-
143-
@app.get('/yolo_classes', status_code=status.HTTP_200_OK)
144-
def get_yolo_classes():
145-
'''
146-
To check the list of cars' models, which are avalable for recognition.
147-
'''
148-
return yolo_classes
149-
150-
@app.post("/object_detection_with_yolo")
151-
def object_detection_with_yolo(file: bytes = File(...)):
152-
"""
153-
Object Detection from an image plot bbox on image. Using Yolo8.
154-
155-
Args:
156-
file (bytes): The image file in bytes format.
157-
Returns:
158-
Image: Image in bytes with bbox annotations.
159-
"""
160-
# get image from bytes
161-
input_image = get_image_from_bytes(file)
162-
163-
# model predict
164-
predict = detect_sample_model(input_image)
165-
166-
# add bbox on image
167-
final_image = add_bboxs_on_img(image=input_image, predict=predict)
168-
169-
# return image in bytes format
170-
return StreamingResponse(
171-
content=get_bytes_from_image(final_image), media_type="image/jpeg")
172-
173-
174-
@logger.catch
175-
@app.post("/car_brand_model_classification")
176-
def car_brand_model_classification(file: bytes = File(...)):
177-
"""
178-
Object Detection from an image plot bbox on image. Using Yolo8.
179-
180-
Args:
181-
file (bytes): The image file in bytes format.
182-
Returns:
183-
Image: Image in bytes with bbox annotations.
184-
"""
185-
input_image = get_image_from_bytes(file)
186-
predict = predict_brand_and_model(input_image)
187-
return StreamingResponse(
188-
content=get_bytes_from_image(predict), media_type="image/jpeg")
189-
190-
191-
@app.post("/car_model_segment_and_crop")
192-
def car_model_segment_and_crop(file: bytes = File(...)):
193-
"""
194-
Object Detection from an image plot bbox on image. Using Yolo8.
195-
196-
Args:
197-
file (bytes): The image file in bytes format.
198-
Returns:
199-
Image: Image in bytes with bbox annotations.
200-
"""
201-
input_image = get_image_from_bytes(file)
202-
predict = extract_segment_image(input_image)
203-
204-
return StreamingResponse(
205-
content=get_bytes_from_image(predict), media_type="image/jpeg")
54+
app.include_router(check.router)
55+
app.include_router(classify.router)
56+
app.include_router(try_yolo.router)

0 commit comments

Comments
 (0)