Skip to content

Commit 7d50536

Browse files
committed
update inference for cls
1 parent 2033d86 commit 7d50536

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

pymic/net_run/agent_cls.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,11 @@ def infer(self):
304304
out_prob = nn.Sigmoid()(out_digit).detach().cpu().numpy()
305305
out_lab = np.asarray(out_prob > 0.5, np.uint8)
306306
for i in range(len(names)):
307-
print(names[i], out_lab[i], len(out_lab[i]))
308-
out_lab_list.append([names[i]] + out_lab[i].tolist())
307+
print(names[i], out_lab[i])
308+
if(self.task_type == "cls"):
309+
out_lab_list.append([names[i]] + [out_lab[i]])
310+
else:
311+
out_lab_list.append([names[i]] + out_lab[i].tolist())
309312
out_prob_list.append([names[i]] + out_prob[i].tolist())
310313

311314
with open(output_csv, mode='w') as csv_file:

0 commit comments

Comments
 (0)