From d5b588dd9a3520081198ffb9f918505319dac651 Mon Sep 17 00:00:00 2001
From: Brendan Moloney <moloney.brendan@gmail.com>
Date: Sat, 1 Aug 2020 15:54:10 -0700
Subject: [PATCH 1/7] WIP: Initial meta summary work

---
 nibabel/metasum.py            | 464 ++++++++++++++++++++++++++++++++++
 nibabel/tests/test_metasum.py |   0
 2 files changed, 464 insertions(+)
 create mode 100644 nibabel/metasum.py
 create mode 100644 nibabel/tests/test_metasum.py

diff --git a/nibabel/metasum.py b/nibabel/metasum.py
new file mode 100644
index 0000000000..9dc5dfe5af
--- /dev/null
+++ b/nibabel/metasum.py
@@ -0,0 +1,464 @@
+# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
+# vi: set ft=python sts=4 ts=4 sw=4 et:
+### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
+#
+#   See COPYING file distributed along with the NiBabel package for the
+#   copyright and license terms.
+#
+### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
+'''Aggregate information for mutliple images
+'''
+from bitarray import bitarray, frozenbitarray
+from bitarry.utils import zeroes
+
+
+class FloatCanon(object):
+    '''Look up a canonical float that we compare equal to'''
+    def __init__(self, n_digits=6):
+        self._n_digits = n_digits
+        self._offset = 0.5 * (10 ** -n_digits)
+        self._canon_vals = set()
+        self._rounded = {}
+
+    def get(self, val):
+        '''Get a canonical value that at least compares equal to `val`'''
+        res = self._values.get(val)
+        if res is not None:
+            return res
+        lb = round(val, self._n_digits)
+        res = self._rounded.get(lb)
+        if res is not None:
+            return res
+        ub = round(val + self._offset, self._n_digits)
+        res = self._rounded.get(ub)
+        if res is not None:
+            return res
+
+
+_NoValue = object()
+
+# TODO: Integrate some value canonicalization filtering? Or just require the
+#       user to do that themselves?
+class ValueIndices(object):
+    """Track indices of values in sequence.
+
+    If values repeat frequently then memory usage can be dramatically improved.
+    It can be thought of as the inverse to a list.
+
+    >>> values = ['a', 'a', 'b', 'a', 'b']
+    >>> vidx = ValueIndices(values)
+    >>> vidx['a']
+    [0, 1, 3]
+    >>> vidx['b']
+    [2, 4]
+    """
+
+    def __init__(self, values=None):
+        """Initialize a ValueIndices instance.
+
+        Parameters
+        ----------
+        values : sequence
+            The sequence of values to track indices on
+        """
+
+        self._n_input = 0
+
+        # The values can be constant, unique to specific indices, or
+        # arbitrarily varying
+        self._const_val = _NoValue
+        self._unique_vals = {}
+        self._val_bitarrs = {}
+
+        if values is not None:
+            self.extend(values)
+
+    @property
+    def n_input(self):
+        '''The number of inputs we are indexing'''
+        return self._n_input
+
+    def __len__(self):
+        '''Number of unique values being tracked'''
+        if self._const_val is not _NoValue:
+            return 1
+        return len(self._unique_vals) + len(self._val_bitarrs)
+
+    def __getitem__(self, value):
+        '''Return list of indices for the given value'''
+        if self._const_val == value:
+            return list(range(self._n_input))
+        idx = self._unique_vals.get(value)
+        if idx is not None:
+            return [idx]
+        ba = self._val_bitarrs[value]
+        return list(self._extract_indices(ba))
+
+    def values(self):
+        '''Generate each unique value that has been seen'''
+        if self._const_val is not _NoValue:
+            yield self._const_val
+            return
+        for val in self._unique_vals.keys():
+            yield val
+        for val in self._val_bitarrs.keys():
+            yield val
+
+    def get_mask(self, value):
+        '''Get bitarray mask of indices with this value'''
+        if self._const_val is not _NoValue:
+            if self._const_val != value:
+                raise KeyError()
+            res = bitarray(self._n_input)
+            res.setall(1)
+            return res
+        idx = self._unique_vals.get(value)
+        if idx is not None:
+            res = zeroes(self._n_inpuf)
+            res[idx] = 1
+            return res
+        return self._val_bitarrs[value].copy()
+
+    def num_indices(self, value):
+        '''Number of indices for the given `value`'''
+        if self._const_val is not _NoValue:
+            if self._const_val != value:
+                raise KeyError()
+            return self._n_input
+        if value in self._unique_vals:
+            return 1
+        return self._val_bitarrs[value].count()
+
+    def get_value(self, idx):
+        '''Get the value at `idx`'''
+        if not 0 <= idx < self._n_input:
+            raise IndexError()
+        if self._const_val is not _NoValue:
+            return self._const_val
+        for val, vidx in self._unique_vals.items():
+            if vidx == idx:
+                return val
+        bit_idx = zeroes(self._n_input)
+        bit_idx[idx] = 1
+        for val, ba in self._val_bitarrs.items():
+            if (ba | bit_idx).any():
+                return val
+        assert False
+
+    def extend(self, values):
+        '''Add more values to the end of any existing ones'''
+        curr_size = self._n_input
+        if isinstance(values, ValueIndices):
+            other_is_vi = True
+            other_size = values._n_input
+        else:
+            other_is_vi = False
+            other_size = len(values)
+        final_size = curr_size + other_size
+        for ba in self._val_bitarrs.values():
+            ba.extend(zeroes(other_size))
+        if other_is_vi:
+            if self._const_val is not _NoValue:
+                if values._const_val is not _NoValue:
+                    self._extend_const(values)
+                    return
+                else:
+                    self._rm_const()
+            elif values._const_val is not _NoValue:
+                cval = values._const_val
+                other_unique = {}
+                other_bitarrs = {}
+                if values._n_input == 1:
+                    other_unique[cval] = 0
+                else:
+                    other_bitarrs[cval] = bitarray(values._n_input)
+                    other_bitarrs[cval].setall(1)
+            else:
+                other_unique = values._unique_vals
+                other_bitarrs = values._val_bitarrs
+            for val, other_idx in other_unique.items():
+                self._ingest_single(val, final_size, curr_size, other_idx)
+            for val, other_ba in other_bitarrs.items():
+                curr_ba = self._val_bitarrs.get(val)
+                if curr_ba is None:
+                    curr_idx = self._unique_vals.get(val)
+                    if curr_idx is None:
+                        if curr_size == 0:
+                            new_ba = other_ba.copy()
+                        else:
+                            new_ba = zeroes(curr_size)
+                            new_ba.extend(other_ba)
+                    else:
+                        new_ba = zeroes(curr_size)
+                        new_ba[curr_idx] = True
+                        new_ba.extend(other_ba)
+                        del self._unique_vals[val]
+                    self._val_bitarrs[val] = new_ba
+                else:
+                    curr_ba[curr_size:] |= other_ba
+        else:
+            for other_idx, val in enumerate(values):
+                self._ingest_single(val, final_size, curr_size, other_idx)
+        self._n_input = final_size
+
+    def append(self, value):
+        '''Append another value as input'''
+        if self._const_val == value:
+            self._n_input += 1
+            return
+        elif self._const_val is not _NoValue:
+            self._rm_const()
+        curr_size = self._n_input
+        found = False
+        for val, bitarr in self._val_bitarrs.items():
+            if val == value:
+                found = True
+                bitarr.append(True)
+            else:
+                bitarr.append(False)
+        if not found:
+            curr_idx = self._unique_vals.get(value)
+            if curr_idx is None:
+                self._unique_vals[value] = curr_size
+            else:
+                new_ba = zeroes(curr_size + 1)
+                new_ba[curr_idx] = True
+                new_ba[curr_size] = True
+                self._val_bitarrs[value] = new_ba
+                del self._unique_vals[value]
+        self._n_input += 1
+
+    def argsort(self, reverse=False):
+        '''Return array of indices in order that sorts the values'''
+        if self._const_val is not _NoValue:
+            return np.arange(self._n_input)
+        res = np.empty(self._n_input, dtype=np.int64)
+        vals = list(self._unique_vals.keys()) + list(self._val_bitarrs.keys())
+        vals.sort(reverse=reverse)
+        res_idx = 0
+        for val in vals:
+            idx = self._unique_vals.get(val)
+            if idx is not None:
+                res[res_idx] = idx
+                res_idx += 1
+                continue
+            ba = self._val_bitarrs[val]
+            for idx in self._extract_indices(ba):
+                res[res_idx] = idx
+                res_idx += 1
+        return res
+
+    def is_covariant(self, other):
+        '''True if `other` has values that vary the same way ours do
+
+        The actual values themselves are ignored
+        '''
+        if self._n_input != other._n_input or len(self) != len(other):
+            return False
+        if self._const_val is not _NoValue:
+            return other._const_val is not _NoValue
+        if self._n_input == len(self):
+            return other._n_input == len(other)
+        self_ba_set = set(frozenbitarray(ba) for ba in self._val_bitarrs.values())
+        other_ba_set = set(frozenbitarray(ba) for ba in other._val_bitarrs.values())
+        if self_ba_set != other_ba_set:
+            return False
+        if len(self._unique_vals) != len(other._unique_vals):
+            return False
+        return True
+
+    def is_blocked(self, block_factor=None):
+        '''True if each value has the same number of indices
+
+        If `block_factor` is not None we also test that it evenly divides the
+        block size.
+        '''
+        block_size, rem = divmod(self._n_input, len(self))
+        if rem != 0:
+            return False
+        if block_factor is not None and block_size % block_factor != 0:
+            return False
+        for val in self.values():
+            if self.num_indices(val) != block_size:
+                return False
+        return True
+
+    def is_subpartition(self, other):
+        '''True if we have more values and they nest within values from other
+
+
+        '''
+
+    def _extract_indices(self, ba):
+        '''Generate integer indices from bitarray representation'''
+        start = 0
+        while True:
+            try:
+                # TODO: Is this the most efficient approach?
+                curr_idx = ba.index(True, start=start)
+            except ValueError:
+                return
+            yield curr_idx
+            start = curr_idx
+
+    def _ingest_single(self, val, final_size, curr_size, other_idx):
+        '''Helper to ingest single value from another collection'''
+        curr_ba = self._val_bitarrs.get(val)
+        if curr_ba is None:
+            curr_idx = self._unique_vals.get(val)
+            if curr_idx is None:
+                self._unique_vals[val] = curr_size + other_idx
+            else:
+                new_ba = zeroes(final_size)
+                new_ba[curr_idx] = True
+                new_ba[curr_size + other_idx] = True
+                self._val_bitarrs = new_ba
+                del self._unique_vals[val]
+        else:
+            curr_ba[curr_size + other_idx] = True
+
+    def _rm_const(self):
+        assert self._const_val is not _NoValue
+        if self._n_input == 1:
+            self._unique_vals[self._const_val] = 0
+        else:
+            self._val_bitarrs[self._const_val] = bitarray(self._n_input)
+            self._val_bitarrs[self._const_val].setall(1)
+        self._const_val == _NoValue
+
+    def _extend_const(self, other):
+        if self._const_val != other._const_val:
+            if self._n_input == 1:
+                self._unique_vals[self._const_val] = 0
+            else:
+                self_ba = bitarray(self._n_input)
+                other_ba = bitarray(other._n_input)
+                self_ba.setall(1)
+                other_ba.setall(0)
+                self._val_bitarrs[self._const_val] = self_ba + other_ba
+            if other._n_input == 1:
+                self._unique_vals[other._const_val] = self._n_input
+            else:
+                self_ba = bitarray(self._n_input)
+                other_ba = bitarray(other._n_input)
+                self_ba.setall(0)
+                other_ba.setall(1)
+                self._val_bitarrs[other._const_val] = self_ba + other_ba
+            self._const_val = _NoValue
+        self._n_input += other._n_input
+
+
+_MissingKey = object()
+
+
+class MetaSummary:
+    '''Summarize a sequence of dicts, tracking how individual keys vary
+
+    The assumption is that for any key many values will be constant, or at
+    least repeated, and thus we can reduce memory consumption by only storing
+    the value once along with the indices it appears at.
+    '''
+    def __init__(self):
+        self._v_idxs = {}
+        self._n_input = 0
+
+    @property
+    def n_input(self):
+        return self._n_input
+
+    def append(self, meta):
+        seen = set()
+        for key, v_idx in self._v_idxs.items():
+            val = meta.get(key, _MissingKey)
+            v_idx.append(val)
+            seen.add(key)
+        for key, val in meta.items():
+            if key in seen:
+                continue
+            v_idx = ValueIndices([_MissingKey for _ in range(self._n_input)])
+            v_idx.append(val)
+            self._v_idxs[key] = v_idx
+        self._n_input += 1
+
+    def extend(self, metas):
+        pass # TODO
+
+    def keys(self):
+        '''Generate all known keys'''
+        return self._v_idxs.keys()
+
+    def const_keys(self):
+        '''Generate keys with a constant value across all inputs'''
+        for key, v_idx in self._v_idxs.items():
+            if len(v_idx) == 1:
+                yield key
+
+    def unique_keys(self):
+        '''Generate keys with a unique value in each input'''
+        n_input = self._n_input
+        if n_input <= 1:
+            return
+        for key, v_idx in self._v_idxs.items():
+            if len(v_idx) == n_input:
+                yield key
+
+    def repeating_keys(self):
+        '''Generate keys that have some repeating component but are not const
+        '''
+        n_input = self._n_input
+        if n_input <= 1:
+            return
+        for key, v_idx in self._v_idxs.items():
+            if 1 < len(v_idx) < n_input:
+                yield key
+
+    def repeating_groups(self, block_only=False, block_factor=None):
+        '''Generate groups of repeating keys that vary with the same pattern
+        '''
+        n_input = self._n_input
+        if n_input <= 1:
+            # If there is only one element, consider all keys as const
+            return
+        # TODO: Can we sort so grouped v_idxs are sequential?
+        #         - Sort by num values isn't sufficient
+        curr_group = []
+        for key, v_idx in self._v_idxs.items():
+            if 1 < len(v_idx) < n_input:
+                if v_idx.is_even(block_factor):
+                pass # TODO
+
+    def get_meta(self, idx):
+        '''Get the full dict at the given index'''
+        res = {}
+        for key, v_idx in self._v_idxs.items():
+            val = v_idx.get_value(idx)
+            if val is _MissingKey:
+                continue
+            res[key] = val
+        return res
+
+    def get_val(self, idx, key, default=None):
+        '''Get the value at `idx` for the `key`, or return `default``'''
+        res = self._v_idxs[key].get_value(key)
+        if res is _MissingKey:
+            return default
+        return res
+
+    def nd_sort(self, dim_keys=None):
+        '''Produce indices ordered so as to fill an n-D array'''
+
+class SummaryTree:
+    '''Groups incoming meta data and creates hierarchy of related groups
+
+    Each leaf node in the tree is a `MetaSummary`
+    '''
+    def __init__(self, group_keys):
+        self._group_keys = group_keys
+        self._group_summaries= {}
+
+    def add(self, meta):
+        pass
+
+    def groups(self):
+        '''Generate the groups and their meta summaries'''
+
diff --git a/nibabel/tests/test_metasum.py b/nibabel/tests/test_metasum.py
new file mode 100644
index 0000000000..e69de29bb2

From cb3222bb1b489a17296357a115c27d44fbbd5d61 Mon Sep 17 00:00:00 2001
From: Brendan Moloney <moloney.brendan@gmail.com>
Date: Fri, 9 Jul 2021 10:00:06 -0700
Subject: [PATCH 2/7] WIP: Basics mostly working, needs more testing and finish
 ndSort

---
 nibabel/metasum.py            | 245 +++++++++++++++++++++++++---------
 nibabel/tests/test_metasum.py |  63 +++++++++
 2 files changed, 245 insertions(+), 63 deletions(-)

diff --git a/nibabel/metasum.py b/nibabel/metasum.py
index 9dc5dfe5af..7daeb8ac04 100644
--- a/nibabel/metasum.py
+++ b/nibabel/metasum.py
@@ -6,14 +6,18 @@
 #   copyright and license terms.
 #
 ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
-'''Aggregate information for mutliple images
+'''Memory efficient tracking of meta data dicts with repeating elements
 '''
+from dataclasses import dataclass
+from enum import IntEnum
+
 from bitarray import bitarray, frozenbitarray
-from bitarry.utils import zeroes
+from bitarray.util import zeros
 
 
-class FloatCanon(object):
+class FloatCanon:
     '''Look up a canonical float that we compare equal to'''
+
     def __init__(self, n_digits=6):
         self._n_digits = n_digits
         self._offset = 0.5 * (10 ** -n_digits)
@@ -39,7 +43,9 @@ def get(self, val):
 
 # TODO: Integrate some value canonicalization filtering? Or just require the
 #       user to do that themselves?
-class ValueIndices(object):
+
+
+class ValueIndices:
     """Track indices of values in sequence.
 
     If values repeat frequently then memory usage can be dramatically improved.
@@ -114,19 +120,31 @@ def get_mask(self, value):
             return res
         idx = self._unique_vals.get(value)
         if idx is not None:
-            res = zeroes(self._n_inpuf)
+            res = zeros(self._n_inpuf)
             res[idx] = 1
             return res
         return self._val_bitarrs[value].copy()
 
-    def num_indices(self, value):
+    def num_indices(self, value, mask=None):
         '''Number of indices for the given `value`'''
+        if mask is not None:
+            if len(mask) != self.n_input:
+                raise ValueError("Mask length must match input length")
         if self._const_val is not _NoValue:
             if self._const_val != value:
                 raise KeyError()
-            return self._n_input
-        if value in self._unique_vals:
+            if mask is None:
+                return self._n_input
+            return mask.count()
+        unique_idx = self._unique_vals.get(_NoValue)
+        if unique_idx is not _NoValue:
+            if mask is not None:
+                if mask[unique_idx]:
+                    return 1
+                return 0
             return 1
+        if mask is not None:
+            return (self._val_bitarrs[value] & mask).count
         return self._val_bitarrs[value].count()
 
     def get_value(self, idx):
@@ -138,13 +156,17 @@ def get_value(self, idx):
         for val, vidx in self._unique_vals.items():
             if vidx == idx:
                 return val
-        bit_idx = zeroes(self._n_input)
+        bit_idx = zeros(self._n_input)
         bit_idx[idx] = 1
         for val, ba in self._val_bitarrs.items():
-            if (ba | bit_idx).any():
+            if (ba & bit_idx).any():
                 return val
         assert False
 
+    def to_list(self):
+        '''Convert back to a list of values'''
+        return [self.get_value(i) for i in range(self.n_input)]
+
     def extend(self, values):
         '''Add more values to the end of any existing ones'''
         curr_size = self._n_input
@@ -156,7 +178,7 @@ def extend(self, values):
             other_size = len(values)
         final_size = curr_size + other_size
         for ba in self._val_bitarrs.values():
-            ba.extend(zeroes(other_size))
+            ba.extend(zeros(other_size))
         if other_is_vi:
             if self._const_val is not _NoValue:
                 if values._const_val is not _NoValue:
@@ -186,10 +208,10 @@ def extend(self, values):
                         if curr_size == 0:
                             new_ba = other_ba.copy()
                         else:
-                            new_ba = zeroes(curr_size)
+                            new_ba = zeros(curr_size)
                             new_ba.extend(other_ba)
                     else:
-                        new_ba = zeroes(curr_size)
+                        new_ba = zeros(curr_size)
                         new_ba[curr_idx] = True
                         new_ba.extend(other_ba)
                         del self._unique_vals[val]
@@ -221,13 +243,20 @@ def append(self, value):
             if curr_idx is None:
                 self._unique_vals[value] = curr_size
             else:
-                new_ba = zeroes(curr_size + 1)
+                new_ba = zeros(curr_size + 1)
                 new_ba[curr_idx] = True
                 new_ba[curr_size] = True
                 self._val_bitarrs[value] = new_ba
                 del self._unique_vals[value]
         self._n_input += 1
 
+    def reverse(self):
+        '''Reverse the indices in place'''
+        for val, idx in self._unique_vals.items():
+            self._unique_vals[val] = self._n_input - idx - 1
+        for val, bitarr in self._val_bitarrs.items():
+            bitarr.reverse()
+
     def argsort(self, reverse=False):
         '''Return array of indices in order that sorts the values'''
         if self._const_val is not _NoValue:
@@ -248,6 +277,18 @@ def argsort(self, reverse=False):
                 res_idx += 1
         return res
 
+    def reorder(self, order):
+        '''Reorder the indices in place'''
+        if len(order) != self._n_input:
+            raise ValueError("The 'order' has the incorrect length")
+        for val, idx in self._unique_vals.items():
+            self._unique_vals[val] = order.index(idx)
+        for val, bitarr in self._val_bitarrs.items():
+            new_ba = zeros(self._n_input)
+            for idx in self._extract_indices(bitarr):
+                new_ba[order.index(idx)] = True
+            self._val_bitarrs[val] = new_ba
+
     def is_covariant(self, other):
         '''True if `other` has values that vary the same way ours do
 
@@ -267,27 +308,22 @@ def is_covariant(self, other):
             return False
         return True
 
-    def is_blocked(self, block_factor=None):
-        '''True if each value has the same number of indices
+    def get_block_size(self):
+        '''Return size of even blocks of values, or None if values aren't "blocked"
 
-        If `block_factor` is not None we also test that it evenly divides the
-        block size.
+        The number of values must evenly divide the number of inputs into the block size,
+        with each value appearing that same number of times.
         '''
         block_size, rem = divmod(self._n_input, len(self))
         if rem != 0:
-            return False
-        if block_factor is not None and block_size % block_factor != 0:
-            return False
+            return None
         for val in self.values():
             if self.num_indices(val) != block_size:
-                return False
-        return True
+                return None
+        return block_size
 
     def is_subpartition(self, other):
-        '''True if we have more values and they nest within values from other
-
-
-        '''
+        ''''''
 
     def _extract_indices(self, ba):
         '''Generate integer indices from bitarray representation'''
@@ -295,7 +331,7 @@ def _extract_indices(self, ba):
         while True:
             try:
                 # TODO: Is this the most efficient approach?
-                curr_idx = ba.index(True, start=start)
+                curr_idx = ba.index(True, start)
             except ValueError:
                 return
             yield curr_idx
@@ -309,10 +345,10 @@ def _ingest_single(self, val, final_size, curr_size, other_idx):
             if curr_idx is None:
                 self._unique_vals[val] = curr_size + other_idx
             else:
-                new_ba = zeroes(final_size)
+                new_ba = zeros(final_size)
                 new_ba[curr_idx] = True
                 new_ba[curr_size + other_idx] = True
-                self._val_bitarrs = new_ba
+                self._val_bitarrs[val] = new_ba
                 del self._unique_vals[val]
         else:
             curr_ba[curr_size + other_idx] = True
@@ -351,6 +387,25 @@ def _extend_const(self, other):
 _MissingKey = object()
 
 
+class DimTypes(IntEnum):
+    '''Enmerate the three types of nD dimensions'''
+    SLICE = 1
+    TIME = 2
+    PARAM = 3
+
+
+@dataclass
+class DimIndex:
+    '''Specify an nD index'''
+    dim_type: DimTypes
+
+    key: str
+
+
+class NdSortError(Exception):
+    '''Raised when the data cannot be sorted into an nD array as specified'''
+
+
 class MetaSummary:
     '''Summarize a sequence of dicts, tracking how individual keys vary
 
@@ -358,6 +413,7 @@ class MetaSummary:
     least repeated, and thus we can reduce memory consumption by only storing
     the value once along with the indices it appears at.
     '''
+
     def __init__(self):
         self._v_idxs = {}
         self._n_input = 0
@@ -380,9 +436,6 @@ def append(self, meta):
             self._v_idxs[key] = v_idx
         self._n_input += 1
 
-    def extend(self, metas):
-        pass # TODO
-
     def keys(self):
         '''Generate all known keys'''
         return self._v_idxs.keys()
@@ -412,20 +465,26 @@ def repeating_keys(self):
             if 1 < len(v_idx) < n_input:
                 yield key
 
-    def repeating_groups(self, block_only=False, block_factor=None):
-        '''Generate groups of repeating keys that vary with the same pattern
+    def covariant_groups(self, keys=None, block_only=False):
+        '''Generate groups of keys that vary with the same pattern
         '''
-        n_input = self._n_input
-        if n_input <= 1:
-            # If there is only one element, consider all keys as const
-            return
-        # TODO: Can we sort so grouped v_idxs are sequential?
-        #         - Sort by num values isn't sufficient
-        curr_group = []
-        for key, v_idx in self._v_idxs.items():
-            if 1 < len(v_idx) < n_input:
-                if v_idx.is_even(block_factor):
-                pass # TODO
+        if keys is None:
+            keys = self.keys()
+        groups = []
+        for key in keys:
+            v_idx = self._v_idxs[key]
+            if len(groups) == 0:
+                groups.append((key, v_idx))
+                continue
+            for group in groups:
+                if group[0][1].is_covariant(v_idx):
+                    group.append(key)
+                    break
+            else:
+                groups.append((key, v_idx))
+        for group in groups:
+            group[0] = group[0][0]
+        return groups
 
     def get_meta(self, idx):
         '''Get the full dict at the given index'''
@@ -439,26 +498,86 @@ def get_meta(self, idx):
 
     def get_val(self, idx, key, default=None):
         '''Get the value at `idx` for the `key`, or return `default``'''
-        res = self._v_idxs[key].get_value(key)
+        res = self._v_idxs[key].get_value(idx)
         if res is _MissingKey:
             return default
         return res
 
-    def nd_sort(self, dim_keys=None):
-        '''Produce indices ordered so as to fill an n-D array'''
+    def reorder(self, order):
+        '''Reorder indices in place'''
+        for v_idx in self._v_idxs.values():
+            v_idx.reorder(order)
 
-class SummaryTree:
-    '''Groups incoming meta data and creates hierarchy of related groups
-
-    Each leaf node in the tree is a `MetaSummary`
-    '''
-    def __init__(self, group_keys):
-        self._group_keys = group_keys
-        self._group_summaries= {}
-
-    def add(self, meta):
-        pass
-
-    def groups(self):
-        '''Generate the groups and their meta summaries'''
+    def nd_sort(self, dims):
+        '''Produce linear indices to fill nD array as specified by `dims`
 
+        Assumes each input corresponds to a 2D or 3D array, and the combined
+        array is 3D+
+        '''
+        # Make sure dims aren't completely invalid
+        if len(dims) == 0:
+            raise ValueError("At least one dimension must be specified")
+        last_dim = None
+        for dim in dims:
+            if last_dim is not None:
+                if last_dim.dim_type > dim.dim_type:
+                    # TODO: This only allows PARAM dimensions at the end, which I guess is reasonable?
+                    raise ValueError("Invalid dimension order")
+                elif last_dim.dim_type == dim.dim_type and dim.dim_type != DimTypes.PARAM:
+                    raise ValueError("There can be at most one each of SLICE and TIME dimensions")
+            last_dim = dim
+
+        # Pull out info about different types of dims
+        n_slices = None
+        n_vol = None
+        time_dim = None
+        param_dims = []
+        n_params = []
+        total_params = 1
+        shape = []
+        curr_size = 1
+        for dim in dims:
+            dim_vidx = self._v_idxs[dim.key]
+            dim_type = dim.dim_type
+            if dim_type is DimTypes.SLICE:
+                n_slices = len(dim_vidx)
+                n_vol = dim_vidx.get_block_size()
+                if n_vol is None:
+                    raise NdSortError("There are missing or extra slices")
+                shape.append(n_slices)
+                curr_size *= n_slices
+            elif dim_type is DimTypes.TIME:
+                time_dim = dim
+            elif dim_type is DimTypes.PARAM:
+                if dim_vidx.get_block_size() is None:
+                    raise NdSortError(f"The parameter {dim.key} doesn't evenly divide inputs")
+                param_dims.append(dim)
+                n_param = len(dim_vidx)
+                n_params.append(n_param)
+                total_params *= n_param
+        if n_vol is None:
+            n_vol = self._n_input
+
+        # Size of the time dimension must be infered from the size of the other dims
+        n_time = 1
+        if time_dim is not None:
+            n_time, rem = divmod(n_vol, total_params)
+            if rem != 0:
+                raise NdSortError(f"The combined parameters don't evenly divide inputs")
+            shape.append(n_time)
+            curr_size *= n_time
+
+        # Complete the "shape", and do a more detailed check that our param dims make sense
+        for dim, n_param in zip(param_dims, n_params):
+            dim_vidx = self._v_idxs[dim.key]
+            if dim_vidx.get_block_size() != curr_size:
+                raise NdSortError(f"The parameter {dim.key} doesn't evenly divide inputs")
+            shape.append(n_param)
+            curr_size *= n_param
+
+        # Extract dim keys for each input and do the actual sort
+        sort_keys = [(idx, tuple(self.get_val(idx, dim.key) for dim in reversed(dims)))
+                     for idx in range(self._n_input)]
+        sort_keys.sort(key=lambda x: x[1])
+
+        # TODO: Finish this
diff --git a/nibabel/tests/test_metasum.py b/nibabel/tests/test_metasum.py
index e69de29bb2..c654e82614 100644
--- a/nibabel/tests/test_metasum.py
+++ b/nibabel/tests/test_metasum.py
@@ -0,0 +1,63 @@
+from ..metasum import MetaSummary, ValueIndices
+
+import pytest
+
+
+vidx_test_patterns = ([0] * 8,
+                      ([0] * 4) + ([1] * 4),
+                      [0, 0, 1, 2, 3, 3, 3, 4],
+                      list(range(8)),
+                      list(range(6)) + [6] * 2,
+                      ([0] * 2) + list(range(2, 8)),
+                      )
+
+
+@pytest.mark.parametrize("in_list", vidx_test_patterns)
+def test_value_indices_rt(in_list):
+    '''Test we can roundtrip list -> ValueIndices -> list'''
+    vidx = ValueIndices(in_list)
+    out_list = vidx.to_list()
+    assert in_list == out_list
+
+
+@pytest.mark.parametrize("in_list", vidx_test_patterns)
+def test_value_indices_append_extend(in_list):
+    '''Test that append/extend are equivalent'''
+    vidx_list = [ValueIndices() for _ in range(4)]
+    vidx_list[0].extend(in_list)
+    vidx_list[0].extend(in_list)
+    for val in in_list:
+        vidx_list[1].append(val)
+    for val in in_list:
+        vidx_list[1].append(val)
+    vidx_list[2].extend(in_list)
+    for val in in_list:
+        vidx_list[2].append(val)
+    for val in in_list:
+        vidx_list[3].append(val)
+    vidx_list[3].extend(in_list)
+    for vidx in vidx_list:
+        assert vidx.to_list() == in_list + in_list
+
+
+metasum_test_dicts = (({'key1': 0, 'key2': 'a', 'key3': 3.0},
+                       {'key1': 2, 'key2': 'c', 'key3': 1.0},
+                       {'key1': 1, 'key2': 'b', 'key3': 2.0},
+                       ),
+                      ({'key1': 0, 'key2': 'a', 'key3': 3.0},
+                       {'key1': 2, 'key2': 'c'},
+                       {'key1': 1, 'key2': 'b', 'key3': 2.0},
+                       ),
+                      )
+
+
+@pytest.mark.parametrize("in_dicts", metasum_test_dicts)
+def test_meta_summary_rt(in_dicts):
+    msum = MetaSummary()
+    for in_dict in in_dicts:
+        msum.append(in_dict)
+    for in_idx in range(len(in_dicts)):
+        out_dict = msum.get_meta(in_idx)
+        assert out_dict == in_dicts[in_idx]
+        for key, in_val in in_dicts[in_idx].items():
+            assert in_val == msum.get_val(in_idx, key)

From c21a8fd46014daec584d8133cf7b22acb0dcbec2 Mon Sep 17 00:00:00 2001
From: Brendan Moloney <moloney.brendan@gmail.com>
Date: Fri, 9 Jul 2021 16:32:58 -0700
Subject: [PATCH 3/7] TST+BF: Expand tests and fix bugs

---
 nibabel/metasum.py            | 68 ++++++++++++++++++++++-------------
 nibabel/tests/test_metasum.py | 33 ++++++++++++-----
 2 files changed, 69 insertions(+), 32 deletions(-)

diff --git a/nibabel/metasum.py b/nibabel/metasum.py
index 7daeb8ac04..d1e84dad55 100644
--- a/nibabel/metasum.py
+++ b/nibabel/metasum.py
@@ -125,7 +125,7 @@ def get_mask(self, value):
             return res
         return self._val_bitarrs[value].copy()
 
-    def num_indices(self, value, mask=None):
+    def count(self, value, mask=None):
         '''Number of indices for the given `value`'''
         if mask is not None:
             if len(mask) != self.n_input:
@@ -136,7 +136,7 @@ def num_indices(self, value, mask=None):
             if mask is None:
                 return self._n_input
             return mask.count()
-        unique_idx = self._unique_vals.get(_NoValue)
+        unique_idx = self._unique_vals.get(value, _NoValue)
         if unique_idx is not _NoValue:
             if mask is not None:
                 if mask[unique_idx]:
@@ -144,7 +144,7 @@ def num_indices(self, value, mask=None):
                 return 0
             return 1
         if mask is not None:
-            return (self._val_bitarrs[value] & mask).count
+            return (self._val_bitarrs[value] & mask).count()
         return self._val_bitarrs[value].count()
 
     def get_value(self, idx):
@@ -169,14 +169,14 @@ def to_list(self):
 
     def extend(self, values):
         '''Add more values to the end of any existing ones'''
-        curr_size = self._n_input
+        init_size = self._n_input
         if isinstance(values, ValueIndices):
             other_is_vi = True
             other_size = values._n_input
         else:
             other_is_vi = False
             other_size = len(values)
-        final_size = curr_size + other_size
+        final_size = init_size + other_size
         for ba in self._val_bitarrs.values():
             ba.extend(zeros(other_size))
         if other_is_vi:
@@ -185,7 +185,7 @@ def extend(self, values):
                     self._extend_const(values)
                     return
                 else:
-                    self._rm_const()
+                    self._rm_const(final_size)
             elif values._const_val is not _NoValue:
                 cval = values._const_val
                 other_unique = {}
@@ -199,29 +199,30 @@ def extend(self, values):
                 other_unique = values._unique_vals
                 other_bitarrs = values._val_bitarrs
             for val, other_idx in other_unique.items():
-                self._ingest_single(val, final_size, curr_size, other_idx)
+                self._ingest_single(val, final_size, init_size, other_idx)
             for val, other_ba in other_bitarrs.items():
                 curr_ba = self._val_bitarrs.get(val)
                 if curr_ba is None:
                     curr_idx = self._unique_vals.get(val)
                     if curr_idx is None:
-                        if curr_size == 0:
+                        if init_size == 0:
                             new_ba = other_ba.copy()
                         else:
-                            new_ba = zeros(curr_size)
+                            new_ba = zeros(init_size)
                             new_ba.extend(other_ba)
                     else:
-                        new_ba = zeros(curr_size)
+                        new_ba = zeros(init_size)
                         new_ba[curr_idx] = True
                         new_ba.extend(other_ba)
                         del self._unique_vals[val]
                     self._val_bitarrs[val] = new_ba
                 else:
-                    curr_ba[curr_size:] |= other_ba
+                    curr_ba[init_size:] |= other_ba
+                self._n_input += other_ba.count()
         else:
             for other_idx, val in enumerate(values):
-                self._ingest_single(val, final_size, curr_size, other_idx)
-        self._n_input = final_size
+                self._ingest_single(val, final_size, init_size, other_idx)
+        assert self._n_input == final_size
 
     def append(self, value):
         '''Append another value as input'''
@@ -229,10 +230,18 @@ def append(self, value):
             self._n_input += 1
             return
         elif self._const_val is not _NoValue:
-            self._rm_const()
+            self._rm_const(self._n_input + 1)
+            self._unique_vals[value] = self._n_input
+            self._n_input += 1
+            return
+        if self._n_input == 0:
+            self._const_val = value
+            self._n_input += 1
+            return
         curr_size = self._n_input
         found = False
         for val, bitarr in self._val_bitarrs.items():
+            assert len(bitarr) == self._n_input
             if val == value:
                 found = True
                 bitarr.append(True)
@@ -318,7 +327,7 @@ def get_block_size(self):
         if rem != 0:
             return None
         for val in self.values():
-            if self.num_indices(val) != block_size:
+            if self.count(val) != block_size:
                 return None
         return block_size
 
@@ -335,32 +344,43 @@ def _extract_indices(self, ba):
             except ValueError:
                 return
             yield curr_idx
-            start = curr_idx
+            start = curr_idx + 1
 
-    def _ingest_single(self, val, final_size, curr_size, other_idx):
+    def _ingest_single(self, val, final_size, init_size, other_idx):
         '''Helper to ingest single value from another collection'''
+        if val == self._const_val:
+            self._n_input += 1
+            return
+        elif self._const_val is not _NoValue:
+            self._rm_const(final_size)
+        if self._n_input == 0:
+            self._const_val = val
+            self._n_input += 1
+            return
+
         curr_ba = self._val_bitarrs.get(val)
         if curr_ba is None:
             curr_idx = self._unique_vals.get(val)
             if curr_idx is None:
-                self._unique_vals[val] = curr_size + other_idx
+                self._unique_vals[val] = init_size + other_idx
             else:
                 new_ba = zeros(final_size)
                 new_ba[curr_idx] = True
-                new_ba[curr_size + other_idx] = True
+                new_ba[init_size + other_idx] = True
                 self._val_bitarrs[val] = new_ba
                 del self._unique_vals[val]
         else:
-            curr_ba[curr_size + other_idx] = True
+            curr_ba[init_size + other_idx] = True
+        self._n_input += 1
 
-    def _rm_const(self):
+    def _rm_const(self, final_size):
         assert self._const_val is not _NoValue
         if self._n_input == 1:
             self._unique_vals[self._const_val] = 0
         else:
-            self._val_bitarrs[self._const_val] = bitarray(self._n_input)
-            self._val_bitarrs[self._const_val].setall(1)
-        self._const_val == _NoValue
+            self._val_bitarrs[self._const_val] = zeros(final_size)
+            self._val_bitarrs[self._const_val][:self._n_input] = True
+        self._const_val = _NoValue
 
     def _extend_const(self, other):
         if self._const_val != other._const_val:
diff --git a/nibabel/tests/test_metasum.py b/nibabel/tests/test_metasum.py
index c654e82614..c0aced4d2a 100644
--- a/nibabel/tests/test_metasum.py
+++ b/nibabel/tests/test_metasum.py
@@ -13,9 +13,16 @@
 
 
 @pytest.mark.parametrize("in_list", vidx_test_patterns)
-def test_value_indices_rt(in_list):
+def test_value_indices_basics(in_list):
     '''Test we can roundtrip list -> ValueIndices -> list'''
     vidx = ValueIndices(in_list)
+    assert vidx.n_input == len(in_list)
+    assert len(vidx) == len(set(in_list))
+    assert sorted(vidx.values()) == sorted(list(set(in_list)))
+    for val in vidx.values():
+        assert vidx.count(val) == in_list.count(val)
+        for in_idx in vidx[val]:
+            assert in_list[in_idx] == val
     out_list = vidx.to_list()
     assert in_list == out_list
 
@@ -40,22 +47,32 @@ def test_value_indices_append_extend(in_list):
         assert vidx.to_list() == in_list + in_list
 
 
-metasum_test_dicts = (({'key1': 0, 'key2': 'a', 'key3': 3.0},
-                       {'key1': 2, 'key2': 'c', 'key3': 1.0},
-                       {'key1': 1, 'key2': 'b', 'key3': 2.0},
+metasum_test_dicts = (({'u1': 0, 'u2': 'a', 'u3': 3.0, 'c1': True, 'r1': 5},
+                       {'u1': 2, 'u2': 'c', 'u3': 1.0, 'c1': True, 'r1': 5},
+                       {'u1': 1, 'u2': 'b', 'u3': 2.0, 'c1': True, 'r1': 7},
                        ),
-                      ({'key1': 0, 'key2': 'a', 'key3': 3.0},
-                       {'key1': 2, 'key2': 'c'},
-                       {'key1': 1, 'key2': 'b', 'key3': 2.0},
+                      ({'u1': 0, 'u2': 'a', 'u3': 3.0, 'c1': True, 'r1': 5},
+                       {'u1': 2, 'u2': 'c', 'c1': True, 'r1': 5},
+                       {'u1': 1, 'u2': 'b', 'u3': 2.0, 'c1': True},
                        ),
                       )
 
 
 @pytest.mark.parametrize("in_dicts", metasum_test_dicts)
-def test_meta_summary_rt(in_dicts):
+def test_meta_summary_basics(in_dicts):
     msum = MetaSummary()
+    all_keys = set()
     for in_dict in in_dicts:
         msum.append(in_dict)
+        for key in in_dict.keys():
+            all_keys.add(key)
+    assert all_keys == set(msum.keys())
+    for key in msum.const_keys():
+        assert key.startswith('c')
+    for key in msum.unique_keys():
+        assert key.startswith('u')
+    for key in msum.repeating_keys():
+        assert key.startswith('r')
     for in_idx in range(len(in_dicts)):
         out_dict = msum.get_meta(in_idx)
         assert out_dict == in_dicts[in_idx]

From 4ba6a733b0e4e711914f2acfe32ec0928880dc3d Mon Sep 17 00:00:00 2001
From: Brendan Moloney <moloney.brendan@gmail.com>
Date: Fri, 9 Jul 2021 16:35:57 -0700
Subject: [PATCH 4/7] BF: Add bitarray dependency

---
 setup.cfg | 1 +
 1 file changed, 1 insertion(+)

diff --git a/setup.cfg b/setup.cfg
index 85aebfee7d..23ac6fce0b 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -31,6 +31,7 @@ python_requires = >=3.6
 install_requires =
     numpy >=1.13
     packaging >=14.3
+    bitarray
 zip_safe = False
 packages = find:
 

From 583e0aa551cf1c790e5d299180356cdebbe46b22 Mon Sep 17 00:00:00 2001
From: Brendan Moloney <moloney.brendan@gmail.com>
Date: Fri, 9 Jul 2021 16:58:24 -0700
Subject: [PATCH 5/7] ENH: Make ValueIndices.to_list much more efficient

---
 nibabel/metasum.py | 10 +++++++++-
 1 file changed, 9 insertions(+), 1 deletion(-)

diff --git a/nibabel/metasum.py b/nibabel/metasum.py
index d1e84dad55..2312dc2df4 100644
--- a/nibabel/metasum.py
+++ b/nibabel/metasum.py
@@ -165,7 +165,15 @@ def get_value(self, idx):
 
     def to_list(self):
         '''Convert back to a list of values'''
-        return [self.get_value(i) for i in range(self.n_input)]
+        if self._const_val is not _NoValue:
+            return [self._const_val] * self._n_input
+        res = [_NoValue] * self._n_input
+        for val, idx in self._unique_vals.items():
+            res[idx] = val
+        for val, ba in self._val_bitarrs.items():
+            for idx in self._extract_indices(ba):
+                res[idx] = val
+        return res
 
     def extend(self, values):
         '''Add more values to the end of any existing ones'''

From 8946f163e5b9cfe4a140d353d5f37dcfdd31fe8d Mon Sep 17 00:00:00 2001
From: moloney <moloney.brendan@gmail.com>
Date: Mon, 12 Jul 2021 12:24:38 -0700
Subject: [PATCH 6/7] BF: Add dataclasses backport for 3.6

Co-authored-by: Chris Markiewicz <effigies@gmail.com>
---
 setup.cfg | 1 +
 1 file changed, 1 insertion(+)

diff --git a/setup.cfg b/setup.cfg
index 23ac6fce0b..c84e8cc894 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -32,6 +32,7 @@ install_requires =
     numpy >=1.13
     packaging >=14.3
     bitarray
+    dataclasses ; python_version < "3.7"
 zip_safe = False
 packages = find:
 

From 59fab276d0730989c516158672cd54be7a12bf06 Mon Sep 17 00:00:00 2001
From: Brendan Moloney <moloney.brendan@gmail.com>
Date: Mon, 12 Jul 2021 20:29:25 -0700
Subject: [PATCH 7/7] ENH: Get the nd_sort method mostly working w/ basic tests

---
 nibabel/metasum.py            | 60 +++++++++++++++++++-------
 nibabel/tests/test_metasum.py | 79 +++++++++++++++++++++++++++++++++--
 2 files changed, 120 insertions(+), 19 deletions(-)

diff --git a/nibabel/metasum.py b/nibabel/metasum.py
index 2312dc2df4..2183cc249a 100644
--- a/nibabel/metasum.py
+++ b/nibabel/metasum.py
@@ -100,6 +100,15 @@ def __getitem__(self, value):
         ba = self._val_bitarrs[value]
         return list(self._extract_indices(ba))
 
+    def first(self, value):
+        '''Return the first index where this value appears'''
+        if self._const_val == value:
+            return 0
+        idx = self._unique_vals.get(value)
+        if idx is not None:
+            return idx
+        return self._val_bitarrs[value].index(True)
+
     def values(self):
         '''Generate each unique value that has been seen'''
         if self._const_val is not _NoValue:
@@ -339,8 +348,15 @@ def get_block_size(self):
                 return None
         return block_size
 
-    def is_subpartition(self, other):
-        ''''''
+    def is_orthogonal(self, other, size=1):
+        '''Check our value's indices overlaps each from `other` exactly `size` times
+        '''
+        other_bas = {v: other.get_mask(v) for v in other.values()}
+        for val in self.values():
+            for other_val, other_ba in other_bas.items():
+                if self.count(val, mask=other_ba) != size:
+                    return False
+        return True
 
     def _extract_indices(self, ba):
         '''Generate integer indices from bitarray representation'''
@@ -416,7 +432,7 @@ def _extend_const(self, other):
 
 
 class DimTypes(IntEnum):
-    '''Enmerate the three types of nD dimensions'''
+    '''Enumerate the three types of nD dimensions'''
     SLICE = 1
     TIME = 2
     PARAM = 3
@@ -556,8 +572,9 @@ def nd_sort(self, dims):
             last_dim = dim
 
         # Pull out info about different types of dims
-        n_slices = None
-        n_vol = None
+        n_input = self._n_input
+        total_vol = None
+        slice_dim = None
         time_dim = None
         param_dims = []
         n_params = []
@@ -568,9 +585,10 @@ def nd_sort(self, dims):
             dim_vidx = self._v_idxs[dim.key]
             dim_type = dim.dim_type
             if dim_type is DimTypes.SLICE:
+                slice_dim = dim
                 n_slices = len(dim_vidx)
-                n_vol = dim_vidx.get_block_size()
-                if n_vol is None:
+                total_vol = dim_vidx.get_block_size()
+                if total_vol is None:
                     raise NdSortError("There are missing or extra slices")
                 shape.append(n_slices)
                 curr_size *= n_slices
@@ -583,29 +601,39 @@ def nd_sort(self, dims):
                 n_param = len(dim_vidx)
                 n_params.append(n_param)
                 total_params *= n_param
-        if n_vol is None:
-            n_vol = self._n_input
+        if total_vol is None:
+            total_vol = n_input
 
-        # Size of the time dimension must be infered from the size of the other dims
+        # Size of the time dimension must be inferred from the size of the other dims
         n_time = 1
+        prev_dim = slice_dim
         if time_dim is not None:
-            n_time, rem = divmod(n_vol, total_params)
+            n_time, rem = divmod(total_vol, total_params)
             if rem != 0:
-                raise NdSortError(f"The combined parameters don't evenly divide inputs")
+                raise NdSortError("The combined parameters don't evenly divide inputs")
             shape.append(n_time)
             curr_size *= n_time
+            prev_dim = time_dim
 
-        # Complete the "shape", and do a more detailed check that our param dims make sense
+        # Complete the "shape", and do a more detailed check that our dims make sense
         for dim, n_param in zip(param_dims, n_params):
             dim_vidx = self._v_idxs[dim.key]
-            if dim_vidx.get_block_size() != curr_size:
+            if dim_vidx.get_block_size() != n_input // n_param:
                 raise NdSortError(f"The parameter {dim.key} doesn't evenly divide inputs")
+            if prev_dim is not None and prev_dim.dim_type != DimTypes.TIME:
+                count_per = (curr_size // shape[-1]) * (n_input // (curr_size * n_param))
+                if not self._v_idxs[prev_dim.key].is_orthogonal(dim_vidx,  count_per):
+                    raise NdSortError("The dimensions are not orthogonal")
             shape.append(n_param)
             curr_size *= n_param
+            prev_dim = dim
 
         # Extract dim keys for each input and do the actual sort
         sort_keys = [(idx, tuple(self.get_val(idx, dim.key) for dim in reversed(dims)))
-                     for idx in range(self._n_input)]
+                     for idx in range(n_input)]
         sort_keys.sort(key=lambda x: x[1])
 
-        # TODO: Finish this
+        # TODO: If we have non-singular time dimension we need to do some additional
+        #       validation checks here after sorting.
+
+        return tuple(shape), [x[0] for x in sort_keys]
diff --git a/nibabel/tests/test_metasum.py b/nibabel/tests/test_metasum.py
index c0aced4d2a..258fb76e87 100644
--- a/nibabel/tests/test_metasum.py
+++ b/nibabel/tests/test_metasum.py
@@ -1,6 +1,9 @@
-from ..metasum import MetaSummary, ValueIndices
+import random
 
 import pytest
+import numpy as np
+
+from ..metasum import DimIndex, DimTypes, MetaSummary, ValueIndices
 
 
 vidx_test_patterns = ([0] * 8,
@@ -14,7 +17,7 @@
 
 @pytest.mark.parametrize("in_list", vidx_test_patterns)
 def test_value_indices_basics(in_list):
-    '''Test we can roundtrip list -> ValueIndices -> list'''
+    '''Test basic ValueIndices behavior'''
     vidx = ValueIndices(in_list)
     assert vidx.n_input == len(in_list)
     assert len(vidx) == len(set(in_list))
@@ -22,7 +25,7 @@ def test_value_indices_basics(in_list):
     for val in vidx.values():
         assert vidx.count(val) == in_list.count(val)
         for in_idx in vidx[val]:
-            assert in_list[in_idx] == val
+            assert in_list[in_idx] == val == vidx.get_value(in_idx)
     out_list = vidx.to_list()
     assert in_list == out_list
 
@@ -78,3 +81,73 @@ def test_meta_summary_basics(in_dicts):
         assert out_dict == in_dicts[in_idx]
         for key, in_val in in_dicts[in_idx].items():
             assert in_val == msum.get_val(in_idx, key)
+
+
+def _make_nd_meta(shape, dim_info, const_meta=None):
+    if const_meta is None:
+        const_meta = {'series_number': '5'}
+    meta_seq = []
+    for nd_idx in np.ndindex(*shape):
+        curr_meta = {}
+        curr_meta.update(const_meta)
+        for dim, dim_idx in zip(dim_info, nd_idx):
+            curr_meta[dim.key] = dim_idx
+        meta_seq.append(curr_meta)
+    return meta_seq
+
+
+ndsort_test_args = (((3,),
+                     (DimIndex(DimTypes.SLICE, 'slice_location'),),
+                     None),
+                    ((3, 5),
+                     (DimIndex(DimTypes.SLICE, 'slice_location'),
+                      DimIndex(DimTypes.TIME, 'acq_time')),
+                     None),
+                    ((3, 5),
+                     (DimIndex(DimTypes.SLICE, 'slice_location'),
+                      DimIndex(DimTypes.PARAM, 'inversion_time')),
+                     None),
+                    ((3, 5, 7),
+                     (DimIndex(DimTypes.SLICE, 'slice_location'),
+                      DimIndex(DimTypes.TIME, 'acq_time'),
+                      DimIndex(DimTypes.PARAM, 'echo_time')),
+                     None),
+                    ((3, 5, 7),
+                     (DimIndex(DimTypes.SLICE, 'slice_location'),
+                      DimIndex(DimTypes.PARAM, 'inversion_time'),
+                      DimIndex(DimTypes.PARAM, 'echo_time')),
+                     None),
+                    ((5, 3),
+                     (DimIndex(DimTypes.TIME, 'acq_time'),
+                      DimIndex(DimTypes.PARAM, 'echo_time')),
+                     None),
+                    ((3, 5, 7),
+                     (DimIndex(DimTypes.TIME, 'acq_time'),
+                      DimIndex(DimTypes.PARAM, 'inversion_time'),
+                      DimIndex(DimTypes.PARAM, 'echo_time')),
+                     None),
+                    ((5, 7),
+                     (DimIndex(DimTypes.PARAM, 'inversion_time'),
+                      DimIndex(DimTypes.PARAM, 'echo_time')),
+                     None),
+                    ((5, 7, 3),
+                     (DimIndex(DimTypes.PARAM, 'inversion_time'),
+                      DimIndex(DimTypes.PARAM, 'echo_time'),
+                      DimIndex(DimTypes.PARAM, 'repetition_time')),
+                     None),
+                    )
+
+
+@pytest.mark.parametrize("shape,dim_info,const_meta", ndsort_test_args)
+def test_ndsort(shape, dim_info, const_meta):
+    meta_seq = _make_nd_meta(shape, dim_info, const_meta)
+    rand_idx_seq = [(i, m) for i, m in enumerate(meta_seq)]
+    # TODO: Use some pytest plugin to manage randomness?  Just use fixed seed?
+    random.shuffle(rand_idx_seq)
+    rand_idx = [x[0] for x in rand_idx_seq]
+    rand_seq = [x[1] for x in rand_idx_seq]
+    msum = MetaSummary()
+    for meta in rand_seq:
+        msum.append(meta)
+    out_shape, out_idxs = msum.nd_sort(dim_info)
+    assert shape == out_shape