Skip to content

Commit 65da839

Browse files
authored
Linking to dem (#49)
* linking two YADE (two particle collision) * linking to YADE (triaxial compression)
1 parent c25b10f commit 65da839

27 files changed

+2034
-98
lines changed

docs/source/tutorials.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ and `sim_data_file_ext` is correct such that GrainLearning can find the data in
152152
"param_max": [1, 10],
153153
"param_names": ['a', 'b'],
154154
"num_samples": 20,
155-
"obs_data_file": 'linearObs.dat',
155+
"obs_data_file": 'linear_obs.dat',
156156
"obs_names": ['f'],
157157
"ctrl_name": 'u',
158158
"sim_name": 'linear',

grainlearning/dynamic_systems.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def __init__(
463463
inv_obs_weight: List[float] = None,
464464
sim_data: np.ndarray = None,
465465
callback: Callable = None,
466-
param_data_file: str = None,
466+
param_data_file: str = '',
467467
param_data: np.ndarray = None,
468468
param_names: List[str] = None,
469469
):
@@ -539,7 +539,7 @@ def from_dict(cls: Type["IODynamicSystem"], obj: dict):
539539
obs_data_file=obj["obs_data_file"],
540540
obs_names=obj["obs_names"],
541541
ctrl_name=obj["ctrl_name"],
542-
param_data_file=obj.get("param_data_file", None),
542+
param_data_file=obj.get("param_data_file", ''),
543543
obs_data=obj.get("obs_data", None),
544544
num_samples=obj.get("num_samples", None),
545545
param_min=obj.get("param_min", None),
@@ -566,9 +566,7 @@ def get_obs_data(self):
566566
self.num_steps = len(self.ctrl_data)
567567
# remove the data not used by Bayesian filtering
568568
self.num_obs = len(self.obs_names)
569-
for key in keys_and_data:
570-
if key not in self.obs_names:
571-
keys_and_data.pop(key)
569+
keys_and_data = {key: keys_and_data[key] for key in self.obs_names}
572570
# assign the obs_data array
573571
self.obs_data = np.zeros([self.num_obs, self.num_steps])
574572
for i, key in enumerate(self.obs_names):
@@ -613,24 +611,17 @@ def load_sim_data(self):
613611
for i, sim_data_file in enumerate(self.sim_data_files):
614612
if self.sim_data_file_ext != '.npy':
615613
data = get_keys_and_data(sim_data_file)
616-
param_data = np.genfromtxt(sim_data_file.split('_sim')[0] + f'_param{self.sim_data_file_ext}')
617-
for j, key in enumerate(self.param_names):
618-
data[key] = param_data[j]
614+
param_data = get_keys_and_data(sim_data_file.split('_sim')[0] + f'_param{self.sim_data_file_ext}')
615+
for key in self.param_names:
616+
data[key] = param_data[key][0]
619617
else:
620618
data = np.load(sim_data_file, allow_pickle=True).item()
621619

622620
for j, key in enumerate(self.obs_names):
623621
self.sim_data[i, j, :] = data[key]
624622

625623
params = np.array([data[key] for key in self.param_names])
626-
if not (np.abs((params - self.param_data[i, :])
627-
/ self.param_data[i, :] < 1e-5).all()):
628-
raise RuntimeError(
629-
"Parameters [" + ", ".join(
630-
[f"{v}" for v in self.param_data[i, :]])
631-
+ '] vs [' +
632-
", ".join(f"{v}" for v in params) +
633-
f"] from the simulation data file {sim_data_file} and the parameter table do not match")
624+
np.testing.assert_allclose(params, self.param_data[i, :], rtol=1e-5)
634625

635626
def load_param_data(self, curr_iter: int = 0):
636627
"""
@@ -643,15 +634,17 @@ def load_param_data(self, curr_iter: int = 0):
643634
self.param_data = np.genfromtxt(self.param_data_file, comments='!')[:, -self.num_params:]
644635
self.num_samples = self.param_data.shape[0]
645636
else:
646-
# if param_data_file does not exit, get parameter daa from simulation data files
647-
files = glob(self.sim_data_dir + f'/iter{curr_iter}/{self.sim_name}*{self.sim_data_file_ext}')
637+
# if param_data_file does not exit, get parameter data from text files
638+
files = glob(self.sim_data_dir + f'/iter{curr_iter}/{self.sim_name}*_param*{self.sim_data_file_ext}')
648639
self.num_samples = len(files)
649640
self.sim_data_files = sorted(files)
650641
self.param_data = np.zeros([self.num_samples, self.num_params])
651642
for i, sim_data_file in enumerate(self.sim_data_files):
652-
# TODO: this is still for npy, support text file formats
653-
data = np.load(sim_data_file, allow_pickle=True).item()
654-
params = [data[key] for key in self.param_names]
643+
if self.sim_data_file_ext == '.npy':
644+
data = np.load(sim_data_file, allow_pickle=True).item()
645+
else:
646+
data = get_keys_and_data(sim_data_file)
647+
params = [data[key][0] for key in self.param_names]
655648
self.param_data[i, :] = params
656649

657650
def run(self, **kwargs):
@@ -692,7 +685,7 @@ def write_params_to_table(self, curr_iter: int):
692685
:return param_data_file: The name of the parameter data file
693686
"""
694687
self.param_data_file = write_to_table(
695-
f'{os.getcwd()}/{self.sim_name}',
688+
self.sim_name,
696689
self.param_data,
697690
self.param_names,
698691
curr_iter)

grainlearning/tools.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import subprocess
88
from typing import List, Callable
99
import numpy as np
10-
import matplotlib.pylab as plt
1110
from sklearn.mixture import BayesianGaussianMixture
1211
from scipy.spatial import Voronoi, ConvexHull
1312

@@ -60,14 +59,14 @@ def write_to_table(sim_name, table, names, curr_iter=0, threads=8):
6059
"""
6160

6261
# Computation of decimal number for unique key
63-
table_file_name = f'{sim_name}_Iter{curr_iter}_Samples.txt'
62+
table_file_name = f'{os.getcwd()}/{sim_name}_Iter{curr_iter}_Samples.txt'
6463

6564
with open(table_file_name, 'w') as f_out:
6665
num, dim = table.shape
6766
mag = math.floor(math.log(num, 10)) + 1
6867
f_out.write(' '.join(['!OMP_NUM_THREADS', 'description', 'key'] + names + ['\n']))
6968
for j in range(num):
70-
description = 'Iter' + str(curr_iter) + '-Sample' + str(j).zfill(mag)
69+
description = f'{sim_name}_Iter' + str(curr_iter) + '-Sample' + str(j).zfill(mag)
7170
f_out.write(' '.join(
7271
[f'{threads:2d}'] + [description] +
7372
[f'{j:9d}'] + [f'{table[j][i]:20.10e}' for i in range(dim)] + ['\n']))
@@ -84,14 +83,7 @@ def get_keys_and_data(file_name: str, delimiters=None):
8483
"""
8584
if delimiters is None:
8685
delimiters = ['\t', ' ', ',']
87-
data = np.genfromtxt(file_name)
88-
89-
try:
90-
nc_ols = data.shape[1]
91-
except IndexError:
92-
n_rows = data.shape[0]
93-
nc_ols = 1
94-
data = data.reshape([n_rows, 1])
86+
data = np.genfromtxt(file_name, ndmin=2)
9587

9688
with open(file_name, 'r') as f_open:
9789
first_line = f_open.read().splitlines()[0]
@@ -102,7 +94,7 @@ def get_keys_and_data(file_name: str, delimiters=None):
10294
keys.remove('#')
10395
# remove empty strings from the list
10496
keys = list(filter(None, keys))
105-
if len(keys) == nc_ols:
97+
if len(keys) == data.shape[1]:
10698
break
10799

108100
# store data in a dictionary
@@ -409,6 +401,7 @@ def plot_param_stats(fig_name, param_names, means, covs, save_fig=0):
409401
:param covs: ndarray
410402
:param save_fig: bool defaults to False
411403
"""
404+
import matplotlib.pylab as plt
412405
num = len(param_names)
413406
n_cols = int(np.ceil(num / 2))
414407
plt.figure('Posterior means of the parameters')
@@ -449,6 +442,10 @@ def plot_posterior(fig_name, param_names, param_data, posterior, save_fig=0):
449442
:param posterior: ndarray
450443
:param save_fig: bool defaults to False
451444
"""
445+
try:
446+
import matplotlib.pylab as plt
447+
except ImportError:
448+
print('matplotlib is not installed, cannot plot posterior distribution. Please install with grainlearning[plot]')
452449
num_steps = posterior.shape[0]
453450
for i, name in enumerate(param_names):
454451
plt.figure(f'Posterior distribution of {name}')
@@ -468,6 +465,7 @@ def plot_posterior(fig_name, param_names, param_data, posterior, save_fig=0):
468465

469466

470467
def plot_param_data(fig_name, param_names, param_data_list, save_fig=0):
468+
import matplotlib.pylab as plt
471469
num = len(param_names)
472470
n_cols = int(np.ceil(num / 2))
473471
num = num - 1
@@ -501,6 +499,7 @@ def plot_obs_and_sim(fig_name, ctrl_name, obs_names, ctrl_data, obs_data, sim_da
501499
:param posterior: ndarray
502500
:param save_fig: bool defaults to False
503501
"""
502+
import matplotlib.pylab as plt
504503
ensemble_mean = np.einsum('ijk, ki->jk', sim_data, posteriors)
505504
ensemble_std = np.einsum('ijk, ki->jk', (sim_data - ensemble_mean) ** 2, posteriors)
506505
ensemble_std = np.sqrt(ensemble_std)
@@ -549,6 +548,10 @@ def write_dict_to_file(data, file_name):
549548
with open(file_name, 'w') as f:
550549
keys = data.keys()
551550
f.write('# ' + ' '.join(keys) + '\n')
552-
num = len(data[list(keys)[0]])
553-
for i in range(num):
554-
f.write(' '.join([str(data[key][i]) for key in keys]) + '\n')
551+
# check if data[list(keys)[0]] is a list
552+
if isinstance(data[list(keys)[0]], list):
553+
num = len(data[list(keys)[0]])
554+
for i in range(num):
555+
f.write(' '.join([str(data[key][i]) for key in keys]) + '\n')
556+
else:
557+
f.write(' '.join([str(data[key]) for key in keys]) + '\n')

pyproject.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,16 @@ pytest = {version = "^6.2.4", optional = true}
2828
pytest-cov = {version = "^2.12.1", optional = true}
2929
prospector = {version = "^1.7.6", optional = true, extras = ["with_pyroma"]}
3030
pyroma = {version = "^4.0", optional = true}
31+
h5py = {version ="^3.7.0", optional = true}
32+
wandb = {version ="^0.13.4", optional = true}
33+
tensorflow = {version ="2.10.0", optional = true}
34+
ipykernel = {version = "*", optional = true}
3135

3236
[tool.poetry.extras]
3337
docs = ["Sphinx", "sphinx-autodoc-typehints", "sphinx-mdinclude", "sphinx-rtd-theme"]
34-
dev = ["pytest", "pytest-cov", "prospector", "pyroma"]
38+
dev = ["pytest", "pytest-cov", "prospector", "pyroma", "h5py"]
39+
rnn = ["wandb", "tensorflow"]
40+
tutorials = ["ipykernel"]
3541

3642
[build-system]
3743
requires = ["poetry-core>=1.0.0"]

tests/integration/test_gmm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def test_gmm():
1818
"num_iter": 0,
1919
"system": {
2020
"system_type": IODynamicSystem,
21-
"obs_data_file": f'{sim_data_dir}/linearObs.dat',
21+
"obs_data_file": f'{sim_data_dir}/linear_obs.dat',
2222
"obs_names": ['f'],
2323
"ctrl_name": 'u',
2424
"sim_name": 'linear',

tests/integration/test_lenreg_IO.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
def run_sim(model, **kwargs):
1616
"""
17-
Runs the external executable and passes the parameter sample to generate the output file.
17+
Run the external executable and passes the parameter sample to generate the output file.
1818
"""
1919
# keep the naming convention consistent between iterations
2020
mag = floor(log(model.num_samples, 10)) + 1
@@ -40,7 +40,7 @@ def test_lenreg_IO():
4040
"param_names": ['a', 'b'],
4141
"num_samples": 20,
4242
"obs_data_file": os.path.abspath(
43-
os.path.join(__file__, "../..")) + '/data/linear_sim_data/linearObs.dat',
43+
os.path.join(__file__, "../..")) + '/data/linear_sim_data/linear_obs.dat',
4444
"obs_names": ['f'],
4545
"ctrl_name": 'u',
4646
"sim_name": 'linear',

tests/integration/test_smc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def test_smc():
1515
"num_iter": 0,
1616
"system": {
1717
"system_type": IODynamicSystem,
18-
"obs_data_file": f'{sim_data_dir}/linearObs.dat',
18+
"obs_data_file": f'{sim_data_dir}/linear_obs.dat',
1919
"obs_names": ['f'],
2020
"ctrl_name": 'u',
2121
"sim_name": 'linear',

tests/unit/test_dynamic_systems.py

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -193,26 +193,20 @@ class TestIODynamicSystem:
193193

194194
def test_init(self):
195195
"""Test if the class is initialized correctly"""
196-
system_cls = IODynamicSystem(
197-
param_min=[0.001, 0.001],
198-
param_max=[1, 10],
199-
param_names=['a', 'b'],
200-
num_samples=20,
201-
obs_data_file=path.abspath(path.join(__file__, "../..")) + '/data/linear_sim_data/linearObs.dat',
202-
obs_names=['f'],
203-
ctrl_name='u',
204-
sim_name='linear',
205-
sim_data_dir=path.abspath(path.join(__file__, "../..")) + '/data/linear_sim_data/',
206-
sim_data_file_ext='.txt',
207-
)
196+
system_cls = IODynamicSystem(sim_name='linear',
197+
sim_data_dir=path.abspath(path.join(__file__, "../..")) + '/data/linear_sim_data/',
198+
sim_data_file_ext='.txt', obs_data_file=path.abspath(
199+
path.join(__file__, "../..")) + '/data/linear_sim_data/linear_obs.dat', obs_names=['f'], ctrl_name='u',
200+
num_samples=20, param_min=[0.001, 0.001], param_max=[1, 10],
201+
param_names=['a', 'b'])
208202

209203
config = {
210204
"system_type": IODynamicSystem,
211205
"param_min": [0.001, 0.001],
212206
"param_max": [1, 10],
213207
"param_names": ['a', 'b'],
214208
"num_samples": 20,
215-
"obs_data_file": path.abspath(path.join(__file__, "../..")) + '/data/linear_sim_data/linearObs.dat',
209+
"obs_data_file": path.abspath(path.join(__file__, "../..")) + '/data/linear_sim_data/linear_obs.dat',
216210
"obs_names": ['f'],
217211
"ctrl_name": 'u',
218212
"sim_name": 'linear',
@@ -246,27 +240,20 @@ def run_sim(system, **kwargs):
246240
data.append(np.array(y, ndmin=2))
247241
# Write the data to a file
248242
data_file_name = f'{sim_name}_' + description + '_sim.txt'
249-
write_dict_to_file({'f': y}, data_file_name)
243+
write_dict_to_file({'f': list(y)}, data_file_name)
250244
# Write the parameters to a file
251245
data_param_name = f'{sim_name}_' + description + '_param.txt'
252-
param_data = {'param0': [param[0]], 'param1': [param[1]], 'param2': [param[2]], 'param3': [param[3]]}
246+
param_data = {'a': [param[0]], 'b': [param[1]], 'c': [param[2]], 'd': [param[3]]}
253247
write_dict_to_file(param_data, data_param_name)
254248
# Set the simulation data
255249
system.set_sim_data(data)
256250

257-
system_cls = IODynamicSystem(
258-
param_min=[None, None, None, None],
259-
param_max=[None, None, None, None],
260-
param_names=['a', 'b', 'c', 'd'],
261-
num_samples=10,
262-
obs_data_file=path.abspath(path.join(__file__, "../..")) + '/data/linear_sim_data/linearObs.dat',
263-
obs_names=['f'],
264-
ctrl_name='u',
265-
sim_name='test',
266-
sim_data_dir=PATH + '/sim_data/',
267-
sim_data_file_ext='.txt',
268-
callback=run_sim,
269-
)
251+
system_cls = IODynamicSystem(sim_name='test', sim_data_dir=PATH + '/sim_data/', sim_data_file_ext='.txt',
252+
obs_data_file=path.abspath(
253+
path.join(__file__, "../..")) + '/data/linear_sim_data/linear_obs.dat',
254+
obs_names=['f'], ctrl_name='u', num_samples=10, param_min=[None, None, None, None],
255+
param_max=[None, None, None, None], callback=run_sim,
256+
param_names=['a', 'b', 'c', 'd'])
270257

271258
system_cls.param_data = np.arange(1, system_cls.num_samples * 4 + 1, dtype=float).reshape(
272259
system_cls.num_samples, 4)
@@ -301,7 +288,7 @@ def test_get_obs_data(self):
301288
"system_type": IODynamicSystem,
302289
"param_min": [0.001, 0.001],
303290
"param_max": [1, 10],
304-
"obs_data_file": path.abspath(path.join(__file__, "../../data/linear_sim_data/linearObs.dat")),
291+
"obs_data_file": path.abspath(path.join(__file__, "../../data/linear_sim_data/linear_obs.dat")),
305292
"obs_names": ['f'],
306293
"ctrl_name": 'u',
307294
"sim_name": 'linear',

tests/unit/test_iterative_bayesian_filter.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -171,20 +171,12 @@ def test_run_inference():
171171
def test_save_and_load_proposal():
172172
"""Test if the proposal density can be loaded from a file"""
173173
#: Initialize a system object (note the observed data is not used in this test)
174-
system_cls = IODynamicSystem(
175-
sim_name='test_ibf',
176-
sim_data_dir=PATH + '/sim_data/',
177-
sim_data_file_ext='.txt',
178-
obs_names=['f'],
179-
ctrl_name='u',
180-
num_samples=10,
181-
param_min=[1e6, 0.2],
182-
param_max=[1e7, 0.5],
183-
obs_data=[[12, 3, 4, 4], [12, 4, 5, 4]],
184-
ctrl_data=[1, 2, 3, 4],
185-
param_names=['a', 'b'],
186-
obs_data_file=os.path.abspath(os.path.join(__file__, "../..")) + '/data/linear_sim_data/linearObs.dat',
187-
)
174+
system_cls = IODynamicSystem(sim_name='test_ibf', sim_data_dir=PATH + '/sim_data/', sim_data_file_ext='.txt',
175+
obs_data_file=os.path.abspath(
176+
os.path.join(__file__, "../..")) + '/data/linear_sim_data/linear_obs.dat',
177+
obs_names=['f'], ctrl_name='u', num_samples=10, param_min=[1e6, 0.2],
178+
param_max=[1e7, 0.5], obs_data=[[12, 3, 4, 4], [12, 4, 5, 4]], ctrl_data=[1, 2, 3, 4],
179+
param_names=['a', 'b'])
188180

189181
#: Assert that the inference runs correctly if a proposal density is provided
190182
ibf_cls = IterativeBayesianFilter.from_dict(

0 commit comments

Comments
 (0)