Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 21 additions & 29 deletions python/lsst/pipe/tasks/fit_coadd_multiband.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,39 +155,19 @@ def adjustQuantum(self, inputs, outputs, label, data_id):
bands_needed_set = set(bands_needed)

adjusted_inputs = {}
bands_found, connection_first = None, None
inputs_to_adjust = {}
bands_found = bands_needed_set
for connection_name, (connection, dataset_refs) in inputs.items():
# Datasets without bands in their dimensions should be fine
if 'band' in connection.dimensions:
datasets_by_band = {dref.dataId['band']: dref for dref in dataset_refs}
bands_set = set(datasets_by_band.keys())
if self.config.allow_missing_bands:
# Use the first dataset found as the reference since all
# dataset types with band should have the same bands
# This will only break if one of the calexp/meas datasets
# is missing from a given band, which would surely be an
# upstream problem anyway
if bands_found is None:
bands_found, connection_first = bands_set, connection_name
if len(bands_found) == 0:
raise pipeBase.NoWorkFound(
f'DatasetRefs={dataset_refs} for {connection_name=} is empty'
)
elif not set(bands_read_only).issubset(bands_set):
raise pipeBase.NoWorkFound(
f'DatasetRefs={dataset_refs} has {bands_set=} which is missing at least one'
f' of {bands_read_only=}'
)
# Put the bands to fit first, then any other bands
# needed for initialization/priors only last
bands_needed = [band for band in bands_fit if band in bands_found] + [
band for band in bands_read_only if band not in bands_found
]
elif bands_found != bands_set:
raise RuntimeError(
f'DatasetRefs={dataset_refs} with {connection_name=} has {bands_set=} !='
f' {bands_found=} from {connection_first=}'
if len(bands_found) == 0:
raise pipeBase.NoWorkFound(
f'DatasetRefs={dataset_refs} for {connection_name=} is empty'
)
bands_found &= bands_set
# All configured bands are treated as necessary
elif not bands_needed_set.issubset(bands_set):
raise pipeBase.NoWorkFound(
Expand All @@ -201,10 +181,21 @@ def adjustQuantum(self, inputs, outputs, label, data_id):
)
# Adjust all datasets with band dimensions to include just
# the needed bands, in consistent order.
adjusted_inputs[connection_name] = (
connection,
[datasets_by_band[band] for band in bands_needed]
inputs_to_adjust[connection_name] = (connection, datasets_by_band)

if self.config.allow_missing_bands:
bands_needed = [band for band in bands_fit if band in bands_found] + [
band for band in bands_read_only if band not in bands_found
]
if len(bands_needed) == 0:
raise pipeBase.NoWorkFound(
f'No common bands remaining for inputs {",".join(inputs_to_adjust.keys())}'
)
for connection_name, (connection, datasets_by_band) in inputs_to_adjust.items():
adjusted_inputs[connection_name] = (
connection,
[datasets_by_band[band] for band in bands_needed]
)

# Delegate to super for more checks.
inputs.update(adjusted_inputs)
Expand Down Expand Up @@ -253,6 +244,7 @@ def bands_read_only(self) -> set:
-------
The set of such bands.
"""
return set()


class CoaddMultibandFitSubTask(pipeBase.Task, ABC):
Expand Down
130 changes: 130 additions & 0 deletions tests/test_fit_coadd_multiband.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# This file is part of pipe_tasks.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

import unittest

import lsst.utils.tests
from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse
from lsst.pipe.base import NoWorkFound
from lsst.pipe.tasks.fit_coadd_multiband import (
CoaddMultibandFitConfig, CoaddMultibandFitConnections,
CoaddMultibandFitSubConfig, CoaddMultibandFitSubTask,
)


class CoaddMultibandFitDummySubTask(CoaddMultibandFitSubTask):
ConfigClass = CoaddMultibandFitSubConfig
_DefaultName = "test"

def run(self, catexps, cat_ref):
return None


class CoaddMultibandFitTestCase(lsst.utils.tests.TestCase):
"""Tests adjustQuantum for now. Could run the task with mock data."""
def setUp(self):
self.config = CoaddMultibandFitConfig()
self.config.fit_coadd_multiband.retarget(CoaddMultibandFitDummySubTask)
self.config.fit_coadd_multiband.bands_fit = ("g", "r")
self.connections = CoaddMultibandFitConnections(config=self.config)
self.config.freeze()

self.universe = DimensionUniverse()
self.datasetType_coadd, self.datasetType_cat_meas = (
DatasetType(
name=connection.name,
dimensions=connection.dimensions,
storageClass=connection.storageClass,
universe=self.universe,
)
for connection in (self.connections.coadds, self.connections.cats_meas)
)
self.skymap = "test"
self.tract = 0
self.patch = 0
self.run = "test"
kwargs_patch = {"skymap": self.skymap, "tract": self.tract, "patch": self.patch}

self.inputs = {
"coadds": (
self.connections.coadds,
tuple(
DatasetRef(
self.datasetType_coadd,
DataCoordinate.standardize(universe=self.universe, band=band, **kwargs_patch),
self.run,
)
for band in ("g", "r")
),
),
"cats_meas": (
self.connections.cats_meas,
tuple(
DatasetRef(
self.datasetType_cat_meas,
DataCoordinate.standardize(universe=self.universe, band=band, **kwargs_patch),
self.run,
)
for band in ("r",)
),
)
}
self.universe = DimensionUniverse()
self.dataId = DataCoordinate.standardize(universe=self.universe, **kwargs_patch)

def testAdjustQuantum(self):
inputs, outputs = self.connections.adjustQuantum(
self.inputs, outputs={}, label="test", data_id=self.dataId,
)
self.assertEqual(len(outputs), 0)

for name, (connection, refs) in inputs.items():
self.assertEqual(len(refs), 1)
self.assertEqual(refs[0].dataId["band"], "r")

def testAdjustQuantumMissingAll(self):
inputs = {
"coadds": self.inputs["coadds"],
"cats_meas": (self.connections.cats_meas, tuple()),
}
with self.assertRaises(NoWorkFound):
self.connections.adjustQuantum(inputs, outputs={}, label="test", data_id=self.dataId)

def testAdjustQuantumStrict(self):
config = self.config.copy()
config.allow_missing_bands = False
connections = CoaddMultibandFitConnections(config=config)

with self.assertRaises(NoWorkFound):
connections.adjustQuantum(self.inputs, outputs={}, label="test", data_id=self.dataId)


class MemoryTester(lsst.utils.tests.MemoryTestCase):
pass


def setup_module(module):
lsst.utils.tests.init()


if __name__ == "__main__":
lsst.utils.tests.init()
unittest.main()