-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
One possible implementation of models
- Loading branch information
1 parent
678d58e
commit 7519ef0
Showing
9 changed files
with
238 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1 @@ | ||
from .example_module import greetings, meaning | ||
|
||
__all__ = ["greetings", "meaning"] | ||
from .base_models import PhysicalModel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import abc | ||
|
||
class PhysicalModel(abc.ABC): | ||
def __init__(self, ra=None, dec=None, distance=None, **kwargs): | ||
self.ra = ra | ||
self.dec = dec | ||
self.distance = distance | ||
self.effects = [] | ||
|
||
def add_effect(self, effect): | ||
"""Add a transformational effect to the PhysicalModel. | ||
Effects are applied in the order in which they are added. | ||
Parameters | ||
---------- | ||
effect : `EffectModel` | ||
The effect to apply. | ||
Raises | ||
------ | ||
Raises a ``AttributeError`` if the PhysicalModel does not have all of the | ||
required attributes. | ||
""" | ||
required: list = effect.required_parameters() | ||
for parameter in required: | ||
# Raise an AttributeError if the parameter is missing or set to None. | ||
if (getattr(self, parameter) is None): | ||
raise AttributeError(f"Parameter {parameter} unset for model {type(self).__name__}") | ||
|
||
self.effects.append(effect) | ||
|
||
def _observe(self, times, bands=None, **kwargs): | ||
"""Draw effect-free observations for this object. | ||
Parameters | ||
---------- | ||
times : `numpy.ndarray` | ||
An array of timestamps. | ||
bands : `numpy.ndarray`, optional | ||
An array of bands. | ||
Returns | ||
------- | ||
flux_density : `numpy.ndarray` | ||
The results. | ||
""" | ||
raise NotImplementedError() | ||
|
||
def observe(self, times, bands=None, **kwargs): | ||
"""Draw observations for this object and apply the noise. | ||
Parameters | ||
---------- | ||
times : `numpy.ndarray` | ||
An array of timestamps. | ||
bands : `numpy.ndarray`, optional | ||
An array of bands. | ||
Returns | ||
------- | ||
flux_density : `numpy.ndarray` | ||
The results. | ||
""" | ||
flux_density = self._observe(times, bands, **kwargs) | ||
for effect in self.effects: | ||
flux_density = effect.apply(flux_density, bands, self, **kwargs) | ||
return flux_density | ||
|
||
|
||
class EffectModel(abc.ABC): | ||
def __init__(self, **kwargs): | ||
pass | ||
|
||
def required_parameters(self): | ||
"""Returns a list of the parameters of a PhysicalModel | ||
that this effect needs to access. | ||
Returns | ||
------- | ||
parameters : `list` of `str` | ||
A list of every required parameter the effect needs. | ||
""" | ||
return [] | ||
|
||
def apply(self, flux_density, bands=None, physical_model=None, **kwargs): | ||
"""Apply the effect to observations (flux_density values) | ||
Parameters | ||
---------- | ||
flux_density : `numpy.ndarray` | ||
An array of flux density values. | ||
bands : `numpy.ndarray`, optional | ||
An array of bands. | ||
physical_model : `PhysicalModel` | ||
A PhysicalModel from which the effect may query parameters | ||
such as redshift, position, or distance. | ||
Returns | ||
------- | ||
flux_density : `numpy.ndarray` | ||
The results. | ||
""" | ||
raise NotImplementedError() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import numpy as np | ||
import types | ||
|
||
from tdastro.base_models import EffectModel, PhysicalModel | ||
|
||
class WhiteNoise(EffectModel): | ||
def __init__(self, scale, **kwargs): | ||
super().__init__(**kwargs) | ||
self.scale = scale | ||
|
||
def apply(self, flux_density, bands=None, physical_model=None, **kwargs): | ||
"""Apply the effect to observations (flux_density values) | ||
Parameters | ||
---------- | ||
flux_density : `numpy.ndarray` | ||
An array of flux density values. | ||
bands : `numpy.ndarray`, optional | ||
An array of bands. | ||
physical_model : `PhysicalModel` | ||
A PhysicalModel from which the effect may query parameters | ||
such as redshift, position, or distance. | ||
Returns | ||
------- | ||
flux_density : `numpy.ndarray` | ||
The results. | ||
""" | ||
return np.random.normal(loc=flux_density, scale=self.scale) | ||
|
||
|
||
class DistanceBasedWhiteNoise(EffectModel): | ||
def __init__(self, scale, dist_multiplier, **kwargs): | ||
super().__init__(**kwargs) | ||
self.scale = scale | ||
self.dist_multiplier = dist_multiplier | ||
|
||
def required_parameters(self): | ||
"""Returns a list of the parameters of a PhysicalModel | ||
that this effect needs to access. | ||
Returns | ||
------- | ||
parameters : `list` of `str` | ||
A list of every required parameter the effect needs. | ||
""" | ||
return ['distance'] | ||
|
||
def apply(self, flux_density, bands=None, physical_model=None, **kwargs): | ||
"""Apply the effect to observations (flux_density values) | ||
Parameters | ||
---------- | ||
flux_density : `numpy.ndarray` | ||
An array of flux density values. | ||
bands : `numpy.ndarray`, optional | ||
An array of bands. | ||
physical_model : `PhysicalModel` | ||
A PhysicalModel from which the effect may query parameters | ||
such as redshift, position, or distance. | ||
Returns | ||
------- | ||
flux_density : `numpy.ndarray` | ||
The results. | ||
""" | ||
scale_value = self.scale + self.dist_multiplier * physical_model.distance | ||
return np.random.normal(loc=flux_density, scale=scale_value) |
This file was deleted.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import numpy as np | ||
import types | ||
|
||
from tdastro.base_models import PhysicalModel | ||
|
||
class StaticSource(PhysicalModel): | ||
def __init__(self, brightness, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
if brightness is None: | ||
# If we were not given the parameter, use a default sampling function. | ||
self.brightness = np.random.rand(10.0, 20.0) | ||
elif type(brightness) is types.FunctionType: | ||
# If we were given a sampling function, use it. | ||
self.brightness = brightness(**kwargs) | ||
else: | ||
# Otherwise assume we were given the parameter itself. | ||
self.brightness = brightness | ||
|
||
def _observe(self, times, bands=None, **kwargs): | ||
"""Draw effect-free observations for this object. | ||
Parameters | ||
---------- | ||
times : `numpy.ndarray` | ||
An array of timestamps. | ||
bands : `numpy.ndarray`, optional | ||
An array of bands. If ``None`` then does something. | ||
Returns | ||
------- | ||
flux : `numpy.ndarray` | ||
The results. | ||
""" | ||
return np.full_like(times, self.brightness) | ||
|
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from tdastro.effects.white_noise import DistanceBasedWhiteNoise, WhiteNoise | ||
from tdastro.sources.static_source import StaticSource | ||
|
||
def brightness_generator(): | ||
return 10.0 + 0.5 * np.random.rand(1) | ||
|
||
def test_white_noise() -> None: | ||
model = StaticSource(brightness=brightness_generator) | ||
model.add_effect(WhiteNoise(scale=0.01)) | ||
|
||
values = model.observe(np.array([1, 2, 3, 4, 5])) | ||
assert len(values) == 5 | ||
assert not np.all(values == 10.0) | ||
assert np.all((np.abs(values - 10.0) < 1.0)) | ||
|
||
def test_distance_based_white_noise() -> None: | ||
model1 = StaticSource(brightness=10.0, distance=10.0) | ||
model1.add_effect(DistanceBasedWhiteNoise(scale=0.01, dist_multiplier=0.05)) | ||
|
||
values = model1.observe(np.array([1, 2, 3, 4, 5])) | ||
assert len(values) == 5 | ||
assert not np.all(values == 10.0) | ||
|
||
# Fail if distance is not specified. | ||
model2 = StaticSource(brightness=10.0) | ||
with pytest.raises(AttributeError): | ||
model2.add_effect(DistanceBasedWhiteNoise(scale=0.01, dist_multiplier=0.05)) |