|
1 | 1 | import json
|
2 |
| -import pandas as pd |
3 |
| -from PIL import Image |
4 |
| -from loguru import logger |
5 |
| -import sys |
6 | 2 |
|
7 |
| -from fastapi import FastAPI, File, status |
| 3 | +from fastapi import FastAPI |
8 | 4 | from fastapi.responses import RedirectResponse
|
9 |
| -from fastapi.responses import StreamingResponse |
10 | 5 | from fastapi.middleware.cors import CORSMiddleware
|
11 |
| -from fastapi.exceptions import HTTPException |
12 | 6 |
|
13 |
| -from app import get_image_from_bytes |
14 |
| -from app import detect_sample_model |
15 |
| -from app import add_bboxs_on_img |
16 |
| -from app import get_bytes_from_image |
17 |
| -from app import extract_segment_image |
18 |
| -from app import predict_brand_and_model |
| 7 | +from routers import check, classify, try_yolo |
19 | 8 |
|
20 | 9 |
|
21 |
| -from models.sample_model.labels import cars_models |
22 |
| -from models.sample_model.yolo_labels import yolo_classes |
| 10 | +# region FastAPI SetUp |
23 | 11 |
|
24 |
| - |
25 |
| -logger.remove() |
26 |
| -logger.add( |
27 |
| - sys.stderr, |
28 |
| - colorize=True, |
29 |
| - format="<green>{time:HH:mm:ss}</green> | <level>{message}</level>", |
30 |
| - level=10, |
31 |
| -) |
32 |
| -logger.add("log.log", rotation="1 MB", level="DEBUG", compression="zip") |
33 |
| - |
34 |
| -# ------------------ FastAPI Setup ------------------ |
35 |
| - |
36 |
| -# title |
37 | 12 | app = FastAPI(
|
38 |
| - title="Object Detection FastAPI Template", |
| 13 | + title="Car models Classification FastAPI", |
39 | 14 | description="""Obtain object value out of image
|
40 |
| - and return image and json result""", |
41 |
| - version="2023.1.31", |
| 15 | + and return image with label""", |
| 16 | + version="2023.7.16", |
42 | 17 | )
|
43 | 18 |
|
44 |
| -# This function is needed if you want to allow client requests |
45 |
| -# from specific domains (specified in the origins argument) |
46 |
| -# to access resources from the FastAPI server, |
47 |
| -# and the client and server are hosted on different domains. |
| 19 | + |
48 | 20 | origins = [
|
49 | 21 | "http://localhost",
|
50 | 22 | "http://localhost:8008",
|
|
62 | 34 |
|
63 | 35 | @app.on_event("startup")
|
64 | 36 | def save_openapi_json():
|
65 |
| - '''This function is used to save the OpenAPI documentation |
| 37 | + ''' This function is used to save the OpenAPI documentation |
66 | 38 | data of the FastAPI application to a JSON file.
|
67 | 39 | The purpose of saving the OpenAPI documentation data is to
|
68 | 40 | which can be used for documentation purposes or
|
69 |
| - to generate client libraries. It is not necessarily needed, |
70 |
| - but can be helpful in certain scenarios.''' |
| 41 | + to generate client libraries. ''' |
71 | 42 | openapi_data = app.openapi()
|
72 |
| - # Change "openapi.json" to desired filename |
73 | 43 | with open("openapi.json", "w") as file:
|
74 | 44 | json.dump(openapi_data, file)
|
75 | 45 |
|
| 46 | +# endregion |
| 47 | + |
76 | 48 |
|
77 |
| -# redirect |
78 | 49 | @app.get("/", include_in_schema=False)
|
79 | 50 | async def redirect():
|
80 | 51 | return RedirectResponse("/docs")
|
81 | 52 |
|
82 | 53 |
|
83 |
| -@app.get('/healthcheck', status_code=status.HTTP_200_OK) |
84 |
| -def perform_healthcheck(): |
85 |
| - ''' |
86 |
| - It basically sends a GET request to the route & hopes to get a "200" |
87 |
| - response code. Failing to return a 200 response code just enables |
88 |
| - the GitHub Actions to rollback to the last version the project was |
89 |
| - found in a "working condition". It acts as a last line of defense in |
90 |
| - case something goes south. |
91 |
| - Additionally, it also returns a JSON response in the form of: |
92 |
| - { |
93 |
| - 'healtcheck': 'Everything OK!' |
94 |
| - } |
95 |
| - ''' |
96 |
| - return {'healthcheck': 'Everything OK!'} |
97 |
| - |
98 |
| - |
99 |
| -# ------------------ Support Func ------------------ |
100 |
| - |
101 |
| -def crop_image_by_predict( |
102 |
| - image: Image, |
103 |
| - predict: pd.DataFrame(), |
104 |
| - crop_class_name: str,) -> Image: |
105 |
| - """Crop an image based on the detection of a certain object in the image. |
106 |
| -
|
107 |
| - Args: |
108 |
| - image: Image to be cropped. |
109 |
| - predict (pd.DataFrame): Dataframe containing the prediction results |
110 |
| - of object detection model. |
111 |
| - crop_class_name (str, optional): The name of the object class to crop |
112 |
| - the image by. if not provided, function returns the first object found |
113 |
| - in the image. |
114 |
| -
|
115 |
| - Returns: |
116 |
| - Image: Cropped image or None |
117 |
| - """ |
118 |
| - crop_predicts = predict[(predict['name'] == crop_class_name)] |
119 |
| - |
120 |
| - if crop_predicts.empty: |
121 |
| - raise HTTPException( |
122 |
| - status_code=400, detail=f"{crop_class_name} not found in photo") |
123 |
| - |
124 |
| - # if there are several detections, choose the one with more confidence |
125 |
| - if len(crop_predicts) > 1: |
126 |
| - crop_predicts = crop_predicts.sort_values( |
127 |
| - by=['confidence'], ascending=False) |
128 |
| - |
129 |
| - crop_bbox = crop_predicts[['xmin', 'ymin', 'xmax', 'ymax']].iloc[0].values |
130 |
| - img_crop = image.crop(crop_bbox) |
131 |
| - return img_crop |
132 |
| - |
133 |
| - |
134 |
| -# ------------------ MAIN Func ------------------ |
135 |
| - |
136 |
| -@app.get('/available_models', status_code=status.HTTP_200_OK) |
137 |
| -def get_all_models(): |
138 |
| - ''' |
139 |
| - To check the list of cars' models, which are avalable for recognition. |
140 |
| - ''' |
141 |
| - return cars_models |
142 |
| - |
143 |
| -@app.get('/yolo_classes', status_code=status.HTTP_200_OK) |
144 |
| -def get_yolo_classes(): |
145 |
| - ''' |
146 |
| - To check the list of cars' models, which are avalable for recognition. |
147 |
| - ''' |
148 |
| - return yolo_classes |
149 |
| - |
150 |
| -@app.post("/object_detection_with_yolo") |
151 |
| -def object_detection_with_yolo(file: bytes = File(...)): |
152 |
| - """ |
153 |
| - Object Detection from an image plot bbox on image. Using Yolo8. |
154 |
| -
|
155 |
| - Args: |
156 |
| - file (bytes): The image file in bytes format. |
157 |
| - Returns: |
158 |
| - Image: Image in bytes with bbox annotations. |
159 |
| - """ |
160 |
| - # get image from bytes |
161 |
| - input_image = get_image_from_bytes(file) |
162 |
| - |
163 |
| - # model predict |
164 |
| - predict = detect_sample_model(input_image) |
165 |
| - |
166 |
| - # add bbox on image |
167 |
| - final_image = add_bboxs_on_img(image=input_image, predict=predict) |
168 |
| - |
169 |
| - # return image in bytes format |
170 |
| - return StreamingResponse( |
171 |
| - content=get_bytes_from_image(final_image), media_type="image/jpeg") |
172 |
| - |
173 |
| - |
174 |
| -@logger.catch |
175 |
| -@app.post("/car_brand_model_classification") |
176 |
| -def car_brand_model_classification(file: bytes = File(...)): |
177 |
| - """ |
178 |
| - Object Detection from an image plot bbox on image. Using Yolo8. |
179 |
| -
|
180 |
| - Args: |
181 |
| - file (bytes): The image file in bytes format. |
182 |
| - Returns: |
183 |
| - Image: Image in bytes with bbox annotations. |
184 |
| - """ |
185 |
| - input_image = get_image_from_bytes(file) |
186 |
| - predict = predict_brand_and_model(input_image) |
187 |
| - return StreamingResponse( |
188 |
| - content=get_bytes_from_image(predict), media_type="image/jpeg") |
189 |
| - |
190 |
| - |
191 |
| -@app.post("/car_model_segment_and_crop") |
192 |
| -def car_model_segment_and_crop(file: bytes = File(...)): |
193 |
| - """ |
194 |
| - Object Detection from an image plot bbox on image. Using Yolo8. |
195 |
| -
|
196 |
| - Args: |
197 |
| - file (bytes): The image file in bytes format. |
198 |
| - Returns: |
199 |
| - Image: Image in bytes with bbox annotations. |
200 |
| - """ |
201 |
| - input_image = get_image_from_bytes(file) |
202 |
| - predict = extract_segment_image(input_image) |
203 |
| - |
204 |
| - return StreamingResponse( |
205 |
| - content=get_bytes_from_image(predict), media_type="image/jpeg") |
| 54 | +app.include_router(check.router) |
| 55 | +app.include_router(classify.router) |
| 56 | +app.include_router(try_yolo.router) |
0 commit comments