Skip to content

Commit f1fbae0

Browse files
committed
Update YOLO
1 parent 732f927 commit f1fbae0

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

Object_Detection/YOLOv1/src/loss.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
295295
obj_class_mask = (target_conf_class > 0.5); // target_conf_class{N,G,G,CN} ===> obj_class_mask{N,G,G,CN}
296296
input_obj_class = input_class.masked_select(/*mask=*/obj_class_mask); // input_class{N,G,G,CN} ===> input_noobj_conf{CN}
297297
target_obj_class = target_class.masked_select(/*mask=*/obj_class_mask); // target_class{N,G,G,CN} ===> target_noobj_conf{object class}
298-
loss_class = criterion(input_obj_class, target_obj_class) * 0.5 / (float)mini_batch_size;
298+
loss_class = criterion(input_obj_class, target_obj_class) * 0.5 / (float)this->class_num / (float)mini_batch_size;
299299
}
300300
else{
301301
loss_class = torch::full({}, /*value=*/0.0, torch::TensorOptions().dtype(torch::kFloat)).to(device);

Object_Detection/YOLOv2/src/loss.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
374374
response_class_mask = response_mask.unsqueeze(/*dim=*/-1).expand_as(input_class); // response_mask{N,G,G,A} ===> response_class_mask{N,G,G,A,CN}
375375
input_response_class = input_class.masked_select(/*mask=*/response_class_mask); // input_class{N,G,G,A,CN} ===> input_response_class{response*CN}
376376
target_response_class = target_class.masked_select(/*mask=*/response_class_mask); // target_class{N,G,G,A,CN} ===> target_response_class{response*CN}
377-
loss_class = criterion(input_response_class, target_response_class) * 0.5 / (float)mini_batch_size;
377+
loss_class = criterion(input_response_class, target_response_class) * 0.5 / (float)this->class_num / (float)mini_batch_size;
378378
}
379379
else{
380380
loss_class = torch::full({}, /*value=*/0.0, torch::TensorOptions().dtype(torch::kFloat)).to(device);

Object_Detection/YOLOv3/src/loss.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
352352
response_class_mask = response_mask.unsqueeze(/*dim=*/-1).expand_as(input_class); // response_mask{N,G,G,A} ===> response_class_mask{N,G,G,A,CN}
353353
input_response_class = input_class.masked_select(/*mask=*/response_class_mask); // input_class{N,G,G,A,CN} ===> input_response_class{response*CN}
354354
target_response_class = target_class.masked_select(/*mask=*/response_class_mask); // target_class{N,G,G,A,CN} ===> target_response_class{response*CN}
355-
loss_class = loss_class + criterion(input_response_class, target_response_class) * 0.5 / (float)mini_batch_size;
355+
loss_class = loss_class + criterion(input_response_class, target_response_class) * 0.5 / (float)this->class_num / (float)mini_batch_size;
356356
}
357357

358358
}

0 commit comments

Comments
 (0)