Skip to content

Commit

Permalink
minor update
Browse files Browse the repository at this point in the history
  • Loading branch information
wcxve committed Mar 1, 2024
1 parent 1bd13bc commit 852f84d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 8 deletions.
41 changes: 36 additions & 5 deletions src/elisa/model/conv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""ConvolutionComponent models."""
from __future__ import annotations

from abc import abstractmethod
from typing import Callable

import jax
Expand All @@ -12,7 +13,7 @@
__all__ = ['EnFlux', 'PhFlux', 'RedShift', 'VelocityShift']


class FluxNorm(ConvolutionComponent):
class NormConvolution(ConvolutionComponent):
_args = ('emin', 'emax')
_kwargs = ('ngrid', 'elog')
_supported = frozenset({'add'})
Expand All @@ -39,6 +40,36 @@ def __init__(

super().__init__(params, latex)

@staticmethod
@abstractmethod
def convolve(
egrid: JAXArray,
params: NameValMapping,
model_fn: Callable[[JAXArray], JAXArray],
flux_egrid: JAXArray,
) -> JAXArray:
"""Convolve a model function.
Parameters
----------
egrid : ndarray
Photon energy grid in units of keV.
params : dict
Parameter dict for the convolution model.
model_fn : callable
The model function to be convolved, which takes the energy grid as
input and returns the model flux over the grid.
flux_egrid : ndarray
Photon energy grid used to calculate flux in units of keV.
Returns
-------
value : ndarray
The re-normalized model over `egrid`, in units of cm⁻² s⁻¹ keV⁻¹.
"""
pass

@property
def eval(self) -> ConvolveEval:
if self._prev_config == (self.emin, self.emax, self.ngrid, self.elog):
Expand Down Expand Up @@ -107,7 +138,7 @@ def elog(self, value: bool):
self._elog = bool(value)


class PhFlux(FluxNorm):
class PhFlux(NormConvolution):
r"""Normalize an additive model by photon flux between `emin` and `emax`.
Warnings
Expand All @@ -125,7 +156,7 @@ class PhFlux(FluxNorm):
Flux parameter, in units of cm⁻² s⁻¹.
latex : str, optional
:math:`\LaTeX` format of the component. Defaults to class name.
ngrid : int or None, optional
ngrid : int, optional
The energy grid number to use. The default is 1000.
elog : bool, optional
Whether to use logarithmically regular energy grids.
Expand All @@ -152,7 +183,7 @@ def convolve(
return F / mflux * flux


class EnFlux(FluxNorm):
class EnFlux(NormConvolution):
r"""Normalize an additive model by energy flux between `emin` and `emax`.
Warnings
Expand All @@ -170,7 +201,7 @@ class EnFlux(FluxNorm):
Flux parameter, in units of erg cm⁻² s⁻¹.
latex : str, optional
:math:`\LaTeX` format of the component. Defaults to class name.
ngrid : int or None, optional
ngrid : int, optional
The energy grid number to use. The default is 1000.
elog : bool, optional
Whether to use logarithmically regular energy grids.
Expand Down
6 changes: 3 additions & 3 deletions src/elisa/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,13 +1394,13 @@ def convolve(*args, **kwargs) -> JAXArray:
params : dict
Parameter dict for the convolution model.
model_fn : callable
The model function to be convolved, which takes energy grid as
input and returns the model value at the grid.
The model function to be convolved, which takes the energy grid as
input and returns the model value over the grid.
Returns
-------
value : ndarray
The convolved model value at the energy grid.
The convolved model value over `egrid`.
"""
pass
Expand Down

0 comments on commit 852f84d

Please sign in to comment.