forked from iricchi/Brain_GLMM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathglmm_final.m
144 lines (123 loc) · 4.78 KB
/
glmm_final.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
function [L, gamma_hat, mu] = glmm_final(y, iterations, classes, spread, regul, norm_par, delta)
if (nargin == 3)
spread = 0.1;
regul = 0.15;
norm_par = 1.2;
delta = 2;
end
if (nargin > 3)
spread = 0.01;
regul = 0.15;
end
n = size(y,2);
m = size(y,1);
%% initialise
L = zeros(n,n,classes); % laplacian
W = zeros(n,n,classes); % weights (edges)
sigma = zeros(n-1,n-1,classes);
mu = zeros(n, classes);
gamma_hat = zeros(m, classes);
p = zeros(classes,1);
vecl = zeros(n,n,classes);
vall = zeros(n,n,classes);
yl = zeros(m, n-1, classes);
%init_ind = randperm(n, classes);
for class = 1:classes
%L(:,:,class) = spread * init_L_by_weight(n);
L(:,:,class) = spread*eye(n) - spread/n *ones(n);
%%mu_curr = mean(y,1) + randn(1,n).* std(y,1);
%mu_curr = y(init_ind(class),:);
%%mu(:,class) = mu_curr - mean(mu_curr);
p(class) = 1/classes;
end
%mu = kmeans_pp(y,classes);
[clst, mu] = kmeans(y,classes,'MaxIter', 300);
mu = mu';
%% start the algorithm
%try
for it = 1:iterations
%Expectation step
%putting everything in eigenvector space of dim-1
pall = 0;
for class = 1:classes
[vecl(:,:,class), vall(:,:,class)] = eig(squeeze(L(:,:,class)));
sigma(:,:,class) = inv(vall(2:n,2:n,class) + regul*eye(n-1));
sigma(:,:,class) = (squeeze(sigma(:,:,class)) + squeeze(sigma(:,:,class))')/2;
yl(:,:,class) = (y-mu(:,class)')*vecl(:,2:n,class);
pall = pall + p(class) * mvnpdf(yl(:,:,class), zeros(1,n-1), sigma(:,:,class));
end
%compute cluster probabilities gamma_hat
for class = 1:classes
gamma_hat(:,class) = (p(class) * mvnpdf(yl(:,:,class), zeros(1,n-1), sigma(:,:,class)))./pall;
nans = find(isnan(gamma_hat(:,class)));
if (length(nans)>1)
disp("Gammas have NaNs")
end
gamma_hat(nans) = 1/classes;
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% EMPTY ACTION
inds = find(sum(gamma_hat)<=10e-10);
if (length(inds)>=1)
disp(inds)
% take index
[mut, idx] = datasample(y,length(inds));
mu(:,inds) = mut';
%or farthest
disp(idx)
pall = 0;
listclasses = linspace(1,classes,classes);
for class = 1:classes
[vecl(:,:,class), vall(:,:,class)] = eig(squeeze(L(:,:,class)));
if (ismember(class,inds))
if length(inds) == 1
ninds = find(inds(:)~=listclasses);
else
ninds = find(all(inds(:)~=listclasses));
end
sigma(:,:,class) = mean(sigma(:,:,ninds),3);
else
sigma(:,:,class) = inv(vall(2:n,2:n,class) + regul*eye(n-1));
sigma(:,:,class) = (squeeze(sigma(:,:,class)) + squeeze(sigma(:,:,class))')/2;
end
yl(:,:,class) = (y-mu(:,class)')*vecl(:,2:n,class);
[R,err] = cholcov(sigma(:,:,class),0);
if err ~= 0
error(message('stats:mvnpdf:BadMatrixSigma'));
end
pall = pall + p(class) * mvnpdf(yl(:,:,class), zeros(1,n-1), sigma(:,:,class));
end
%compute cluster probabilities gamma_hat
for class = 1:classes
gamma_hat(:,class) = (p(class) * mvnpdf(yl(:,:,class), zeros(1,n-1), sigma(:,:,class)))./pall;
nans = find(isnan(gamma_hat(:,class)));
if (length(nans)>1)
disp("Gammas have NaNs")
end
gamma_hat(nans) = 1/classes;
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%Maximisation step: update mu, W and p
for class = 1:classes
mu(:,class) = (gamma_hat(:,class)'*y)/sum(gamma_hat(:,class));
yc = repmat(sqrt(gamma_hat(:,class)),[1,n]) .* (y - mu(:,class)');
Z = gsp_distanz(yc).^2;
theta = mean(Z(:))/norm_par;
if ~all(diag(Z ./ theta)==0)
error(message('stats:squareform:BadInputMatrix'));
end
W_curr = delta*gsp_learn_graph_log_degrees(Z ./ theta, 1, 1);
W(:,:,class) = W_curr;
%p(class) = sum(gamma_hat(:,class))/m;
%compute Ls
L(:,:,class) = diag(sum(W(:,:,class),2)) - W(:,:,class);
W_curr(W_curr<1e-3) = 0;
W(:,:,class) = W_curr;
end
end
%catch e
% warning('Some error')
% fprintf(1,'%s %s\n',e.identifier, e.message);
% pause
end