-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
120 lines (101 loc) · 3.57 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright 2022 AstroLab Software
# Author: Julien Peloton
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pandas as pd
import numpy as np
from scipy.optimize import curve_fit
def apply_selection_cuts(filternames: pd.Series, min_hist_length: int) -> pd.Series:
""" Apply selection cuts defined by the user. In this case, flag out
alerts with not enough points in the g band.
Alerts that do not satisfy the criteria are not processed.
Parameters
----------
filternames: pd.Series
Series containing filter values (array of str). Each row contains
all filter values for one alert (with its history).
min_hist_length: int
Minimum number of measurements in g
Returns
---------
mask: pd.Series
Series containing `True` if the alert is valid, `False` otherwise.
Each row contains one boolean.
"""
# Keep only g-band
hist_g_duration = filternames.apply(lambda x: np.sum(x == 'g'))
mask = (hist_g_duration >= min_hist_length)
return mask
def linear_model(x: float, a: float, b: float) -> float:
""" Linear model of the form f(x) = ax + b
"""
return a * x + b
def return_fitted_slope(func, midpointtais: list, psfluxes: list) -> float:
""" Wrapper for `curve_fit`
Parameters
----------
func: function
Function to fit
midpointtais: list of floats
List containing time steps for one alert
psfluxes: list of floats
List containing fluxes for one alert
"""
try:
fit, cov = curve_fit(func, midpointtais, psfluxes, p0=[0.0, 0.0])
except RuntimeError:
return np.nan
return fit[0]
def extract_history(history_list: list, field: str) -> list:
"""Extract the historical measurements contained in the alerts
for the parameter `field`.
Parameters
----------
history_list: list of dict
List of dictionary from alert['prvDiaSources'].
field: str
The field name for which you want to extract the data. It must be
a key of elements of history_list (alert['prvDiaSources'])
Returns
----------
measurement: list
List of all the `field` measurements contained in the alerts.
"""
try:
measurement = [obs[field] for obs in history_list]
except KeyError:
print('{} not in history data'.format(field))
measurement = [None] * len(history_list)
return measurement
def extract_field(alert: dict, field: str) -> np.array:
""" Concatenate current and historical observation data for a given field.
Parameters
----------
alert: dict
Dictionnary containing alert data
field: str
Name of the field to extract.
Returns
----------
data: np.array
List containing previous measurements and current measurement at the
end. If `field` is not in `prvDiaSources` fields, data will be
[None, None, ..., alert['diaSource'][field]].
"""
data = np.concatenate(
[
extract_history(alert['prvDiaSources'], field),
[alert["diaSource"][field]]
]
)
return data