Skip to content

Commit

Permalink
Fix crash in Transitions.occupancy if input sites are disordered (#342
Browse files Browse the repository at this point in the history
)

* Add work-around for disordered sites listing

#339

* Add warning for disordered structures

* Fix atom locations and add occupancy_by_site_type
  • Loading branch information
stefsmeets authored Nov 18, 2024
1 parent 6f8375e commit 34a0cc0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 10 deletions.
42 changes: 32 additions & 10 deletions src/gemdat/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import typing
from collections import defaultdict
from itertools import pairwise
from warnings import warn

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -67,6 +68,15 @@ def __init__(
inner_states : np.ndarray
Input states for inner sites
"""
if not (sites.is_ordered):
warn(
'Input `sites` are disordered! '
'Although the code may work, it was written under the assumption '
'that an ordered structure would be passed. '
'See https://github.com/GEMDAT-repos/GEMDAT/issues/339 for more information.',
stacklevel=2,
)

self.sites = sites
self.trajectory = trajectory
self.diff_trajectory = diff_trajectory
Expand Down Expand Up @@ -252,7 +262,10 @@ def occupancy(self) -> Structure:
counts = counts / len(states)
occupancies = dict(zip(unq, counts))

species = [{site.specie.name: occupancies.get(i, 0)} for i, site in enumerate(sites)]
species = [
{site.species.elements[0].name: occupancies.get(i, 0)}
for i, site in enumerate(sites)
]

return Structure(
lattice=sites.lattice,
Expand All @@ -262,27 +275,36 @@ def occupancy(self) -> Structure:
labels=sites.labels,
)

def atom_locations(self):
def occupancy_by_site_type(self) -> dict[str, float]:
"""Calculate average occupancy per a type of site.
Returns
-------
occupancy : dict[str, float]
Return dict with average occupancy per site type
"""
compositions_by_label = defaultdict(list)

for site in self.occupancy():
compositions_by_label[site.label].append(site.species.num_atoms)

return {k: sum(v) / len(v) for k, v in compositions_by_label.items()}

def atom_locations(self) -> dict[str, float]:
"""Calculate fraction of time atoms spent at a type of site.
Returns
-------
dict[str, float]
Return dict with the fraction of time atoms spent at a site
"""
multiplier = len(self.sites) / self.n_floating

n = self.n_floating
compositions_by_label = defaultdict(list)

for site in self.occupancy():
compositions_by_label[site.label].append(site.species.num_atoms)

ret = {}

for k, v in compositions_by_label.items():
ret[k] = (sum(v) / len(v)) * multiplier

return ret
return {k: sum(v) / n for k, v in compositions_by_label.items()}

def split(self, n_parts: int = 10) -> list[Transitions]:
"""Split data into equal parts in time for statistics.
Expand Down
4 changes: 4 additions & 0 deletions tests/integration/transitions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def test_occupancy_parts(self, vasp_transitions):
35.43733333333334,
]

def test_occupancy_by_site_type(self, vasp_transitions):
occ = vasp_transitions.occupancy_by_site_type()
assert occ == {'48h': 0.3806277777777776}

def test_atom_locations(self, vasp_transitions):
dct = vasp_transitions.atom_locations()
assert dct == {'48h': 0.7612555555555552}
Expand Down

0 comments on commit 34a0cc0

Please sign in to comment.