15
15
from database .database import DataBase
16
16
from difference .diff import DiffProcessor
17
17
from edge .info import TASK_STATE
18
+ from edge .resample import history_sample , annotion_process
18
19
from edge .task import Task
19
20
from edge .transmit import get_cloud_target
20
21
from grpc_server import message_transmission_pb2_grpc , message_transmission_pb2
21
22
from grpc_server .rpc_server import MessageTransmissionServicer
22
23
23
24
from tools .convert_tool import cv2_to_base64
24
- from tools .file_op import clear_folder , creat_folder
25
+ from tools .file_op import clear_folder , creat_folder , sample_files
25
26
from tools .preprocess import frame_resize
26
27
from model_management .object_detection import Object_Detection
27
28
from apscheduler .schedulers .background import BackgroundScheduler
@@ -53,53 +54,37 @@ def __init__(self, config):
53
54
self .avg_scores = []
54
55
self .select_index = []
55
56
self .annotations = []
56
- # read video
57
- self .video_reader = threading .Thread (target = self .video_read ,)
58
- self .video_reader .start ()
57
+
59
58
# start the thread for diff
60
59
self .diff = 0
61
60
self .key_task = None
62
- self .diff_processor = threading .Thread (target = self .diff_worker ,)
61
+ self .diff_processor = threading .Thread (target = self .diff_worker ,daemon = True )
63
62
self .diff_processor .start ()
64
63
65
64
# start the thread for edge server
66
- #self.edge_server = threading.Thread(target=self.start_edge_server, )
65
+ #self.edge_server = threading.Thread(target=self.start_edge_server, daemon=True )
67
66
#self.edge_server.start()
68
67
69
68
# start the thread for local process
70
- self .local_processor = threading .Thread (target = self .local_worker ,)
69
+ self .local_processor = threading .Thread (target = self .local_worker ,daemon = True )
71
70
self .local_processor .start ()
72
71
73
72
# start the thread for retrain process
74
73
self .retrain_flag = False
75
74
self .collect_flag = True
76
75
self .cache_count = 0
77
- self .retrain_processor = threading .Thread (target = self .retrain_worker ,)
76
+
77
+ self .use_history = True
78
+ self .test_only = False
79
+ self .retrain_no = 0
80
+
81
+ self .retrain_processor = threading .Thread (target = self .retrain_worker ,daemon = True )
78
82
self .retrain_processor .start ()
79
83
80
84
# start the thread pool for offload
81
85
self .offloading_executor = futures .ThreadPoolExecutor (max_workers = config .offloading_max_worker ,)
82
86
83
- def video_read (self ):
84
- with VideoProcessor (self .config .source ) as video :
85
- video_fps = video .fps
86
- logger .info ("the video fps is {}" .format (video_fps ))
87
- index = 0
88
- if self .config .interval == 0 :
89
- logger .error ("the interval error" )
90
- sys .exit (1 )
91
- logger .info ("Take the frame interval is {}" .format (self .config .interval ))
92
- while True :
93
- frame = next (video )
94
- if frame is None :
95
- logger .info ("the video finished" )
96
- break
97
- index += 1
98
- if index % self .config .interval == 0 :
99
- start_time = time .time ()
100
- task = Task (self .edge_id , index , frame , start_time , frame .shape )
101
- self .frame_cache .put (task , block = True )
102
- time .sleep ((self .config .interval * 1.0 ) / video_fps )
87
+
103
88
104
89
105
90
def diff_worker (self ):
@@ -110,7 +95,6 @@ def diff_worker(self):
110
95
self .pre_frame_feature = self .edge_processor .get_frame_feature (frame )
111
96
self .key_task = task
112
97
# Create an entry for the task in the database table
113
- logger .debug ("start time {}" .format (task .start_time ))
114
98
data = (
115
99
task .frame_index ,
116
100
task .start_time ,
@@ -151,9 +135,6 @@ def diff_worker(self):
151
135
task .state = TASK_STATE .FINISHED
152
136
self .update_table (task )
153
137
154
- else :
155
- pass
156
-
157
138
def update_table (self , task ):
158
139
state = "Finished" if task .state == TASK_STATE .FINISHED else ""
159
140
if task .ref is not None :
@@ -207,7 +188,7 @@ def local_worker(self):
207
188
task .frame_cloud = offloading_image
208
189
self .offloading_executor .submit (self .offload_worker , task )
209
190
end_time = time .time ()
210
- task .set_end_time ( end_time )
191
+ task .end_time = end_time
211
192
task .state = TASK_STATE .FINISHED
212
193
# upload the result to database
213
194
self .update_table (task )
@@ -298,55 +279,74 @@ def offload_worker(self, task, destination_edge_id=None):
298
279
else :
299
280
logger .info (str (res ))
300
281
301
- #
282
+ # collect data for retrain
302
283
def collect_data (self , task , frame ,detection_boxes , detection_class , detection_score ):
303
- self .select_index = []
304
- creat_folder (self .config .retrain .cache_path )
305
- cv2 .imwrite (os .path .join (self .config .retrain .cache_path ,'frames' , str (task .frame_index ) + '.jpg' ), frame )
306
- self .avg_scores .append ({task .frame_index :np .mean (detection_score )})
307
- self .cache_count += 1
308
- logger .debug ("count {}" .format (self .cache_count ))
309
- for score , label , box in zip (detection_score , detection_class , detection_boxes ):
310
- self .pred_res .append ((task .frame_index , label , box [0 ], box [1 ], box [2 ], box [3 ], score ))
284
+ if detection_score is not None :
285
+ creat_folder (self .config .retrain .cache_path )
286
+ cv2 .imwrite (os .path .join (self .config .retrain .cache_path ,'frames' , str (task .frame_index ) + '.jpg' ), frame )
287
+ self .avg_scores .append ({task .frame_index : np .mean (detection_score )})
288
+ self .cache_count += 1
289
+ logger .debug ("count {}" .format (self .cache_count ))
311
290
if self .cache_count >= self .config .retrain .collect_num :
291
+ self .retrain_no += 1
312
292
logger .debug ("enough" )
313
293
smallest_elements = sorted (self .avg_scores , key = lambda d : list (d .values ())[0 ])[:self .config .retrain .select_num ]
314
294
self .select_index = [list (d .keys ())[0 ] for d in smallest_elements ]
315
- print (self .select_index )
316
-
317
- np .savetxt (os .path .join (self .config .retrain .cache_path ,'pred_res.txt' ), self .pred_res ,
318
- fmt = ['%d' , '%d' , '%f' , '%f' , '%f' , '%f' , '%f' ], delimiter = ',' )
319
-
295
+ logger .debug (self .select_index )
320
296
self .pred_res = []
321
- self .retrain_flag = True
322
297
self .collect_flag = False
323
298
self .cache_count = 0
299
+ if self .retrain_no % self .config .retrain .interval == 0 :
300
+ self .test_only = True
301
+ else :
302
+ self .test_only = False
303
+ self .retrain_flag = True
324
304
325
305
326
306
# retrain
327
307
def retrain_worker (self ):
308
+ self .annotations = []
328
309
while True :
329
310
if self .retrain_flag :
330
311
logger .debug ("retrain" )
331
- self . annotations = []
312
+
332
313
for index in self .select_index :
333
314
path = os .path .join (self .config .retrain .cache_path , 'frames' , '{}.jpg' .format (index ))
315
+ logger .debug (path )
334
316
frame = cv2 .imread (path )
335
- logger .debug ("get index {}" .format (index ))
317
+ logger .debug ("get index {} {} " .format (index , time . time () ))
336
318
target_res = get_cloud_target (self .config .server_ip , frame )
337
- logger .debug ("get target {}" .format (target_res ))
319
+ logger .debug ("get target {} {} " .format (target_res , time . time () ))
338
320
for score , label , box in zip (target_res ['scores' ], target_res ['labels' ], target_res ['boxes' ]):
339
321
self .annotations .append ((index , label , box [0 ], box [1 ], box [2 ], box [3 ], score ))
340
322
if len (self .annotations ):
341
323
np .savetxt (os .path .join (self .config .retrain .cache_path ,'annotation.txt' ), self .annotations ,
342
324
fmt = ['%d' , '%d' , '%f' , '%f' , '%f' , '%f' , '%f' ], delimiter = ',' )
343
325
344
326
logger .debug ("select num {}" .format (int (self .config .retrain .select_num * 0.8 )))
345
- self .small_object_detection .retrain (self .config .retrain .cache_path , self .select_index [:int (self .config .retrain .select_num * 0.8 )])
346
- self .small_object_detection .model_evaluation (self .config .retrain .cache_path , self .select_index [int (self .config .retrain .select_num * 0.8 ):])
347
- self .collect_flag = True
327
+ if self .test_only :
328
+ logger .debug ("test only" )
329
+ self .small_object_detection .model_evaluation (
330
+ self .config .retrain .cache_path , self .select_index )
331
+ else :
332
+ self .small_object_detection .model_evaluation (
333
+ self .config .retrain .cache_path , self .select_index [int (self .config .retrain .select_num * 0.8 ):])
334
+ self .small_object_detection .retrain (
335
+ self .config .retrain .cache_path , self .select_index [:int (self .config .retrain .select_num * 0.8 )])
336
+ self .small_object_detection .model_evaluation (
337
+ self .config .retrain .cache_path , self .select_index [int (self .config .retrain .select_num * 0.8 ):])
348
338
self .retrain_flag = False
349
- clear_folder (self .config .retrain .cache_path )
339
+ if self .use_history :
340
+ self .select_index ,self .avg_scores = history_sample (self .select_index ,self .avg_scores )
341
+ self .annotations = annotion_process (self .annotations , self .select_index )
342
+ sample_files (os .path .join (self .config .retrain .cache_path , 'frames' ) ,self .select_index )
343
+ self .cache_count = len (self .select_index )
344
+ else :
345
+ clear_folder (self .config .retrain .cache_path )
346
+ self .select_index = []
347
+ self .avg_scores = []
348
+ self .annotations = []
349
+ self .collect_flag = True
350
350
time .sleep (1 )
351
351
352
352
0 commit comments