From 34a0cc014d5ee7574dc469d8967afce2b1435281 Mon Sep 17 00:00:00 2001 From: Stef Smeets Date: Mon, 18 Nov 2024 16:51:30 +0100 Subject: [PATCH] Fix crash in `Transitions.occupancy` if input sites are disordered (#342) * Add work-around for disordered sites listing https://github.com/GEMDAT-repos/GEMDAT/issues/339 * Add warning for disordered structures * Fix atom locations and add occupancy_by_site_type --- src/gemdat/transitions.py | 42 ++++++++++++++++++++------- tests/integration/transitions_test.py | 4 +++ 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/src/gemdat/transitions.py b/src/gemdat/transitions.py index f5d08cd6..78e40306 100644 --- a/src/gemdat/transitions.py +++ b/src/gemdat/transitions.py @@ -6,6 +6,7 @@ import typing from collections import defaultdict from itertools import pairwise +from warnings import warn import numpy as np import pandas as pd @@ -67,6 +68,15 @@ def __init__( inner_states : np.ndarray Input states for inner sites """ + if not (sites.is_ordered): + warn( + 'Input `sites` are disordered! ' + 'Although the code may work, it was written under the assumption ' + 'that an ordered structure would be passed. ' + 'See https://github.com/GEMDAT-repos/GEMDAT/issues/339 for more information.', + stacklevel=2, + ) + self.sites = sites self.trajectory = trajectory self.diff_trajectory = diff_trajectory @@ -252,7 +262,10 @@ def occupancy(self) -> Structure: counts = counts / len(states) occupancies = dict(zip(unq, counts)) - species = [{site.specie.name: occupancies.get(i, 0)} for i, site in enumerate(sites)] + species = [ + {site.species.elements[0].name: occupancies.get(i, 0)} + for i, site in enumerate(sites) + ] return Structure( lattice=sites.lattice, @@ -262,7 +275,22 @@ def occupancy(self) -> Structure: labels=sites.labels, ) - def atom_locations(self): + def occupancy_by_site_type(self) -> dict[str, float]: + """Calculate average occupancy per a type of site. + + Returns + ------- + occupancy : dict[str, float] + Return dict with average occupancy per site type + """ + compositions_by_label = defaultdict(list) + + for site in self.occupancy(): + compositions_by_label[site.label].append(site.species.num_atoms) + + return {k: sum(v) / len(v) for k, v in compositions_by_label.items()} + + def atom_locations(self) -> dict[str, float]: """Calculate fraction of time atoms spent at a type of site. Returns @@ -270,19 +298,13 @@ def atom_locations(self): dict[str, float] Return dict with the fraction of time atoms spent at a site """ - multiplier = len(self.sites) / self.n_floating - + n = self.n_floating compositions_by_label = defaultdict(list) for site in self.occupancy(): compositions_by_label[site.label].append(site.species.num_atoms) - ret = {} - - for k, v in compositions_by_label.items(): - ret[k] = (sum(v) / len(v)) * multiplier - - return ret + return {k: sum(v) / n for k, v in compositions_by_label.items()} def split(self, n_parts: int = 10) -> list[Transitions]: """Split data into equal parts in time for statistics. diff --git a/tests/integration/transitions_test.py b/tests/integration/transitions_test.py index 97942539..27a5e752 100644 --- a/tests/integration/transitions_test.py +++ b/tests/integration/transitions_test.py @@ -133,6 +133,10 @@ def test_occupancy_parts(self, vasp_transitions): 35.43733333333334, ] + def test_occupancy_by_site_type(self, vasp_transitions): + occ = vasp_transitions.occupancy_by_site_type() + assert occ == {'48h': 0.3806277777777776} + def test_atom_locations(self, vasp_transitions): dct = vasp_transitions.atom_locations() assert dct == {'48h': 0.7612555555555552}