Skip to content

Commit

Permalink
add: new type "expression" for common.models.map.PointByPointScanData
Browse files Browse the repository at this point in the history
  • Loading branch information
keara-soloway committed Feb 2, 2024
1 parent 5665845 commit a5318f7
Showing 1 changed file with 118 additions and 16 deletions.
134 changes: 118 additions & 16 deletions CHAP/common/models/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ def get_index(self, scan_number:int, scan_step_index:int, map_config):
coordinate_index = list(
map_config.coords[independent_dimension.label]).index(
independent_dimension.get_value(
self, scan_number, scan_step_index))
self, scan_number, scan_step_index,
map_config.scalar_data))
index = (coordinate_index, *index)
return index

Expand Down Expand Up @@ -235,7 +236,7 @@ class PointByPointScanData(BaseModel):
"""
label: constr(min_length=1)
units: constr(strip_whitespace=True, min_length=1)
data_type: Literal['spec_motor', 'scan_column', 'smb_par']
data_type: Literal['spec_motor', 'scan_column', 'smb_par', 'expression']
name: constr(strip_whitespace=True, min_length=1)

@validator('label')
Expand Down Expand Up @@ -308,9 +309,47 @@ def validate_for_spec_scans(
f'on scan number {scan_number} '
f'for index {index} '
f'in spec file {scans.spec_file}')
def validate_for_scalar_data(self, scalar_data):
"""Used for `PointByPointScanData` objects with a `data_type`
of `'expression'`. Validate that the `scalar_data` field of a
`MapConfig` object contains all the items necessary for
evaluating the expression.
:param scalar_data: the `scalar_data` field of a `MapConfig`
that this `PointByPointScanData` object will be validated
against
:type scalar_data: list[PointByPointScanData]
:raises ValueError: if `scalar_data` does not contain items
needed for evaluating the expression.
:return: None
"""
from ast import parse
from asteval import get_ast_names

labels = get_ast_names(parse(self.name))
for label in ('round', 'np', 'numpy'):
try:
labels.remove(label)
except:
pass
for l in labels:
if l == 'round':
symtable[l] = round
continue
if l in ('np', 'numpy'):
symtable[l] = np
continue
label_found = False
for s_d in scalar_data:
if s_d.label == l:
label_found = True
break
if not label_found:
raise ValueError(
f'{l} is not the label of an item in scalar_data')

def get_value(self, spec_scans:SpecScans,
scan_number:int, scan_step_index:int=0):
scan_number:int, scan_step_index:int=0, scalar_data=[]):
"""Return the value recorded for this instance of
`PointByPointScanData` at a specific scan step.
Expand All @@ -322,6 +361,10 @@ def get_value(self, spec_scans:SpecScans,
:type scan_number: int
:param scan_step_index: The index of the requested scan step.
:type scan_step_index: int
:param scalar_data: list of scalar data configurations used to
get values for `PointByPointScanData` objects with
`data_type == 'expression'`, optional
:type scalar_data: list[PointByPointScanData], defaults to []
:return: The value recorded of the data represented by this
instance of `PointByPointScanData` at the scan step
requested.
Expand All @@ -341,6 +384,12 @@ def get_value(self, spec_scans:SpecScans,
return get_smb_par_value(spec_scans.spec_file,
scan_number,
self.name)
elif self.data_type == 'expression':
return get_expression_value(spec_scans,
scan_number,
scan_step_index,
self.name,
scalar_data)
return None


Expand Down Expand Up @@ -426,6 +475,43 @@ def get_smb_par_value(spec_file:str, scan_number:int, par_name:str):
return scanparser.pars[par_name]


def get_expression_value(spec_scans:SpecScans, scan_number:int,
scan_step_index:int, expression:str,
scalar_data:list[PointByPointScanData]):
"""Return the value of an evaluated expression of other sources of
point-by-point scalar scan data for a single point.
:param spec_scans: An instance of `SpecScans` in which the
requested scan step occurs.
:type spec_scans: SpecScans
:param scan_number: The number of the scan in which the requested
scan step occurs.
:type scan_number: int
:param scan_step_index: The index of the requested scan step.
:type scan_step_index: int
:param expression: the string expression to evaluate
:type expression: str
:param scalar_data: the `scalar_data` field of a `MapConfig`
object (used to provide values for variables used in
`expression`)
:type scalar_data: list[PointByPointScanData]
:return: The value of the .par file value for the scan requested.
:rtype: float
"""
from ast import parse
from asteval import get_ast_names, Interpreter
labels = get_ast_names(parse(expression))
symtable = {}
for l in labels:
if l == 'round':
symtable[l] = round
for s_d in scalar_data:
if s_d.label == l:
symtable[l] = s_d.get_value(
spec_scans, scan_number, scan_step_index, scalar_data)
aeval = Interpreter(symtable=symtable)
return aeval(expression)

def validate_data_source_for_map_config(data_source, values):
"""Confirm that an instance of PointByPointScanData is valid for
the station and scans provided by a map configuration dictionary.
Expand All @@ -439,11 +525,22 @@ def validate_data_source_for_map_config(data_source, values):
:return: `data_source`, if it is valid.
:rtype: PointByPointScanData
"""
if data_source is not None:
import_scanparser(values.get('station'), values.get('experiment_type'))
data_source.validate_for_station(values.get('station'))
data_source.validate_for_spec_scans(values.get('spec_scans'))
return data_source
def _validate_data_source_for_map_config(
data_source, values, parent_list=None):
if isinstance(data_source, list):
return [_validate_data_source_for_map_config(
d_s, values, parent_list=data_source) for d_s in data_source]
if data_source is not None:
if data_source.data_type == 'expression':
data_source.validate_for_scalar_data(
values.get('scalar_data', parent_list))
else:
import_scanparser(
values.get('station'), values.get('experiment_type'))
data_source.validate_for_station(values.get('station'))
data_source.validate_for_spec_scans(values.get('spec_scans'))
return(data_source)
return _validate_data_source_for_map_config(data_source, values)


class IndependentDimension(PointByPointScanData):
Expand Down Expand Up @@ -516,7 +613,7 @@ def reserved_labels(cls):
:return: A list of reserved labels.
:rtype: list[str]
"""
return list(cls.__fields__['label'].type_.__args__)
return list((*cls.__fields__['label'].type_.__args__, 'round'))


class PresampleIntensity(CorrectionsData):
Expand Down Expand Up @@ -678,12 +775,12 @@ class MapConfig(BaseModel):
experiment_type: Literal['SAXSWAXS', 'EDD', 'XRF', 'TOMO']
sample: Sample
spec_scans: conlist(item_type=SpecScans, min_items=1)
scalar_data: Optional[list[PointByPointScanData]] = []
independent_dimensions: conlist(
item_type=IndependentDimension, min_items=1)
presample_intensity: Optional[PresampleIntensity]
dwell_time_actual: Optional[DwellTimeActual]
postsample_intensity: Optional[PostsampleIntensity]
scalar_data: Optional[list[PointByPointScanData]] = []
map_type: Optional[Literal['structured', 'unstructured']] = 'structured'
_coords: dict = PrivateAttr()
_dims: tuple = PrivateAttr()
Expand All @@ -705,7 +802,6 @@ class MapConfig(BaseModel):
allow_reuse=True)(validate_data_source_for_map_config)
_validate_scalar_data = validator(
'scalar_data',
each_item=True,
allow_reuse=True)(validate_data_source_for_map_config)

@root_validator(pre=True)
Expand Down Expand Up @@ -745,6 +841,7 @@ def validate_map_type(cls, map_type, values):
dims = {}
spec_scans = values.get('spec_scans')
independent_dimensions = values.get('independent_dimensions')
scalar_data = values.get('scalar_data')
import_scanparser(values.get('station'), values.get('experiment_type'))
for i, dim in enumerate(deepcopy(independent_dimensions)):
dims[dim.label] = []
Expand All @@ -754,7 +851,8 @@ def validate_map_type(cls, map_type, values):
for scan_step_index in range(
scanparser.spec_scan_npts):
dims[dim.label].append(dim.get_value(
scans, scan_number, scan_step_index))
scans, scan_number, scan_step_index,
scalar_data))
dims[dim.label] = np.unique(dims[dim.label])
if dim.end is None:
dim.end = len(dims[dim.label])
Expand All @@ -769,7 +867,8 @@ def validate_map_type(cls, map_type, values):
for scan_step_index in range(scanparser.spec_scan_npts):
coords[tuple([
list(dims[dim.label]).index(
dim.get_value(scans, scan_number, scan_step_index))
dim.get_value(scans, scan_number, scan_step_index,
scalar_data))
for dim in independent_dimensions])] += 1
if any(True for v in coords.flatten() if v == 0 or v > 1):
return 'unstructured'
Expand Down Expand Up @@ -827,7 +926,8 @@ def coords(self):
for scan_step_index in range(
scanparser.spec_scan_npts):
coords[dim.label].append(dim.get_value(
scans, scan_number, scan_step_index))
scans, scan_number, scan_step_index,
self.scalar_data))
if self.map_type == 'structured':
coords[dim.label] = np.unique(coords[dim.label])
self._coords = coords
Expand Down Expand Up @@ -921,7 +1021,8 @@ def get_scan_step_index(self, map_index):
map_coords = self.get_coords(map_index)
for scans, scan_number, scan_step_index in self.scan_step_indices:
coords = {dim.label:dim.get_value(
scans, scan_number, scan_step_index)
scans, scan_number, scan_step_index,
self.scalar_data)
for dim in self.independent_dimensions}
if coords == map_coords:
return scans, scan_number, scan_step_index
Expand All @@ -942,7 +1043,8 @@ def get_value(self, data, map_index):
"""
scans, scan_number, scan_step_index = \
self.get_scan_step_index(map_index)
return data.get_value(scans, scan_number, scan_step_index)
return data.get_value(scans, scan_number, scan_step_index,
self.scalar_data)


def import_scanparser(station, experiment):
Expand Down

0 comments on commit a5318f7

Please sign in to comment.