diff --git a/src/aiida/orm/nodes/data/array/trajectory.py b/src/aiida/orm/nodes/data/array/trajectory.py index 3d6356ebd2..9c807d839f 100644 --- a/src/aiida/orm/nodes/data/array/trajectory.py +++ b/src/aiida/orm/nodes/data/array/trajectory.py @@ -8,10 +8,13 @@ ########################################################################### """AiiDA class to deal with crystal structure trajectories.""" +from __future__ import annotations + import collections.abc from typing import List from aiida.common.pydantic import MetadataField +from aiida.common.warnings import warn_deprecation from .array import ArrayData @@ -33,7 +36,7 @@ def __init__(self, structurelist=None, **kwargs): if structurelist is not None: self.set_structurelist(structurelist) - def _internal_validate(self, stepids, cells, symbols, positions, times, velocities): + def _internal_validate(self, stepids, cells, symbols, positions, times, velocities, pbc): """Internal function to validate the type and shape of the arrays. See the documentation of py:meth:`.set_trajectory` for a description of the valid shape and type of the parameters. @@ -82,8 +85,14 @@ def _internal_validate(self, stepids, cells, symbols, positions, times, velociti 'have shape (s,n,3), ' 'with s=number of steps and n=number of symbols' ) - - def set_trajectory(self, symbols, positions, stepids=None, cells=None, times=None, velocities=None): + if not (isinstance(pbc, (list, tuple)) and len(pbc) == 3 and all(isinstance(val, bool) for val in pbc)): + raise ValueError('`pbc` must be a list/tuple of length three with boolean values.') + if cells is None and list(pbc) != [False, False, False]: + raise ValueError('Periodic boundary conditions are only possible when a cell is defined.') + + def set_trajectory( + self, symbols, positions, stepids=None, cells=None, times=None, velocities=None, pbc: None | list | tuple = None + ): r"""Store the whole trajectory, after checking that types and dimensions are correct. @@ -131,14 +140,28 @@ def set_trajectory(self, symbols, positions, stepids=None, cells=None, times=Non :param velocities: if specified, must be a float array with the same dimensions of the ``positions`` array. The array contains the velocities in the atoms. + :param pbc: periodic boundary conditions of the structure. Should be a list of + length three with booleans indicating if the structure is periodic in that + direction. The same periodic boundary conditions are set for each step. .. todo :: Choose suitable units for velocities """ import numpy - self._internal_validate(stepids, cells, symbols, positions, times, velocities) - # set symbols as attribute for easier querying + if cells is None: + pbc = pbc or [False, False, False] + elif pbc is None: + warn_deprecation( + "When 'cells' is not None, the periodic boundary conditions should be explicitly specified via" + "the 'pbc' keyword argument. Defaulting to '[True, True, True]', but this will raise in v3.0.0.", + version=3, + ) + pbc = [True, True, True] + + self._internal_validate(stepids, cells, symbols, positions, times, velocities, pbc) + # set symbols/pbc as attributes for easier querying self.base.attributes.set('symbols', list(symbols)) + self.base.attributes.set('pbc', tuple(pbc)) self.set_array('positions', positions) if stepids is not None: # use input stepids self.set_array('steps', stepids) @@ -189,7 +212,12 @@ def set_structurelist(self, structurelist): raise ValueError('Symbol lists have to be the same for all of the supplied structures') symbols = list(symbols_first) positions = numpy.array([[list(s.position) for s in x.sites] for x in structurelist]) - self.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions) + pbc_set = {structure.pbc for structure in structurelist} + if len(pbc_set) == 1: + pbc = pbc_set.pop() + else: + raise ValueError(f'All structures should have the same `pbc`, found: {pbc_set}') + self.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions, pbc=pbc) def _validate(self): """Verify that the required arrays are present and that their type and @@ -206,6 +234,7 @@ def _validate(self): self.get_positions(), self.get_times(), self.get_velocities(), + self.pbc, ) # Should catch TypeErrors, ValueErrors, and KeyErrors for missing arrays except Exception as exception: @@ -264,6 +293,15 @@ def symbols(self) -> List[str]: """ return self.base.attributes.get('symbols') + @property + def pbc(self) -> tuple[bool]: + """Return the tuple of periodic boundary conditions. + + Returns a tuple of length three with booleans indicating if the structure is + periodic in that direction. + """ + return self.base.attributes.get('pbc') + def get_positions(self): """Return the array of positions, if it has already been set. @@ -384,7 +422,7 @@ def get_step_structure(self, index, custom_kinds=None): 'passed {}, but the symbols are {}'.format(sorted(kind_names), sorted(symbols)) ) - struc = StructureData(cell=cell) + struc = StructureData(cell=cell, pbc=self.pbc) if custom_kinds is not None: for _k in custom_kinds: struc.append_kind(_k) diff --git a/tests/orm/nodes/data/test_trajectory.py b/tests/orm/nodes/data/test_trajectory.py index eb0a78384a..00540e3dd9 100644 --- a/tests/orm/nodes/data/test_trajectory.py +++ b/tests/orm/nodes/data/test_trajectory.py @@ -3,6 +3,7 @@ import numpy as np import pytest +from aiida.common.warnings import AiidaDeprecationWarning from aiida.orm import StructureData, TrajectoryData, load_node @@ -10,6 +11,7 @@ def trajectory_data(): """Return a dictionary of data to create a ``TrajectoryData``.""" symbols = ['H'] * 5 + ['Cl'] * 5 + pbc = [True, True, True] stepids = np.arange(1000, 3000, 10) times = stepids * 0.01 positions = np.arange(6000, dtype=float).reshape((200, 10, 3)) @@ -23,6 +25,7 @@ def trajectory_data(): 'cells': cells, 'times': times, 'velocities': velocities, + 'pbc': pbc, } @@ -107,10 +110,11 @@ def test_trajectory_get_step_data(self, trajectory_data): stepid, time, cell, symbols, positions, velocities = trajectory.get_step_data(-2) assert stepid == trajectory_data['stepids'][-2] assert time == trajectory_data['times'][-2] - np.array_equal(cell, trajectory_data['cells'][-2, :, :]) - np.array_equal(symbols, trajectory_data['symbols']) - np.array_equal(positions, trajectory_data['positions'][-2, :, :]) - np.array_equal(velocities, trajectory_data['velocities'][-2, :, :]) + assert np.array_equal(cell, trajectory_data['cells'][-2, :, :]) + assert np.array_equal(symbols, trajectory_data['symbols']) + assert np.array_equal(trajectory.pbc, trajectory_data['pbc']) + assert np.array_equal(positions, trajectory_data['positions'][-2, :, :]) + assert np.array_equal(velocities, trajectory_data['velocities'][-2, :, :]) def test_trajectory_get_step_data_empty(self, trajectory_data): """Test the `get_step_data` method when some arrays are not defined.""" @@ -123,6 +127,8 @@ def test_trajectory_get_step_data_empty(self, trajectory_data): assert np.array_equal(symbols, trajectory_data['symbols']) assert np.array_equal(positions, trajectory_data['positions'][3, :, :]) assert velocities is None + # In case the cell is not defined, there should be no periodic boundary conditions + assert np.array_equal(trajectory.pbc, [False, False, False]) def test_trajectory_get_step_structure(self, trajectory_data): """Test the `get_step_structure` method.""" @@ -141,3 +147,90 @@ def test_trajectory_get_step_structure(self, trajectory_data): with pytest.raises(IndexError): trajectory.get_step_structure(500) + + def test_trajectory_pbc_structures(self, trajectory_data): + """Test the `pbc` for the `TrajectoryData` using structure inputs.""" + # Test non-pbc structure with no cell + structure = StructureData(cell=None, pbc=[False, False, False]) + structure.append_atom(position=[0.0, 0.0, 0.0], symbols='H') + + trajectory = TrajectoryData(structurelist=(structure,)) + + trajectory.get_step_structure(0).store() # Verify that the `StructureData` can be stored + assert trajectory.get_step_structure(0).pbc == structure.pbc + + # Test failure for incorrect pbc + trajectory_data_incorrect = trajectory_data.copy() + trajectory_data_incorrect['pbc'] = [0, 0, 0] + with pytest.raises(ValueError, match='`pbc` must be a list/tuple of length three with boolean values'): + trajectory = TrajectoryData() + trajectory.set_trajectory(**trajectory_data_incorrect) + + # Test failure when structures have different pbc + cell = [[3.0, 0.1, 0.3], [-0.05, 3.0, -0.2], [0.02, -0.08, 3.0]] + structure_periodic = StructureData(cell=cell) + structure_periodic.append_atom(position=[0.0, 0.0, 0.0], symbols='H') + structure_non_periodic = StructureData(cell=cell, pbc=[False, False, False]) + structure_non_periodic.append_atom(position=[0.0, 0.0, 0.0], symbols='H') + + with pytest.raises(ValueError, match='All structures should have the same `pbc`'): + TrajectoryData(structurelist=(structure_periodic, structure_non_periodic)) + + def test_trajectory_pbc_set_trajectory(self): + """Test the `pbc` for the `TrajectoryData` using `set_trajectory`.""" + data = { + 'symbols': ['H'], + 'positions': np.array( + [ + [ + [0.0, 0.0, 0.0], + ] + ] + ), + } + trajectory = TrajectoryData() + + data.update( + { + 'cells': None, + 'pbc': None, + } + ) + trajectory.set_trajectory(**data) + assert trajectory.get_step_structure(0).pbc == (False, False, False) + + data.update( + { + 'cells': None, + 'pbc': [False, False, False], + } + ) + trajectory.set_trajectory(**data) + assert trajectory.get_step_structure(0).pbc == (False, False, False) + + data.update( + { + 'cells': None, + 'pbc': [True, False, False], + } + ) + with pytest.raises(ValueError, match='Periodic boundary conditions are only possible when a cell is defined'): + trajectory.set_trajectory(**data) + + data.update( + { + 'cells': np.array([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]]), + 'pbc': None, + } + ) + with pytest.warns(AiidaDeprecationWarning, match="When 'cells' is not None, the periodic"): + trajectory.set_trajectory(**data) + + data.update( + { + 'cells': np.array([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]]), + 'pbc': (True, False, False), + } + ) + trajectory.set_trajectory(**data) + assert trajectory.get_step_structure(0).pbc == (True, False, False)