diff --git a/py/desispec/io/spectra.py b/py/desispec/io/spectra.py index 266336436..a6b6443f7 100644 --- a/py/desispec/io/spectra.py +++ b/py/desispec/io/spectra.py @@ -185,8 +185,31 @@ def write_spectra(outfile, spec, units=None): return outfile - -def read_spectra(infile, single=False): +def _read_image(hdus, extname, dtype, rows=None): + """ + Helper function to read extname from fitsio.FITS hdus, filter by rows, + convert to native endian, and cast to dtype. Returns image. + """ + data = hdus[extname].read() + if rows is not None: + data = data[rows] + + return native_endian(data).astype(dtype) + + +def read_spectra( + infile, + single=False, + targetids=None, + rows=None, + skip_hdus=None, + select_columns={ + "FIBERMAP": None, + "EXP_FIBERMAP": None, + "SCORES": None, + "EXTRA_CATALOG": None, + }, +): """ Read Spectra object from FITS file. @@ -196,10 +219,21 @@ def read_spectra(infile, single=False): Args: infile (str): path to read single (bool): if True, keep spectra as single precision in memory. + targetids (list): Optional, list of targetids to read from file, if present. + rows (list): Optional, list of rows to read from file + skip_hdus (list): Optional, list/set/tuple of HDUs to skip + select_columns (dict): Optional, dictionary to select column names to be read. Default, all columns are read. Returns (Spectra): The object containing the data read from disk. + `skip_hdus` options are FIBERMAP, EXP_FIBERMAP, SCORES, EXTRA_CATALOG, MASK, RESOLUTION; + where MASK and RESOLUTION mean to skip those for all cameras. + Note that WAVE, FLUX, and IVAR are always required. + + If a table HDU is not listed in `select_columns`, all of its columns will be read + + User can optionally specify targetids OR rows, but not both """ log = get_logger() infile = checkgzip(infile) @@ -212,11 +246,32 @@ def read_spectra(infile, single=False): raise IOError("{} is not a file".format(infile)) t0 = time.time() - hdus = fitsio.FITS(infile, mode='r') + hdus = fitsio.FITS(infile, mode="r") nhdu = len(hdus) - # load the metadata. + if targetids is not None and rows is not None: + raise ValueError('Set rows or targetids but not both') + if targetids is not None: + targetids = np.atleast_1d(targetids) + file_targetids = hdus["FIBERMAP"].read(columns="TARGETID") + rows = np.where(np.isin(file_targetids, targetids))[0] + if len(rows) == 0: + return Spectra() + elif rows is not None: + rows = np.asarray(rows) + + if skip_hdus is None: + skip_hdus = set() #- empty set, include everything + + if select_columns is None: + select_columns = dict() + + for extname in ("FIBERMAP", "EXP_FIBERMAP", "SCORES", "EXTRA_CATALOG"): + if extname not in select_columns: + select_columns[extname] = None + + # load the metadata. meta = dict(hdus[0].read_header()) # initialize data objects @@ -240,19 +295,44 @@ def read_spectra(infile, single=False): for h in range(1, nhdu): name = hdus[h].read_header()["EXTNAME"] - if name == "FIBERMAP": - fmap = encode_table(Table(hdus[h].read(), copy=True).as_array()) - elif name == "EXP_FIBERMAP": - expfmap = encode_table(Table(hdus[h].read(), copy=True).as_array()) - elif name == "SCORES": - scores = encode_table(Table(hdus[h].read(), copy=True).as_array()) - elif name == 'EXTRA_CATALOG': - extra_catalog = encode_table(Table(hdus[h].read(), copy=True).as_array()) + log.debug('Reading %s', name) + if name == "FIBERMAP" and name not in skip_hdus: + fmap = encode_table( + Table( + hdus[h].read(rows=rows, columns=select_columns["FIBERMAP"]), + copy=True, + ).as_array() + ) + elif name == "EXP_FIBERMAP" and name not in skip_hdus: + expfmap = encode_table( + Table( + hdus[h].read(rows=rows, columns=select_columns["EXP_FIBERMAP"]), + copy=True, + ).as_array() + ) + elif name == "SCORES" and name not in skip_hdus: + scores = encode_table( + Table( + hdus[h].read(rows=rows, columns=select_columns["SCORES"]), + copy=True, + ).as_array() + ) + elif name == "EXTRA_CATALOG" and name not in skip_hdus: + extra_catalog = encode_table( + Table( + hdus[h].read( + rows=rows, columns=select_columns["EXTRA_CATALOG"] + ), + copy=True, + ).as_array() + ) else: # Find the band based on the name mat = re.match(r"(.*)_(.*)", name) if mat is None: - raise RuntimeError("FITS extension name {} does not contain the band".format(name)) + raise RuntimeError( + "FITS extension name {} does not contain the band".format(name) + ) band = mat.group(1).lower() type = mat.group(2) if band not in bands: @@ -260,44 +340,76 @@ def read_spectra(infile, single=False): if type == "WAVELENGTH": if wave is None: wave = {} - #- Note: keep original float64 resolution for wavelength + # - Note: keep original float64 resolution for wavelength wave[band] = native_endian(hdus[h].read()) elif type == "FLUX": if flux is None: flux = {} - flux[band] = native_endian(hdus[h].read().astype(ftype)) + flux[band] = _read_image(hdus, h, ftype, rows=rows) elif type == "IVAR": if ivar is None: ivar = {} - ivar[band] = native_endian(hdus[h].read().astype(ftype)) - elif type == "MASK": + ivar[band] = _read_image(hdus, h, ftype, rows=rows) + elif type == "MASK" and type not in skip_hdus: if mask is None: mask = {} - mask[band] = native_endian(hdus[h].read().astype(np.uint32)) - elif type == "RESOLUTION": + mask[band] = _read_image(hdus, h, np.uint32, rows=rows) + elif type == "RESOLUTION" and type not in skip_hdus: if res is None: res = {} - res[band] = native_endian(hdus[h].read().astype(ftype)) - else: + res[band] = _read_image(hdus, h, ftype, rows=rows) + elif type != "MASK" and type != "RESOLUTION" and type not in skip_hdus: # this must be an "extra" HDU if extra is None: extra = {} if band not in extra: extra[band] = {} - extra[band][type] = native_endian(hdus[h].read().astype(ftype)) + + extra[band][type] = _read_image(hdus, h, ftype, rows=rows) hdus.close() duration = time.time() - t0 - log.info(iotime.format('read', infile, duration)) + log.info(iotime.format("read", infile, duration)) # Construct the Spectra object from the data. If there are any # inconsistencies in the sizes of the arrays read from the file, # they will be caught by the constructor. - spec = Spectra(bands, wave, flux, ivar, mask=mask, resolution_data=res, - fibermap=fmap, exp_fibermap=expfmap, - meta=meta, extra=extra, extra_catalog=extra_catalog, - single=single, scores=scores) + spec = Spectra( + bands, + wave, + flux, + ivar, + mask=mask, + resolution_data=res, + fibermap=fmap, + exp_fibermap=expfmap, + meta=meta, + extra=extra, + extra_catalog=extra_catalog, + single=single, + scores=scores, + ) + + # if needed, sort spectra to match order of targetids, which could be + # different than the order they appear in the file + if targetids is not None: + from desispec.util import ordered_unique + #- Input targetids that we found in the file, in the order they appear in targetids + ii = np.isin(targetids, spec.fibermap['TARGETID']) + found_targetids = ordered_unique(targetids[ii]) + log.debug('found_targetids=%s', found_targetids) + + #- Unique targetids of input file in the order they first appear + input_targetids = ordered_unique(spec.fibermap['TARGETID']) + log.debug('input_targetids=%s', np.asarray(input_targetids)) + + #- Only reorder if needed + if not np.all(input_targetids == found_targetids): + rows = np.concatenate([np.where(spec.fibermap['TARGETID'] == tid)[0] for tid in targetids]) + log.debug("spec.fibermap['TARGETID'] = %s", np.asarray(spec.fibermap['TARGETID'])) + log.debug("rows for subselection=%s", rows) + spec = spec[rows] return spec diff --git a/py/desispec/test/test_spectra.py b/py/desispec/test/test_spectra.py index 72cb91076..b20701074 100644 --- a/py/desispec/test/test_spectra.py +++ b/py/desispec/test/test_spectra.py @@ -182,6 +182,99 @@ def test_io(self): else: raise ValueError(f'Unrecognized extension for {self.fileio=}') + def test_read_targetids(self): + """Test reading while filtering by targetid""" + + # manually create the spectra and write + spec = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, + ivar=self.ivar, mask=self.mask, resolution_data=self.res, + fibermap=self.fmap1, meta=self.meta, extra=self.extra) + + write_spectra(self.fileio, spec) + + # read subset in same order as file + ii = [2,3] + spec_subset = spec[ii] + targetids = spec_subset.fibermap['TARGETID'] + comp_subset = read_spectra(self.fileio, targetids=targetids) + self.assertTrue(np.all(spec_subset.fibermap['TARGETID'] == comp_subset.fibermap['TARGETID'])) + self.assertTrue(np.allclose(spec_subset.flux['b'], comp_subset.flux['b'])) + self.assertTrue(np.allclose(spec_subset.ivar['r'], comp_subset.ivar['r'])) + self.assertTrue(np.all(spec_subset.mask['z'] == comp_subset.mask['z'])) + self.assertEqual(len(comp_subset.R['b']), len(ii)) + self.assertEqual(comp_subset.R['b'][0].shape, (self.nwave, self.nwave)) + + # read subset in different order than original file + ii = [3, 1] + spec_subset = spec[ii] + targetids = spec_subset.fibermap['TARGETID'] + comp_subset = read_spectra(self.fileio, targetids=targetids) + self.assertTrue(np.all(spec_subset.fibermap['TARGETID'] == comp_subset.fibermap['TARGETID'])) + self.assertTrue(np.allclose(spec_subset.flux['b'], comp_subset.flux['b'])) + self.assertTrue(np.allclose(spec_subset.ivar['r'], comp_subset.ivar['r'])) + self.assertTrue(np.all(spec_subset.mask['z'] == comp_subset.mask['z'])) + + # read subset in different order than original file, with repeats and missing targetids + spec.fibermap['TARGETID'] = (np.arange(self.nspec) // 2) * 2 # [0,0,2,2,4] for nspec=5 + write_spectra(self.fileio, spec) + targetids = [2,10,4,4,4,0,0] + comp_subset = read_spectra(self.fileio, targetids=targetids) + + # targetid 2 appears 2x because it is in the input file twice + # targetid 4 appears 3x because it was requested 3 times + # targetid 0 appears 4x because it was in the input file twice and requested twice + # and targetid 0 is at the end of comp_subset, not the beginning like the file + # targetid 10 doesn't appear because it wasn't in the input file, ok + self.assertTrue(np.all(comp_subset.fibermap['TARGETID'] == np.array([2,2,4,4,4,0,0,0,0]))) + + def test_read_rows(self): + """Test reading specific rows""" + + # manually create the spectra and write + spec = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, + ivar=self.ivar, mask=self.mask, resolution_data=self.res, + fibermap=self.fmap1, meta=self.meta, extra=self.extra) + + write_spectra(self.fileio, spec) + + rows = [1,3] + subset = read_spectra(self.fileio, rows=rows) + self.assertTrue(np.all(spec.fibermap[rows] == subset.fibermap)) + + with self.assertRaises(ValueError): + subset = read_spectra(self.fileio, rows=rows, targetids=[1,2]) + + + + def test_read_columns(self): + """test reading while subselecting columns""" + # manually create the spectra and write + spec = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, + ivar=self.ivar, mask=self.mask, resolution_data=self.res, + fibermap=self.fmap1, meta=self.meta) + + write_spectra(self.fileio, spec) + + test = read_spectra(self.fileio, select_columns=dict(FIBERMAP=('TARGETID', 'FIBER'))) + self.assertIn('TARGETID', test.fibermap.colnames) + self.assertIn('FIBER', test.fibermap.colnames) + self.assertIn('FLUX_R', spec.fibermap.colnames) + self.assertNotIn('FLUX_R', test.fibermap.colnames) + + def test_read_skip_hdus(self): + """test reading while skipping some HDUs""" + # manually create the spectra and write + spec = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, + ivar=self.ivar, mask=self.mask, resolution_data=self.res, + fibermap=self.fmap1, meta=self.meta) + + write_spectra(self.fileio, spec) + + test = read_spectra(self.fileio, skip_hdus=('MASK', 'RESOLUTION')) + self.assertIsNone(test.mask) + self.assertIsNone(test.R) + self.assertIsNotNone(test.fibermap) #- fibermap not skipped + def test_empty(self):