Skip to content

Commit

Permalink
Merge pull request #2052 from desihub/read_few_spectra
Browse files Browse the repository at this point in the history
add read_single_spectrum function
  • Loading branch information
sbailey authored Aug 16, 2023
2 parents 98044b2 + 629bbce commit d401035
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 27 deletions.
166 changes: 139 additions & 27 deletions py/desispec/io/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,31 @@ def write_spectra(outfile, spec, units=None):

return outfile


def read_spectra(infile, single=False):
def _read_image(hdus, extname, dtype, rows=None):
"""
Helper function to read extname from fitsio.FITS hdus, filter by rows,
convert to native endian, and cast to dtype. Returns image.
"""
data = hdus[extname].read()
if rows is not None:
data = data[rows]

return native_endian(data).astype(dtype)


def read_spectra(
infile,
single=False,
targetids=None,
rows=None,
skip_hdus=None,
select_columns={
"FIBERMAP": None,
"EXP_FIBERMAP": None,
"SCORES": None,
"EXTRA_CATALOG": None,
},
):
"""
Read Spectra object from FITS file.
Expand All @@ -196,10 +219,21 @@ def read_spectra(infile, single=False):
Args:
infile (str): path to read
single (bool): if True, keep spectra as single precision in memory.
targetids (list): Optional, list of targetids to read from file, if present.
rows (list): Optional, list of rows to read from file
skip_hdus (list): Optional, list/set/tuple of HDUs to skip
select_columns (dict): Optional, dictionary to select column names to be read. Default, all columns are read.
Returns (Spectra):
The object containing the data read from disk.
`skip_hdus` options are FIBERMAP, EXP_FIBERMAP, SCORES, EXTRA_CATALOG, MASK, RESOLUTION;
where MASK and RESOLUTION mean to skip those for all cameras.
Note that WAVE, FLUX, and IVAR are always required.
If a table HDU is not listed in `select_columns`, all of its columns will be read
User can optionally specify targetids OR rows, but not both
"""
log = get_logger()
infile = checkgzip(infile)
Expand All @@ -212,11 +246,32 @@ def read_spectra(infile, single=False):
raise IOError("{} is not a file".format(infile))

t0 = time.time()
hdus = fitsio.FITS(infile, mode='r')
hdus = fitsio.FITS(infile, mode="r")
nhdu = len(hdus)

# load the metadata.
if targetids is not None and rows is not None:
raise ValueError('Set rows or targetids but not both')

if targetids is not None:
targetids = np.atleast_1d(targetids)
file_targetids = hdus["FIBERMAP"].read(columns="TARGETID")
rows = np.where(np.isin(file_targetids, targetids))[0]
if len(rows) == 0:
return Spectra()
elif rows is not None:
rows = np.asarray(rows)

if skip_hdus is None:
skip_hdus = set() #- empty set, include everything

if select_columns is None:
select_columns = dict()

for extname in ("FIBERMAP", "EXP_FIBERMAP", "SCORES", "EXTRA_CATALOG"):
if extname not in select_columns:
select_columns[extname] = None

# load the metadata.
meta = dict(hdus[0].read_header())

# initialize data objects
Expand All @@ -240,64 +295,121 @@ def read_spectra(infile, single=False):

for h in range(1, nhdu):
name = hdus[h].read_header()["EXTNAME"]
if name == "FIBERMAP":
fmap = encode_table(Table(hdus[h].read(), copy=True).as_array())
elif name == "EXP_FIBERMAP":
expfmap = encode_table(Table(hdus[h].read(), copy=True).as_array())
elif name == "SCORES":
scores = encode_table(Table(hdus[h].read(), copy=True).as_array())
elif name == 'EXTRA_CATALOG':
extra_catalog = encode_table(Table(hdus[h].read(), copy=True).as_array())
log.debug('Reading %s', name)
if name == "FIBERMAP" and name not in skip_hdus:
fmap = encode_table(
Table(
hdus[h].read(rows=rows, columns=select_columns["FIBERMAP"]),
copy=True,
).as_array()
)
elif name == "EXP_FIBERMAP" and name not in skip_hdus:
expfmap = encode_table(
Table(
hdus[h].read(rows=rows, columns=select_columns["EXP_FIBERMAP"]),
copy=True,
).as_array()
)
elif name == "SCORES" and name not in skip_hdus:
scores = encode_table(
Table(
hdus[h].read(rows=rows, columns=select_columns["SCORES"]),
copy=True,
).as_array()
)
elif name == "EXTRA_CATALOG" and name not in skip_hdus:
extra_catalog = encode_table(
Table(
hdus[h].read(
rows=rows, columns=select_columns["EXTRA_CATALOG"]
),
copy=True,
).as_array()
)
else:
# Find the band based on the name
mat = re.match(r"(.*)_(.*)", name)
if mat is None:
raise RuntimeError("FITS extension name {} does not contain the band".format(name))
raise RuntimeError(
"FITS extension name {} does not contain the band".format(name)
)
band = mat.group(1).lower()
type = mat.group(2)
if band not in bands:
bands.append(band)
if type == "WAVELENGTH":
if wave is None:
wave = {}
#- Note: keep original float64 resolution for wavelength
# - Note: keep original float64 resolution for wavelength
wave[band] = native_endian(hdus[h].read())
elif type == "FLUX":
if flux is None:
flux = {}
flux[band] = native_endian(hdus[h].read().astype(ftype))
flux[band] = _read_image(hdus, h, ftype, rows=rows)
elif type == "IVAR":
if ivar is None:
ivar = {}
ivar[band] = native_endian(hdus[h].read().astype(ftype))
elif type == "MASK":
ivar[band] = _read_image(hdus, h, ftype, rows=rows)
elif type == "MASK" and type not in skip_hdus:
if mask is None:
mask = {}
mask[band] = native_endian(hdus[h].read().astype(np.uint32))
elif type == "RESOLUTION":
mask[band] = _read_image(hdus, h, np.uint32, rows=rows)
elif type == "RESOLUTION" and type not in skip_hdus:
if res is None:
res = {}
res[band] = native_endian(hdus[h].read().astype(ftype))
else:
res[band] = _read_image(hdus, h, ftype, rows=rows)
elif type != "MASK" and type != "RESOLUTION" and type not in skip_hdus:
# this must be an "extra" HDU
if extra is None:
extra = {}
if band not in extra:
extra[band] = {}
extra[band][type] = native_endian(hdus[h].read().astype(ftype))

extra[band][type] = _read_image(hdus, h, ftype, rows=rows)

hdus.close()
duration = time.time() - t0
log.info(iotime.format('read', infile, duration))
log.info(iotime.format("read", infile, duration))

# Construct the Spectra object from the data. If there are any
# inconsistencies in the sizes of the arrays read from the file,
# they will be caught by the constructor.

spec = Spectra(bands, wave, flux, ivar, mask=mask, resolution_data=res,
fibermap=fmap, exp_fibermap=expfmap,
meta=meta, extra=extra, extra_catalog=extra_catalog,
single=single, scores=scores)
spec = Spectra(
bands,
wave,
flux,
ivar,
mask=mask,
resolution_data=res,
fibermap=fmap,
exp_fibermap=expfmap,
meta=meta,
extra=extra,
extra_catalog=extra_catalog,
single=single,
scores=scores,
)

# if needed, sort spectra to match order of targetids, which could be
# different than the order they appear in the file
if targetids is not None:
from desispec.util import ordered_unique
#- Input targetids that we found in the file, in the order they appear in targetids
ii = np.isin(targetids, spec.fibermap['TARGETID'])
found_targetids = ordered_unique(targetids[ii])
log.debug('found_targetids=%s', found_targetids)

#- Unique targetids of input file in the order they first appear
input_targetids = ordered_unique(spec.fibermap['TARGETID'])
log.debug('input_targetids=%s', np.asarray(input_targetids))

#- Only reorder if needed
if not np.all(input_targetids == found_targetids):
rows = np.concatenate([np.where(spec.fibermap['TARGETID'] == tid)[0] for tid in targetids])
log.debug("spec.fibermap['TARGETID'] = %s", np.asarray(spec.fibermap['TARGETID']))
log.debug("rows for subselection=%s", rows)
spec = spec[rows]

return spec

Expand Down
93 changes: 93 additions & 0 deletions py/desispec/test/test_spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,99 @@ def test_io(self):
else:
raise ValueError(f'Unrecognized extension for {self.fileio=}')

def test_read_targetids(self):
"""Test reading while filtering by targetid"""

# manually create the spectra and write
spec = Spectra(bands=self.bands, wave=self.wave, flux=self.flux,
ivar=self.ivar, mask=self.mask, resolution_data=self.res,
fibermap=self.fmap1, meta=self.meta, extra=self.extra)

write_spectra(self.fileio, spec)

# read subset in same order as file
ii = [2,3]
spec_subset = spec[ii]
targetids = spec_subset.fibermap['TARGETID']
comp_subset = read_spectra(self.fileio, targetids=targetids)
self.assertTrue(np.all(spec_subset.fibermap['TARGETID'] == comp_subset.fibermap['TARGETID']))
self.assertTrue(np.allclose(spec_subset.flux['b'], comp_subset.flux['b']))
self.assertTrue(np.allclose(spec_subset.ivar['r'], comp_subset.ivar['r']))
self.assertTrue(np.all(spec_subset.mask['z'] == comp_subset.mask['z']))
self.assertEqual(len(comp_subset.R['b']), len(ii))
self.assertEqual(comp_subset.R['b'][0].shape, (self.nwave, self.nwave))

# read subset in different order than original file
ii = [3, 1]
spec_subset = spec[ii]
targetids = spec_subset.fibermap['TARGETID']
comp_subset = read_spectra(self.fileio, targetids=targetids)
self.assertTrue(np.all(spec_subset.fibermap['TARGETID'] == comp_subset.fibermap['TARGETID']))
self.assertTrue(np.allclose(spec_subset.flux['b'], comp_subset.flux['b']))
self.assertTrue(np.allclose(spec_subset.ivar['r'], comp_subset.ivar['r']))
self.assertTrue(np.all(spec_subset.mask['z'] == comp_subset.mask['z']))

# read subset in different order than original file, with repeats and missing targetids
spec.fibermap['TARGETID'] = (np.arange(self.nspec) // 2) * 2 # [0,0,2,2,4] for nspec=5
write_spectra(self.fileio, spec)
targetids = [2,10,4,4,4,0,0]
comp_subset = read_spectra(self.fileio, targetids=targetids)

# targetid 2 appears 2x because it is in the input file twice
# targetid 4 appears 3x because it was requested 3 times
# targetid 0 appears 4x because it was in the input file twice and requested twice
# and targetid 0 is at the end of comp_subset, not the beginning like the file
# targetid 10 doesn't appear because it wasn't in the input file, ok
self.assertTrue(np.all(comp_subset.fibermap['TARGETID'] == np.array([2,2,4,4,4,0,0,0,0])))

def test_read_rows(self):
"""Test reading specific rows"""

# manually create the spectra and write
spec = Spectra(bands=self.bands, wave=self.wave, flux=self.flux,
ivar=self.ivar, mask=self.mask, resolution_data=self.res,
fibermap=self.fmap1, meta=self.meta, extra=self.extra)

write_spectra(self.fileio, spec)

rows = [1,3]
subset = read_spectra(self.fileio, rows=rows)
self.assertTrue(np.all(spec.fibermap[rows] == subset.fibermap))

with self.assertRaises(ValueError):
subset = read_spectra(self.fileio, rows=rows, targetids=[1,2])



def test_read_columns(self):
"""test reading while subselecting columns"""
# manually create the spectra and write
spec = Spectra(bands=self.bands, wave=self.wave, flux=self.flux,
ivar=self.ivar, mask=self.mask, resolution_data=self.res,
fibermap=self.fmap1, meta=self.meta)

write_spectra(self.fileio, spec)

test = read_spectra(self.fileio, select_columns=dict(FIBERMAP=('TARGETID', 'FIBER')))
self.assertIn('TARGETID', test.fibermap.colnames)
self.assertIn('FIBER', test.fibermap.colnames)
self.assertIn('FLUX_R', spec.fibermap.colnames)
self.assertNotIn('FLUX_R', test.fibermap.colnames)

def test_read_skip_hdus(self):
"""test reading while skipping some HDUs"""
# manually create the spectra and write
spec = Spectra(bands=self.bands, wave=self.wave, flux=self.flux,
ivar=self.ivar, mask=self.mask, resolution_data=self.res,
fibermap=self.fmap1, meta=self.meta)

write_spectra(self.fileio, spec)

test = read_spectra(self.fileio, skip_hdus=('MASK', 'RESOLUTION'))
self.assertIsNone(test.mask)
self.assertIsNone(test.R)
self.assertIsNotNone(test.fibermap) #- fibermap not skipped


def test_empty(self):

Expand Down

0 comments on commit d401035

Please sign in to comment.