-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathGLAM_oos_prediction.py
208 lines (165 loc) · 9.34 KB
/
GLAM_oos_prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import glam
import numpy as np
import pymc3 as pm
import matplotlib.pyplot as plt
import os
import pandas as pd
import argparse
def check_convergence(summary, parameters=['v', 's', 'tau'],
n_eff_required=100, gelman_rubin_criterion=0.05):
"""
Checks model convergence based on
number of effective samples
and Gelman Rubin statistics
from pymc3 model summary table.
"""
parameters = [parameter + '__0_0' for parameter in parameters]
enough_eff_samples = np.all(summary.loc[parameters]['n_eff'] > n_eff_required)
good_gelman = np.all(np.abs(summary.loc[parameters]['Rhat'] - 1.0) < gelman_rubin_criterion)
if not enough_eff_samples or not good_gelman:
return False
else:
return True
def fitModel(model, relevant_parameters=['v', 's', 'tau'],
n_tuning_initial=1000, n_tuning_increase=1000,
seed_start=10, seed_increment=1, n_tries_max=1,
n_advi=200000, fallback='Metropolis',
progressbar=True):
"""
Keep fitting a given GLAM model until convergence diagnosed.
Then fall back to ADVI.
"""
converged = False
n_tuning = n_tuning_initial
seed = seed_start
n_tries = 0
while (not converged) and (n_tries < n_tries_max):
np.random.seed(seed)
model.fit(method='NUTS', tune=n_tuning, progressbar=progressbar)
summary = pm.summary(model.trace[0])
converged = check_convergence(summary, parameters=relevant_parameters)
seed += seed_increment
n_tuning += n_tuning_increase
n_tries += 1
if not converged:
if fallback is 'ADVI':
print("Falling back to ADVI...")
model.fit(method='ADVI', n_advi=n_advi)
elif fallback is 'Metropolis':
print("Falling back to Metropolis...")
model.fit(method='Metropolis', n_samples=10000)
return model
def fitPredictOOS(data, subject, n_repeats=50, n_tries=1, overwrite=False, progressbar=True):
"""
Perform fitting of additive & no bias GLAM variants
on even numbered trials, exchange data and
predict for odd numbered trials
"""
print("Processing subject {}...".format(subject))
# Subset data
subject_data = data[data['subject'] == subject].copy()
n_items = subject_data['n_items'].values[0]
if n_items == 2:
subject_data = subject_data.drop(['item_value_2', 'gaze_2'], axis=1)
subject_data['subject'] = 0
# split into even and odd trials
even = subject_data[(subject_data['trial'] % 2) == 0].copy().reset_index(drop=True)
odd = subject_data[(subject_data['trial'] % 2) == 1].copy().reset_index(drop=True)
# Additive
if (overwrite) or (not os.path.isfile(os.path.join('results', 'estimates', 'out_of_sample', 'additive', 'estimates_{}_additive_oos.csv'.format(subject)))):
print('\tS{}: Additive'.format(subject))
parameters = ['v', 's', 'tau', 'gamma']
additive = glam.GLAM(even, drift='additive')
additive.make_model('individual', gamma_bounds=(-100, 100), t0_val=0)
additive = fitModel(additive, relevant_parameters=parameters, n_tries_max=n_tries, progressbar=progressbar)
summary = pm.summary(additive.trace[0])
for parameter in parameters:
summary.loc[parameter + '__0_0', 'MAP'] = additive.estimates[parameter].values[0]
summary.to_csv(os.path.join('results', 'estimates', 'out_of_sample', 'additive', 'estimates_{}_additive_oos.csv'.format(subject)))
additive_model = additive.model[0]
additive_model.name = 'additive'
additive_trace = additive.trace[0]
pm.trace_to_dataframe(additive_trace).to_csv(os.path.join('results', 'traces', 'out_of_sample', 'additive', 'trace_{}_additive_oos.csv'.format(subject)))
pm.traceplot(additive_trace)
plt.savefig(os.path.join('results', 'traces', 'out_of_sample', 'additive', 'plots', 'traceplot_{}_additive_oos.png'.format(subject)))
plt.close()
# out of sample prediction
additive.exchange_data(odd)
additive.predict(n_repeats=n_repeats)
additive.prediction['subject'] = subject
additive.prediction.to_csv(os.path.join('results', 'predictions', 'out_of_sample', 'additive', 'prediction_{}_additive_oos.csv'.format(subject)))
else:
print("Previous estimates found for additive model (Subject {}). Skipping...".format(subject))
# Multiplicative
if (overwrite) or (not os.path.isfile(os.path.join('results', 'estimates', 'out_of_sample', 'multiplicative', 'estimates_{}_multiplicative_oos.csv'.format(subject)))):
print('\tS{}: Multiplicative'.format(subject))
parameters = ['v', 's', 'tau', 'gamma']
multiplicative = glam.GLAM(even, drift='multiplicative')
multiplicative.make_model('individual', gamma_bounds=(-10, 1), t0_val=0)
multiplicative = fitModel(multiplicative, relevant_parameters=parameters, n_tries_max=n_tries, progressbar=progressbar)
summary = pm.summary(multiplicative.trace[0])
for parameter in parameters:
summary.loc[parameter + '__0_0', 'MAP'] = multiplicative.estimates[parameter].values[0]
summary.to_csv(os.path.join('results', 'estimates', 'out_of_sample', 'multiplicative', 'estimates_{}_multiplicative_oos.csv'.format(subject)))
multiplicative_model = multiplicative.model[0]
multiplicative_model.name = 'multiplicative'
multiplicative_trace = multiplicative.trace[0]
pm.trace_to_dataframe(multiplicative_trace).to_csv(os.path.join('results', 'traces', 'out_of_sample', 'multiplicative', 'trace_{}_multiplicative_oos.csv'.format(subject)))
pm.traceplot(multiplicative_trace)
plt.savefig(os.path.join('results', 'traces', 'out_of_sample', 'multiplicative', 'plots', 'traceplot_{}_multiplicative_oos.png'.format(subject)))
plt.close()
# out of sample prediction
multiplicative.exchange_data(odd)
multiplicative.predict(n_repeats=n_repeats)
multiplicative.prediction['subject'] = subject
multiplicative.prediction.to_csv(os.path.join('results', 'predictions', 'out_of_sample', 'multiplicative', 'prediction_{}_multiplicative_oos.csv'.format(subject)))
else:
print("Previous estimates found for multiplicative model (Subject {}). Skipping...".format(subject))
# No-Bias
if (overwrite) or (not os.path.isfile(os.path.join('results', 'estimates', 'out_of_sample', 'nobias', 'estimates_{}_nobias_oos.csv'.format(subject)))):
print('\tS{}: No Bias'.format(subject))
parameters = ['v', 's', 'tau']
nobias = glam.GLAM(even, drift='additive')
nobias.make_model('individual', gamma_val=0.0, t0_val=0)
nobias = fitModel(nobias, relevant_parameters=parameters, n_tries_max=n_tries, progressbar=progressbar)
summary = pm.summary(nobias.trace[0])
for parameter in parameters:
summary.loc[parameter + '__0_0', 'MAP'] = nobias.estimates[parameter].values[0]
summary.to_csv(os.path.join('results', 'estimates', 'out_of_sample', 'nobias', 'estimates_{}_nobias_oos.csv'.format(subject)))
nobias_model = nobias.model[0]
nobias_model.name = 'nobias'
nobias_trace = nobias.trace[0]
pm.trace_to_dataframe(nobias_trace).to_csv(os.path.join('results', 'traces', 'out_of_sample', 'nobias', 'trace_{}_nobias_oos.csv'.format(subject)))
pm.traceplot(nobias_trace)
plt.savefig(os.path.join('results', 'traces', 'out_of_sample', 'nobias', 'plots', 'traceplot_{}_nobias_oos.png'.format(subject)))
plt.close()
# out of sample prediction
nobias.exchange_data(odd)
nobias.predict(n_repeats=n_repeats)
nobias.prediction['subject'] = subject
nobias.prediction.to_csv(os.path.join('results', 'predictions', 'out_of_sample', 'nobias', 'prediction_{}_nobias_oos.csv'.format(subject)))
else:
print("Previous estimates found for no-bias model (Subject {}). Skipping...".format(subject))
return
def fitSubjects(first=0, last=-1, n_repeats=50, n_tries=1, overwrite=False, progressbar=True):
data = pd.read_csv(os.path.join('data', 'data_aggregate.csv'))
for subject in data['subject'].unique()[first:last]:
fitPredictOOS(data, subject, n_repeats=n_repeats, n_tries=n_tries, overwrite=overwrite, progressbar=progressbar)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--overwrite", default=False, action="store_true",
help="Overwrite previous results.")
parser.add_argument("--silent", default=False, action="store_true",
help="Run without progressbar.")
parser.add_argument("--n-tries", default=1, type=int,
help="Number of tries for NUTS fitting, before falling back to fallback method.")
parser.add_argument("--first", default=0, type=int,
help="First subject index to use.")
parser.add_argument("--last", default=-1, type=int,
help="Last subject index to use.")
parser.add_argument("--n-prediction-repeats", default=50, type=int,
help="Number of trial repetitions in prediction.")
args = parser.parse_args()
fitSubjects(first=args.first, last=args.last,
overwrite=args.overwrite, progressbar=(not args.silent),
n_repeats=args.n_prediction_repeats, n_tries=args.n_tries)