Skip to content

Commit 0980e30

Browse files
committed
update
1 parent 363ac13 commit 0980e30

File tree

11 files changed

+216
-90
lines changed

11 files changed

+216
-90
lines changed

cal_metrics.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def cal_mAP(self):
7272
sum_delay = 0.0
7373
filtered_out = 0
7474
i = 0
75-
75+
last_result = None
7676
while i < len(result):
7777
index, start_time, end_time, res, log = result[i]
7878
gap = end_time-start_time
@@ -84,6 +84,7 @@ def cal_mAP(self):
8484
pred = eval(res)
8585
else:
8686
pred = eval(res)
87+
last_result = pred
8788
ground_truth = ground_truths['{}'.format(index)]
8889
map = calculate_map(ground_truth, pred)
8990
sum_map += map

cloud_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(self, config):
3434

3535
# start the thread for local process
3636
self.local_queue = Queue(config.local_queue_maxsize)
37-
self.local_processor = threading.Thread(target=self.cloud_local, )
37+
self.local_processor = threading.Thread(target=self.cloud_local, daemon=True)
3838
self.local_processor.start()
3939

4040

config/config.yaml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,20 +47,21 @@ client:
4747
#The task from another edge node is offloaded to the cloud after local inference, default 85.
4848
another_cloud: 90
4949
#server_ip
50-
server_ip: '127.0.0.1:50051'
50+
server_ip: '43.137.51.180:50051'
5151
#edge nodes
5252
edge_id: 1
5353
destinations: {'id': [2], 'ip':['127.0.0.1:50050']}
5454
#database config
5555
database:
56-
connection: { 'user': 'root', 'password': 'root', 'host': '127.0.0.1', 'raise_on_warnings': True }
56+
connection: { 'user': 'root', 'password': 'root', 'host': '43.137.51.180', 'raise_on_warnings': True }
5757
database_name: 'mydatabase'
5858
# retrain
5959
retrain:
60-
num_epoch: 4
60+
num_epoch: 2
6161
cache_path: './cache'
62-
collect_num: 20
63-
select_num: 20
62+
collect_num: 10
63+
select_num: 5
64+
interval: 2
6465

6566
server:
6667
server_id: 0

edge/edge_worker.py

Lines changed: 55 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
from database.database import DataBase
1616
from difference.diff import DiffProcessor
1717
from edge.info import TASK_STATE
18+
from edge.resample import history_sample, annotion_process
1819
from edge.task import Task
1920
from edge.transmit import get_cloud_target
2021
from grpc_server import message_transmission_pb2_grpc, message_transmission_pb2
2122
from grpc_server.rpc_server import MessageTransmissionServicer
2223

2324
from tools.convert_tool import cv2_to_base64
24-
from tools.file_op import clear_folder, creat_folder
25+
from tools.file_op import clear_folder, creat_folder, sample_files
2526
from tools.preprocess import frame_resize
2627
from model_management.object_detection import Object_Detection
2728
from apscheduler.schedulers.background import BackgroundScheduler
@@ -53,53 +54,37 @@ def __init__(self, config):
5354
self.avg_scores = []
5455
self.select_index = []
5556
self.annotations = []
56-
# read video
57-
self.video_reader = threading.Thread(target=self.video_read,)
58-
self.video_reader.start()
57+
5958
# start the thread for diff
6059
self.diff = 0
6160
self.key_task = None
62-
self.diff_processor = threading.Thread(target=self.diff_worker,)
61+
self.diff_processor = threading.Thread(target=self.diff_worker,daemon=True)
6362
self.diff_processor.start()
6463

6564
# start the thread for edge server
66-
#self.edge_server = threading.Thread(target=self.start_edge_server, )
65+
#self.edge_server = threading.Thread(target=self.start_edge_server, daemon=True)
6766
#self.edge_server.start()
6867

6968
# start the thread for local process
70-
self.local_processor = threading.Thread(target=self.local_worker,)
69+
self.local_processor = threading.Thread(target=self.local_worker,daemon=True)
7170
self.local_processor.start()
7271

7372
# start the thread for retrain process
7473
self.retrain_flag = False
7574
self.collect_flag = True
7675
self.cache_count = 0
77-
self.retrain_processor = threading.Thread(target=self.retrain_worker,)
76+
77+
self.use_history = True
78+
self.test_only = False
79+
self.retrain_no = 0
80+
81+
self.retrain_processor = threading.Thread(target=self.retrain_worker,daemon=True)
7882
self.retrain_processor.start()
7983

8084
# start the thread pool for offload
8185
self.offloading_executor = futures.ThreadPoolExecutor(max_workers=config.offloading_max_worker,)
8286

83-
def video_read(self):
84-
with VideoProcessor(self.config.source) as video:
85-
video_fps = video.fps
86-
logger.info("the video fps is {}".format(video_fps))
87-
index = 0
88-
if self.config.interval == 0:
89-
logger.error("the interval error")
90-
sys.exit(1)
91-
logger.info("Take the frame interval is {}".format(self.config.interval))
92-
while True:
93-
frame = next(video)
94-
if frame is None:
95-
logger.info("the video finished")
96-
break
97-
index += 1
98-
if index % self.config.interval == 0:
99-
start_time = time.time()
100-
task = Task(self.edge_id, index, frame, start_time, frame.shape)
101-
self.frame_cache.put(task, block=True)
102-
time.sleep((self.config.interval * 1.0) / video_fps)
87+
10388

10489

10590
def diff_worker(self):
@@ -110,7 +95,6 @@ def diff_worker(self):
11095
self.pre_frame_feature = self.edge_processor.get_frame_feature(frame)
11196
self.key_task = task
11297
# Create an entry for the task in the database table
113-
logger.debug("start time {}".format(task.start_time))
11498
data = (
11599
task.frame_index,
116100
task.start_time,
@@ -151,9 +135,6 @@ def diff_worker(self):
151135
task.state = TASK_STATE.FINISHED
152136
self.update_table(task)
153137

154-
else:
155-
pass
156-
157138
def update_table(self, task):
158139
state = "Finished" if task.state == TASK_STATE.FINISHED else ""
159140
if task.ref is not None:
@@ -207,7 +188,7 @@ def local_worker(self):
207188
task.frame_cloud = offloading_image
208189
self.offloading_executor.submit(self.offload_worker, task)
209190
end_time = time.time()
210-
task.set_end_time(end_time)
191+
task.end_time = end_time
211192
task.state = TASK_STATE.FINISHED
212193
# upload the result to database
213194
self.update_table(task)
@@ -298,55 +279,74 @@ def offload_worker(self, task, destination_edge_id=None):
298279
else:
299280
logger.info(str(res))
300281

301-
#
282+
# collect data for retrain
302283
def collect_data(self, task, frame ,detection_boxes, detection_class, detection_score):
303-
self.select_index = []
304-
creat_folder(self.config.retrain.cache_path)
305-
cv2.imwrite(os.path.join(self.config.retrain.cache_path,'frames', str(task.frame_index) + '.jpg'), frame)
306-
self.avg_scores.append({task.frame_index:np.mean(detection_score)})
307-
self.cache_count += 1
308-
logger.debug("count {}".format(self.cache_count))
309-
for score, label, box in zip(detection_score, detection_class, detection_boxes):
310-
self.pred_res.append((task.frame_index, label, box[0], box[1], box[2], box[3], score))
284+
if detection_score is not None:
285+
creat_folder(self.config.retrain.cache_path)
286+
cv2.imwrite(os.path.join(self.config.retrain.cache_path,'frames', str(task.frame_index) + '.jpg'), frame)
287+
self.avg_scores.append({task.frame_index: np.mean(detection_score)})
288+
self.cache_count += 1
289+
logger.debug("count {}".format(self.cache_count))
311290
if self.cache_count >= self.config.retrain.collect_num:
291+
self.retrain_no += 1
312292
logger.debug("enough")
313293
smallest_elements = sorted(self.avg_scores, key=lambda d: list(d.values())[0])[:self.config.retrain.select_num]
314294
self.select_index = [list(d.keys())[0] for d in smallest_elements]
315-
print(self.select_index)
316-
317-
np.savetxt(os.path.join(self.config.retrain.cache_path,'pred_res.txt'), self.pred_res,
318-
fmt=['%d', '%d', '%f', '%f', '%f', '%f', '%f'], delimiter=',')
319-
295+
logger.debug(self.select_index)
320296
self.pred_res = []
321-
self.retrain_flag = True
322297
self.collect_flag = False
323298
self.cache_count = 0
299+
if self.retrain_no % self.config.retrain.interval == 0:
300+
self.test_only = True
301+
else:
302+
self.test_only = False
303+
self.retrain_flag = True
324304

325305

326306
# retrain
327307
def retrain_worker(self):
308+
self.annotations = []
328309
while True:
329310
if self.retrain_flag:
330311
logger.debug("retrain")
331-
self.annotations = []
312+
332313
for index in self.select_index:
333314
path = os.path.join(self.config.retrain.cache_path, 'frames', '{}.jpg'.format(index))
315+
logger.debug(path)
334316
frame = cv2.imread(path)
335-
logger.debug("get index {}".format(index))
317+
logger.debug("get index {} {}".format(index, time.time()))
336318
target_res = get_cloud_target(self.config.server_ip, frame)
337-
logger.debug("get target {}".format(target_res))
319+
logger.debug("get target {} {}".format(target_res, time.time()))
338320
for score, label, box in zip(target_res['scores'], target_res['labels'], target_res['boxes']):
339321
self.annotations.append((index, label, box[0], box[1], box[2], box[3], score))
340322
if len(self.annotations):
341323
np.savetxt(os.path.join(self.config.retrain.cache_path,'annotation.txt'), self.annotations,
342324
fmt=['%d', '%d', '%f', '%f', '%f', '%f', '%f'], delimiter=',')
343325

344326
logger.debug("select num {}".format(int(self.config.retrain.select_num*0.8)))
345-
self.small_object_detection.retrain(self.config.retrain.cache_path, self.select_index[:int(self.config.retrain.select_num*0.8)])
346-
self.small_object_detection.model_evaluation(self.config.retrain.cache_path, self.select_index[int(self.config.retrain.select_num*0.8):])
347-
self.collect_flag = True
327+
if self.test_only:
328+
logger.debug("test only")
329+
self.small_object_detection.model_evaluation(
330+
self.config.retrain.cache_path, self.select_index)
331+
else:
332+
self.small_object_detection.model_evaluation(
333+
self.config.retrain.cache_path, self.select_index[int(self.config.retrain.select_num * 0.8):])
334+
self.small_object_detection.retrain(
335+
self.config.retrain.cache_path, self.select_index[:int(self.config.retrain.select_num*0.8)])
336+
self.small_object_detection.model_evaluation(
337+
self.config.retrain.cache_path, self.select_index[int(self.config.retrain.select_num*0.8):])
348338
self.retrain_flag = False
349-
clear_folder(self.config.retrain.cache_path)
339+
if self.use_history:
340+
self.select_index,self.avg_scores = history_sample(self.select_index,self.avg_scores)
341+
self.annotations = annotion_process(self.annotations, self.select_index)
342+
sample_files(os.path.join(self.config.retrain.cache_path, 'frames') ,self.select_index)
343+
self.cache_count = len(self.select_index)
344+
else:
345+
clear_folder(self.config.retrain.cache_path)
346+
self.select_index = []
347+
self.avg_scores = []
348+
self.annotations = []
349+
self.collect_flag = True
350350
time.sleep(1)
351351

352352

edge/resample.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import random
2+
3+
def history_sample(index_list, score_list):
4+
num_elements = min(len(index_list), len(score_list)) // 2
5+
index_list_c = random.sample(index_list, num_elements)
6+
score_list_c = []
7+
for item in score_list:
8+
if list(item.keys())[0] in index_list_c:
9+
score_list_c.append(item)
10+
return index_list_c, score_list_c
11+
12+
def annotion_process(annotations, index_list):
13+
new_annotations = []
14+
for sublist in annotations:
15+
if sublist[0] in index_list:
16+
new_annotations.append(sublist)
17+
return new_annotations

edge_client.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
11
import argparse
2+
import signal
3+
import sys
4+
import threading
5+
import time
6+
27
import munch
38
import yaml
49

510
from edge.edge_worker import EdgeWorker
611
from loguru import logger
712

13+
from edge.task import Task
14+
from tools.file_op import clear_folder
15+
from tools.video_processor import VideoProcessor
816

9-
17+
def signal_handler(signal, frame):
18+
logger.debug("Received Ctrl+C. Cleaning up...")
19+
clear_folder(config.retrain.cache_path)
20+
sys.exit(0)
1021

1122

1223
if __name__ == '__main__':
@@ -18,10 +29,31 @@
1829
#provide class-like access for dict
1930
config = munch.munchify(config)
2031
config = config.client
32+
signal.signal(signal.SIGINT, signal_handler)
33+
event = threading.Event()
2134
edge = EdgeWorker(config)
2235
logger.add("log/client/client_{time}.log", level="INFO", rotation="500 MB")
23-
24-
25-
26-
36+
try:
37+
with VideoProcessor(config.source) as video:
38+
video_fps = video.fps
39+
logger.info("the video fps is {}".format(video_fps))
40+
index = 0
41+
if config.interval == 0:
42+
logger.error("the interval error")
43+
sys.exit(1)
44+
logger.info("Take the frame interval is {}".format(config.interval))
45+
while True:
46+
frame = next(video)
47+
if frame is None:
48+
logger.debug("The video finished")
49+
break
50+
index += 1
51+
if index % config.interval == 0:
52+
start_time = time.time()
53+
task = Task(config.edge_id, index, frame, start_time, frame.shape)
54+
edge.frame_cache.put(task, block=True)
55+
time.sleep((config.interval * 1.0) / video_fps)
56+
event.wait()
57+
except KeyboardInterrupt:
58+
pass
2759

grpc_server/rpc_server.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,9 @@ def task_processor(self, request, context):
2323
task.other = True
2424

2525
if request.part_result != "":
26-
logger.debug("put1")
2726
part_result = eval(request.part_result)
2827
if len(request.part_result['boxes']) != 0:
29-
logger.debug("put2")
3028
task.add_result(part_result['boxes'], part_result['labels'], part_result['scores'])
31-
logger.debug("put3")
3229

3330
if request.note == "edge process":
3431
task.edge_process = True
@@ -49,13 +46,13 @@ def frame_processor(self, request, context):
4946
frame_shape = tuple(int(s) for s in request.frame_shape[1:-1].split(","))
5047
frame = base64_to_cv2(base64_frame).reshape(frame_shape)
5148
pred_boxes, pred_class, pred_score = self.object_detection.large_inference(frame)
52-
res_dict = {
49+
res = {
5350
'boxes': pred_boxes,
5451
'labels': pred_class,
5552
'scores': pred_score
5653
}
5754
reply = message_transmission_pb2.FrameReply(
58-
response=str(res_dict),
55+
response=str(res),
5956
frame_shape=str(frame_shape),
6057
)
6158
return reply

0 commit comments

Comments
 (0)