Skip to content

Commit

Permalink
Init fixes (#284)
Browse files Browse the repository at this point in the history
* removed get_best_fit_spectrum: not robust enough and not needed if set_spectrum_to_match is called

* robust linear solver when components are degenerate (closes #282)

* refactored src and obs loops
  • Loading branch information
pmelchior authored May 1, 2024
1 parent 45187fd commit 2d1f022
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 173 deletions.
13 changes: 6 additions & 7 deletions docs/0-quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Since we use a trivial `wcs` in this `Observation`, all coordinates are already in image pixels, otherwise RA/Dec pairs are expected as sky coordinates. Also:"
"Since we use a trivial `wcs` in this `Observation`, all coordinates are already in image pixels, otherwise RA/Dec pairs are expected as sky coordinates."
]
},
{
Expand Down Expand Up @@ -253,7 +253,7 @@
" else:\n",
" new_source = scarlet.ExtendedSource(model_frame, center, observation, compact=True)\n",
" sources.append(new_source)\n",
" \n",
"\n",
"for k, src in enumerate(sources):\n",
" print (f\"{k}: {src.__class__.__name__}\")"
]
Expand All @@ -262,8 +262,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"These source are initialized independently, and their spectra, i.e. the amplitudes in every channel, assume that they are isolated.\n",
"We can make another sweep and determine their spectra so that their superposition best matches the observation. Above, this was done for use by the option `set_specta=True`. But we can call the linear solver directly:"
"These sources are initialized independently, with spectra taken from their peak position in the observations. We can make another sweep and determine their spectra such that their superposition best matches the observation. Above, this was done for use by the option `set_specta=True`. But we can call the linear solver directly:"
]
},
{
Expand All @@ -281,7 +280,7 @@
"source": [
"## Create and Fit Model\n",
"\n",
"The `Blend` class holds the list of sources and has the machinery to fit them to the given images. In this example the code is set to run for a maximum of 100 iterations, but will end early if the likelihood and all of the constraints converge."
"The `Blend` class holds the list of sources and has the machinery to fit them to the given images. In this example the code is set to run for a maximum of 100 iterations, but will end early if the likelihood and all the constraints converge."
]
},
{
Expand Down Expand Up @@ -526,7 +525,7 @@
"metadata": {
"celltoolbar": "Raw Cell Format",
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -540,7 +539,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
208 changes: 95 additions & 113 deletions scarlet/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,44 +9,9 @@
logger = logging.getLogger("scarlet.initialisation")


def get_best_fit_spectrum(morph, images):
"""Calculate best fitting spectra for one or multiple morphologies.
Solves min_A ||img - AS||^2 for the spectrum matrix A,
assuming that the images only contain a single source.
Parameters
----------
morph: array or list thereof
Morphology for each component in the source
images: array
images to get the spectrum amplitude from
Returns
-------
spectrum: `~numpy.array`
"""

if isinstance(morph, (list, tuple)) or (
isinstance(morph, np.ndarray) and len(morph.shape) == 3
):
morphs = morph
else:
morphs = (morph,)

K = len(morphs)
C = images.shape[0]
im = images.reshape(C, -1)

if K == 1:
morph = morphs[0].reshape(-1)
return np.dot(im, morph) / np.dot(morph, morph)
else:
morph = np.array(morphs).reshape(K, -1)
return np.dot(np.linalg.inv(np.dot(morph, morph.T)), np.dot(morph, im.T))


def get_pixel_spectrum(sky_coord, observations, correct_psf=False, models=None):
def get_pixel_spectrum(
sky_coord, observations, correct_psf=False, models=None, concat=True
):
"""Get the spectrum at `sky_coord` in `observation`.
Yields the spectrum of a single-pixel source with flux 1 in every channel,
Expand All @@ -68,24 +33,24 @@ def get_pixel_spectrum(sky_coord, observations, correct_psf=False, models=None):
If PSF shape variations in the observations should be corrected.
models: instance or list of arrays
Rendered models for this source in every observation
concat: bool
Whether spectra from multiple observations are flattened
Returns
-------
spectrum: `~numpy.array` or list thereof
spectrum: `~numpy.array`
"""
if models is not None:
assert correct_psf is False

if not hasattr(observations, "__iter__"):
single = True
observations = (observations,)
models = (models,)
else:
if models is not None:
assert len(models) == len(observations)
else:
models = (None,) * len(observations)
single = False

spectra = []
for obs, model in zip(observations, models):
Expand Down Expand Up @@ -114,12 +79,13 @@ def get_pixel_spectrum(sky_coord, observations, correct_psf=False, models=None):
else:
logger.info(msg)

if single:
return spectra[0]
if concat:
spectra = np.concatenate(spectra).reshape(-1)

return spectra


def get_psf_spectrum(sky_coord, observations, compute_snr=False):
def get_psf_spectrum(sky_coord, observations, compute_snr=False, concat=True):
"""Get spectrum for a point source at `sky_coord` in `observation`
Equivalent to point source photometry for isolated sources. For extended source,
Expand All @@ -137,17 +103,16 @@ def get_psf_spectrum(sky_coord, observations, compute_snr=False):
Observation to extract the spectrum from.
compute_snr: bool
Whether the compute the SNR of a PSF at `sky_coord`
concat: bool
Whether spectra from multiple observations are flattened
Returns
-------
spectrum: ~numpy.array` or list thereof
spectrum: ~numpy.array`
"""

if not hasattr(observations, "__iter__"):
single = True
observations = (observations,)
else:
single = False

spectra = []
if compute_snr:
Expand Down Expand Up @@ -196,8 +161,8 @@ def get_psf_spectrum(sky_coord, observations, compute_snr=False):
else:
logger.info(msg)

if single:
spectra = spectra[0]
if concat:
spectra = np.concatenate(spectra).reshape(-1)

if compute_snr:
snr = np.sum(snr_num) / np.sqrt(np.sum(snr_denom))
Expand Down Expand Up @@ -251,28 +216,28 @@ def build_initialization_image(observations, spectra=None):
Parameters
----------
observations: list of `~scarlet.observation.Observation`
Every observation with a suitable renderer will contribute to the initialization image, according to the noise level of its data
spectra: list of array
for every observation: spectrum at the center of the source
If not set, returns the detection image in all channels, instead of averaging.
Every observation with a suitable renderer will contribute to the initialization image,
according to its noise level.
spectra: list of arrays
for every observation: source spectrum to optimize for. If not set, assumes flat spectrum.
Returns
-------
image: array
image created by weighting all of the channels by SED
std: float
image created by weighting all channels by spectrum
std: array
the effective noise standard deviation of `image`
"""

if not hasattr(observations, "__iter__"):
observations = (observations,)
if spectra is not None:
spectra = (spectra,)
spectra = (spectra,)
assert len(observations) == len(spectra)

model_frame = observations[0].model_frame

# check if detection images are stored in obs[0]
# stoing in an obs avoids using the cache (see issue 256)
# storing in an obs avoids using the cache (see issue 256)
if not hasattr(observations[0], "_detect"):
# if not, map every obs and variance onto the model frame
detect, var = [], []
Expand All @@ -297,27 +262,24 @@ def build_initialization_image(observations, spectra=None):

detect, var = observations[0]._detect

# get multi-channel image for spectrum matching
if spectra is None:
nonzero = np.minimum(1, (var > 0).sum(axis=0))
detect = detect.sum(axis=0) / nonzero
var = var.sum(axis=0) / nonzero
else:
# spectrum SNR weighted combination of all observations
spectrum = []
for i, obs in enumerate(observations):
if not isinstance(obs.renderer, (NullRenderer, ConvolutionRenderer)):
continue
spectrum_ = np.zeros(model_frame.C)
# spectrum SNR weighted combination of all observations
spectrum = []
for i, obs in enumerate(observations):
if not isinstance(obs.renderer, (NullRenderer, ConvolutionRenderer)):
continue
spectrum_ = np.zeros(model_frame.C)
if spectra[i] is not None:
obs.renderer.map_channels(spectrum_)[:] = spectra[i]
spectrum.append(spectrum_)
spectrum = np.stack(spectrum, axis=0)[:, :, None, None] # L x C x Ny x Nx
weight = np.zeros(var.shape)
sel = var > 0
weight[sel] = 1 / var[sel]
weight *= spectrum
detect = (weight * detect).sum(axis=(0, 1))
var = (spectrum * weight).sum(axis=(0, 1))
else:
obs.renderer.map_channels(spectrum_)[:] = 1
spectrum.append(spectrum_)
spectrum = np.stack(spectrum, axis=0)[:, :, None, None] # L x C x Ny x Nx
weight = np.zeros(var.shape)
sel = var > 0
weight[sel] = 1 / var[sel]
weight *= spectrum
detect = (weight * detect).sum(axis=(0, 1))
var = (spectrum * weight).sum(axis=(0, 1))

return detect, np.sqrt(var)

Expand Down Expand Up @@ -546,60 +508,80 @@ def set_spectra_to_match(sources, observations):
observations = (observations,)
model_frame = observations[0].model_frame

for obs in observations:

# extract model for every component
morphs = []
parameters = []
for src in sources:
if isinstance(src, CombinedComponent):
components = src.children
else:
components = (src,)
for c in components:
if isinstance(c, FactorizedComponent):
p = c.parameters[0]
if not p.fixed:
obs.renderer.map_channels(p)[:] = 1
parameters.append(p)
model_ = obs.render(c.get_model(frame=model_frame))
morphs.append(model_)

morphs = np.array(morphs)
K = len(morphs)

images = obs.data
weights = obs.weights
C = obs.C
# extract multi-channel model for every non-degenerate component
parameters = []
update_of = []
models = []
for i, src in enumerate(sources):
if isinstance(src, CombinedComponent):
components = src.children
else:
components = (src,)

for j, c in enumerate(components):
p = c.get_parameter(
"spectrum"
) # returns None of c doesn't have parameter "spectrum"
parameters.append(p)
# correct for different flux in channels to have flat-spectrum component
if p is not None and not p.fixed:
p[:] = 1
model = c.get_model(frame=model_frame)

# check for models with identical initializations, see #282
# if duplicate: remove morph[k] from linear fit, but keep track of parameters[k]
# to set spectrum later: update_of: component index -> updated spectrum index
K_ = len(models)
update_of.append(K_)
for l in range(K_):
if np.allclose(model, models[l]):
update_of[-1] = l
message = f"Source {i}, Component {j} has a model identical to another component.\n"
message += "This is likely not intended, and the source/component should be deleted. "
message += "Spectra will be identical."
logger.warning(message)
if update_of[-1] == K_:
models.append(model)
models = np.array(models)
K = len(parameters)
K_ = len(models)

for obs in observations:
# independent channels, no mixing
# solve the linear inverse problem of the amplitudes in every channel
# given all the rendered morphologies
# spectrum = (M^T Sigma^-1 M)^-1 M^T Sigma^-1 * im
spectra = np.zeros((K, C))
C = obs.C
images = obs.data
weights = obs.weights
morphs = np.stack([obs.render(model) for model in models], axis=0)
spectra = np.zeros((K_, C))
for c in range(C):
im = images[c].reshape(-1)
w = weights[c].reshape(-1)
m = morphs[:, c, :, :].reshape(K, -1)
m = morphs[:, c, :, :].reshape(K_, -1)
mw = m * w[None, :]
# check if all components have nonzero flux in c.
# because of convolutions, flux can be outside of the box,
# so we need to compare weighted flu with unweighted flux,
# which is the same (up to a constant) for constant weights
# because of convolutions, flux can be outside the box,
# so we need to compare weighted flux with unweighted flux,
# which is the same (up to a constant) for constant weights.
# so we check if *most* of the flux is from pixels with non-zero weight
nonzero = np.sum(mw, axis=1) / np.sum(m, axis=1) / np.mean(w) > 0.1
nonzero = np.flatnonzero(nonzero)
if len(nonzero) == K:
if len(nonzero) == K_:
covar = np.linalg.inv(mw @ m.T)
spectra[:, c] = covar @ m @ (im * w)
else:
covar = np.linalg.inv(mw[nonzero] @ m[nonzero].T)
spectra[nonzero, c] = covar @ m[nonzero] @ (im * w)

for p, spectrum in zip(parameters, spectra):
obs.renderer.map_channels(p)[:] = spectrum
# update the parameters with the best-fit spectrum solution
for k, p in enumerate(parameters):
if p is not None and not p.fixed:
l = update_of[k]
obs.renderer.map_channels(p)[:] = spectra[l]

# enforce constraints
for p in parameters:
if p.constraint is not None:
if p is not None and p.constraint is not None:
p[:] = p.constraint(p, 0)
Loading

0 comments on commit 2d1f022

Please sign in to comment.