Skip to content

Commit b1c8ad5

Browse files
committed
update
1 parent dfd5439 commit b1c8ad5

File tree

5 files changed

+36
-31
lines changed

5 files changed

+36
-31
lines changed

config/config.yaml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ client:
22
# the source video
33
source:
44
video_path: ./video_data/dayroad.mp4
5-
max_count: 5
5+
max_count: 1000
66
rtsp:
77
flag: False
88
account: your account
@@ -21,13 +21,13 @@ client:
2121
# the queue threshold for offloading
2222
queue_thresh: 5
2323
# the max wait time
24-
wait_thresh: 10
24+
wait_thresh: 100
2525
# Number of worker threads to offloading thread pool
2626
offloading_max_worker: 1
2727
frame_cache_maxsize: 100
2828
small_model_name: fasterrcnn_mobilenet_v3_large_fpn
2929
# select the offloading policy
30-
policy: Edge-Shortest
30+
policy: Edge-Local
3131
# change frame resolution using frame new height
3232
new_height:
3333
# offload to another edge node [1080, 720], default 720
@@ -52,24 +52,24 @@ client:
5252
server_ip: '127.0.0.1:50051'
5353
# edge nodes
5454
edge_id: 1
55-
edge_num: 2
56-
destinations: {'id': [2], 'ip':['192.168.0.185:50050']}
55+
edge_num: 1
56+
destinations: {'id': [], 'ip':[]}
5757
# database config
5858
database:
5959
connection: { 'user': 'root', 'password': 'root', 'host': '127.0.0.1', 'raise_on_warnings': True }
6060
database_name: 'mydatabase'
6161
# retrain
6262
retrain:
63-
flag: False
63+
flag: True
6464
num_epoch: 2
6565
cache_path: './cache'
66-
collect_num: 10
67-
select_num: 5
68-
interval: 2
66+
collect_num: 20
67+
select_num: 15
68+
window: 90
6969

7070
server:
7171
server_id: 0
72-
edge_ids: [1,2]
72+
edge_ids: [1]
7373
large_model_name: fasterrcnn_resnet50_fpn
7474
#the queue maxsize
7575
local_queue_maxsize: 10

edge/edge_worker.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -63,20 +63,21 @@ def __init__(self, config):
6363
self.diff_processor.start()
6464

6565
# start the thread for edge server
66-
self.edge_server = threading.Thread(target=self.start_edge_server, daemon=True)
67-
self.edge_server.start()
66+
#self.edge_server = threading.Thread(target=self.start_edge_server, daemon=True)
67+
#self.edge_server.start()
6868

6969
# start the thread for local process
7070
self.local_processor = threading.Thread(target=self.local_worker,daemon=True)
7171
self.local_processor.start()
7272

7373
# start the thread for retrain process
7474
self.collect_flag = self.config.retrain.flag
75+
self.collect_time = None
76+
self.collect_time_flag = True
7577
self.retrain_flag = False
7678
self.cache_count = 0
7779

7880
self.use_history = True
79-
self.test_only = False
8081
self.retrain_no = 0
8182

8283
self.retrain_processor = threading.Thread(target=self.retrain_worker,daemon=True)
@@ -180,7 +181,12 @@ def local_worker(self):
180181

181182
# collect data for retrain
182183
if self.collect_flag:
183-
self.collect_data(task, current_frame ,detection_boxes, detection_class, detection_score)
184+
if self.collect_time_flag:
185+
self.collect_time = time.time()
186+
self.collect_time_flag = False
187+
duration = time.time() - self.collect_time
188+
if duration > self.config.retrain.window:
189+
self.collect_data(task, current_frame ,detection_boxes, detection_class, detection_score)
184190

185191
if detection_boxes is not None:
186192
task.add_result(detection_boxes, detection_class, detection_score)
@@ -308,10 +314,6 @@ def collect_data(self, task, frame ,detection_boxes, detection_class, detection_
308314
self.pred_res = []
309315
self.collect_flag = False
310316
self.cache_count = 0
311-
if self.retrain_no % self.config.retrain.interval == 0:
312-
self.test_only = True
313-
else:
314-
self.test_only = False
315317
self.retrain_flag = True
316318

317319

@@ -332,17 +334,10 @@ def retrain_worker(self):
332334
np.savetxt(os.path.join(self.config.retrain.cache_path,'annotation.txt'), self.annotations,
333335
fmt=['%d', '%d', '%f', '%f', '%f', '%f', '%f'], delimiter=',')
334336

335-
if self.test_only:
336-
logger.debug("test only")
337-
self.small_object_detection.model_evaluation(
338-
self.config.retrain.cache_path, self.select_index)
339-
else:
340-
self.small_object_detection.model_evaluation(
341-
self.config.retrain.cache_path, self.select_index[int(self.config.retrain.select_num * 0.8):])
342-
self.small_object_detection.retrain(
343-
self.config.retrain.cache_path, self.select_index[:int(self.config.retrain.select_num*0.8)])
344-
self.small_object_detection.model_evaluation(
345-
self.config.retrain.cache_path, self.select_index[int(self.config.retrain.select_num*0.8):])
337+
self.small_object_detection.model_evaluation(
338+
self.config.retrain.cache_path, self.select_index)
339+
self.small_object_detection.retrain(
340+
self.config.retrain.cache_path, self.select_index[:int(self.config.retrain.select_num*0.8)])
346341
self.retrain_flag = False
347342
if self.use_history:
348343
self.select_index,self.avg_scores = history_sample(self.select_index,self.avg_scores)
@@ -354,8 +349,10 @@ def retrain_worker(self):
354349
self.select_index = []
355350
self.avg_scores = []
356351
self.annotations = []
352+
self.collect_time_flag = True
357353
self.collect_flag = True
358-
time.sleep(1)
354+
time.sleep(0.2)
355+
359356

360357

361358
def start_edge_server(self):

edge_client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import threading
55
import time
66

7+
8+
import wandb
79
import munch
810
import yaml
911

@@ -28,6 +30,7 @@ def signal_handler(signal, frame):
2830
config = yaml.load(f, Loader=yaml.SafeLoader)
2931
#provide class-like access for dict
3032
config = munch.munchify(config)
33+
wandb.init(project="filter", config=config)
3134
config = config.client
3235
signal.signal(signal.SIGINT, signal_handler)
3336
event = threading.Event()

model_management/object_detection.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,13 @@ def retrain(self, path, select_index):
107107
# Update the learning rate
108108
lr_scheduler.step()
109109
torch.save(tmp_model.state_dict(), "./model_management/tmp_model.pth")
110+
111+
if torch.cuda.is_available():
112+
torch.cuda.empty_cache()
110113
state_dict = torch.load("./model_management/tmp_model.pth", map_location=device)
111114
with self.model_lock:
112115
self.model.load_state_dict(state_dict)
113-
self.model.eval()
116+
self.model.eval()
114117

115118
def model_evaluation(self,cache_path, select_index):
116119
map = []
@@ -222,6 +225,8 @@ def large_inference(self, img):
222225
return pred_boxes, pred_class, pred_score
223226

224227
def get_model_prediction(self, img, threshold, model=None):
228+
if torch.cuda.is_available():
229+
torch.cuda.empty_cache()
225230
#process the image
226231
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
227232
transform = transforms.Compose([transforms.ToTensor()])

retrain/road.mp4

31.5 MB
Binary file not shown.

0 commit comments

Comments
 (0)