Skip to content

Commit c12d540

Browse files
authored
Merge pull request #223 from zhi-yi-huang/main
fix: fix issue#221
2 parents 2eab72e + ff0f0c1 commit c12d540

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

causallearn/score/LocalScoreFunction.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def local_score_cv_general(
206206
Thresh = 1e-5
207207

208208
if len(PAi):
209-
PA = Data[:, PAi].reshape(-1, 1)
209+
PA = Data[:, PAi]
210210

211211
# set the kernel for X
212212
GX = np.multiply(X, X).reshape(-1, 1)
@@ -354,7 +354,7 @@ def local_score_cv_general(
354354
CV = CV / k
355355
else:
356356
# set the kernel for X
357-
GX = np.sum(np.multiply(X, X), axis=1)
357+
GX = np.sum(np.multiply(X, X), axis=1).reshape(-1, 1)
358358
Q = np.tile(GX, (1, T))
359359
R = np.tile(GX.T, (T, 1))
360360
dists = Q + R - 2 * X * X.T
@@ -416,8 +416,8 @@ def local_score_cv_general(
416416
- 1
417417
/ (gamma * n1)
418418
* Kx_tr_te.T
419-
* pdinv(np.eye(n1) + 1 / (gamma * n1) * Kx_tr)
420-
* Kx_tr_te
419+
@ pdinv(np.eye(n1) + 1 / (gamma * n1) * Kx_tr)
420+
@ Kx_tr_te
421421
) / gamma
422422
B = 1 / (gamma * n1) * Kx_tr + np.eye(n1)
423423
L = np.linalg.cholesky(B)
@@ -604,7 +604,7 @@ def local_score_cv_multi(
604604
CV = CV / k
605605
else:
606606
# set the kernel for X
607-
GX = np.sum(np.multiply(X, X), axis=1)
607+
GX = np.sum(np.multiply(X, X), axis=1).reshape(-1, 1)
608608
Q = np.tile(GX, (1, T))
609609
R = np.tile(GX.T, (T, 1))
610610
dists = Q + R - 2 * X * X.T

0 commit comments

Comments
 (0)