diff --git a/python/lsst/pipe/tasks/fit_coadd_multiband.py b/python/lsst/pipe/tasks/fit_coadd_multiband.py index 4af35fd98..d6e649df6 100644 --- a/python/lsst/pipe/tasks/fit_coadd_multiband.py +++ b/python/lsst/pipe/tasks/fit_coadd_multiband.py @@ -155,39 +155,19 @@ def adjustQuantum(self, inputs, outputs, label, data_id): bands_needed_set = set(bands_needed) adjusted_inputs = {} - bands_found, connection_first = None, None + inputs_to_adjust = {} + bands_found = bands_needed_set for connection_name, (connection, dataset_refs) in inputs.items(): # Datasets without bands in their dimensions should be fine if 'band' in connection.dimensions: datasets_by_band = {dref.dataId['band']: dref for dref in dataset_refs} bands_set = set(datasets_by_band.keys()) if self.config.allow_missing_bands: - # Use the first dataset found as the reference since all - # dataset types with band should have the same bands - # This will only break if one of the calexp/meas datasets - # is missing from a given band, which would surely be an - # upstream problem anyway - if bands_found is None: - bands_found, connection_first = bands_set, connection_name - if len(bands_found) == 0: - raise pipeBase.NoWorkFound( - f'DatasetRefs={dataset_refs} for {connection_name=} is empty' - ) - elif not set(bands_read_only).issubset(bands_set): - raise pipeBase.NoWorkFound( - f'DatasetRefs={dataset_refs} has {bands_set=} which is missing at least one' - f' of {bands_read_only=}' - ) - # Put the bands to fit first, then any other bands - # needed for initialization/priors only last - bands_needed = [band for band in bands_fit if band in bands_found] + [ - band for band in bands_read_only if band not in bands_found - ] - elif bands_found != bands_set: - raise RuntimeError( - f'DatasetRefs={dataset_refs} with {connection_name=} has {bands_set=} !=' - f' {bands_found=} from {connection_first=}' + if len(bands_found) == 0: + raise pipeBase.NoWorkFound( + f'DatasetRefs={dataset_refs} for {connection_name=} is empty' ) + bands_found &= bands_set # All configured bands are treated as necessary elif not bands_needed_set.issubset(bands_set): raise pipeBase.NoWorkFound( @@ -201,10 +181,21 @@ def adjustQuantum(self, inputs, outputs, label, data_id): ) # Adjust all datasets with band dimensions to include just # the needed bands, in consistent order. - adjusted_inputs[connection_name] = ( - connection, - [datasets_by_band[band] for band in bands_needed] + inputs_to_adjust[connection_name] = (connection, datasets_by_band) + + if self.config.allow_missing_bands: + bands_needed = [band for band in bands_fit if band in bands_found] + [ + band for band in bands_read_only if band not in bands_found + ] + if len(bands_needed) == 0: + raise pipeBase.NoWorkFound( + f'No common bands remaining for inputs {",".join(inputs_to_adjust.keys())}' ) + for connection_name, (connection, datasets_by_band) in inputs_to_adjust.items(): + adjusted_inputs[connection_name] = ( + connection, + [datasets_by_band[band] for band in bands_needed] + ) # Delegate to super for more checks. inputs.update(adjusted_inputs) @@ -253,6 +244,7 @@ def bands_read_only(self) -> set: ------- The set of such bands. """ + return set() class CoaddMultibandFitSubTask(pipeBase.Task, ABC): diff --git a/tests/test_fit_coadd_multiband.py b/tests/test_fit_coadd_multiband.py new file mode 100644 index 000000000..d9efe78a9 --- /dev/null +++ b/tests/test_fit_coadd_multiband.py @@ -0,0 +1,130 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import unittest + +import lsst.utils.tests +from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse +from lsst.pipe.base import NoWorkFound +from lsst.pipe.tasks.fit_coadd_multiband import ( + CoaddMultibandFitConfig, CoaddMultibandFitConnections, + CoaddMultibandFitSubConfig, CoaddMultibandFitSubTask, +) + + +class CoaddMultibandFitDummySubTask(CoaddMultibandFitSubTask): + ConfigClass = CoaddMultibandFitSubConfig + _DefaultName = "test" + + def run(self, catexps, cat_ref): + return None + + +class CoaddMultibandFitTestCase(lsst.utils.tests.TestCase): + """Tests adjustQuantum for now. Could run the task with mock data.""" + def setUp(self): + self.config = CoaddMultibandFitConfig() + self.config.fit_coadd_multiband.retarget(CoaddMultibandFitDummySubTask) + self.config.fit_coadd_multiband.bands_fit = ("g", "r") + self.connections = CoaddMultibandFitConnections(config=self.config) + self.config.freeze() + + self.universe = DimensionUniverse() + self.datasetType_coadd, self.datasetType_cat_meas = ( + DatasetType( + name=connection.name, + dimensions=connection.dimensions, + storageClass=connection.storageClass, + universe=self.universe, + ) + for connection in (self.connections.coadds, self.connections.cats_meas) + ) + self.skymap = "test" + self.tract = 0 + self.patch = 0 + self.run = "test" + kwargs_patch = {"skymap": self.skymap, "tract": self.tract, "patch": self.patch} + + self.inputs = { + "coadds": ( + self.connections.coadds, + tuple( + DatasetRef( + self.datasetType_coadd, + DataCoordinate.standardize(universe=self.universe, band=band, **kwargs_patch), + self.run, + ) + for band in ("g", "r") + ), + ), + "cats_meas": ( + self.connections.cats_meas, + tuple( + DatasetRef( + self.datasetType_cat_meas, + DataCoordinate.standardize(universe=self.universe, band=band, **kwargs_patch), + self.run, + ) + for band in ("r",) + ), + ) + } + self.universe = DimensionUniverse() + self.dataId = DataCoordinate.standardize(universe=self.universe, **kwargs_patch) + + def testAdjustQuantum(self): + inputs, outputs = self.connections.adjustQuantum( + self.inputs, outputs={}, label="test", data_id=self.dataId, + ) + self.assertEqual(len(outputs), 0) + + for name, (connection, refs) in inputs.items(): + self.assertEqual(len(refs), 1) + self.assertEqual(refs[0].dataId["band"], "r") + + def testAdjustQuantumMissingAll(self): + inputs = { + "coadds": self.inputs["coadds"], + "cats_meas": (self.connections.cats_meas, tuple()), + } + with self.assertRaises(NoWorkFound): + self.connections.adjustQuantum(inputs, outputs={}, label="test", data_id=self.dataId) + + def testAdjustQuantumStrict(self): + config = self.config.copy() + config.allow_missing_bands = False + connections = CoaddMultibandFitConnections(config=config) + + with self.assertRaises(NoWorkFound): + connections.adjustQuantum(self.inputs, outputs={}, label="test", data_id=self.dataId) + + +class MemoryTester(lsst.utils.tests.MemoryTestCase): + pass + + +def setup_module(module): + lsst.utils.tests.init() + + +if __name__ == "__main__": + lsst.utils.tests.init() + unittest.main()