From 80ea0d4884b7356ba9e23c95764f6c681d15e348 Mon Sep 17 00:00:00 2001 From: Leszek Grzanka Date: Fri, 22 Mar 2024 22:35:04 +0100 Subject: [PATCH] fixes --- pymchelper/averaging.py | 17 +++---------- pymchelper/input_output.py | 5 ++-- pymchelper/readers/common.py | 2 -- pymchelper/readers/topas.py | 2 -- pymchelper/writers/excel.py | 49 ++++++++++++++++++------------------ pymchelper/writers/hdf.py | 6 +++-- 6 files changed, 35 insertions(+), 46 deletions(-) diff --git a/pymchelper/averaging.py b/pymchelper/averaging.py index ff39385e7..105850db2 100644 --- a/pymchelper/averaging.py +++ b/pymchelper/averaging.py @@ -49,17 +49,12 @@ class Aggregator: _updated: bool = field(default=False, repr=False, init=False) def error(self, **kwargs): - """ - Default implementation of error function, returns None. - """ + """Default implementation of error function, returns None.""" logging.warning("Error calculation not implemented for %s", self.__class__.__name__) - return None @property def updated(self): - """ - Check if the aggregator was updated. - """ + """Check if the aggregator was updated.""" if isinstance(self.data, float) and np.isnan(self.data): return False return self._updated @@ -154,9 +149,7 @@ def error(self, **kwargs): @dataclass class ConcatenatingAggregator(Aggregator): - """ - Class for concatenating numpy arrays - """ + """Class for concatenating numpy arrays""" def update(self, value: Union[float, ArrayLike]): "" @@ -169,9 +162,7 @@ def update(self, value: Union[float, ArrayLike]): @dataclass class SumAggregator(Aggregator): - """ - Class for calculating sum of a sequence of numbers. - """ + """Class for calculating sum of a sequence of numbers.""" def update(self, value: Union[float, ArrayLike]): # first value added diff --git a/pymchelper/input_output.py b/pymchelper/input_output.py index 2793ead30..0a0ff79cd 100644 --- a/pymchelper/input_output.py +++ b/pymchelper/input_output.py @@ -18,9 +18,8 @@ class AggregationType(IntEnum): - """ - Enum for different types of aggregation. - """ + """Enum for different types of aggregation.""" + NoAggregation = 0 Sum = 1 AveragingCumulative = 2 diff --git a/pymchelper/readers/common.py b/pymchelper/readers/common.py index b3400163b..08224e349 100644 --- a/pymchelper/readers/common.py +++ b/pymchelper/readers/common.py @@ -1,8 +1,6 @@ from abc import abstractmethod import logging -import numpy as np - logger = logging.getLogger(__name__) diff --git a/pymchelper/readers/topas.py b/pymchelper/readers/topas.py index 56e9f56d7..3043fc91d 100644 --- a/pymchelper/readers/topas.py +++ b/pymchelper/readers/topas.py @@ -216,7 +216,6 @@ def read_data(self, estimator: Estimator) -> bool: num_results = len(results) page = Page(estimator=estimator) set_data = False - set_error = False for column, result in enumerate(results): if result not in ['Mean', 'Standard_Deviation']: continue @@ -270,7 +269,6 @@ def read_data(self, estimator: Estimator) -> bool: set_data = True elif result == 'Standard_Deviation': page.error_raw = scores - set_error = True # If we didn't find mean results for the scorer, we return False if not set_data: diff --git a/pymchelper/writers/excel.py b/pymchelper/writers/excel.py index e65224888..706ded36b 100644 --- a/pymchelper/writers/excel.py +++ b/pymchelper/writers/excel.py @@ -1,5 +1,7 @@ import logging +from pymchelper.estimator import Estimator + import numpy as np logger = logging.getLogger(__name__) @@ -9,45 +11,44 @@ class ExcelWriter: """ Supports writing XLS files (MS Excel 2003 format) """ + def __init__(self, filename, options): self.filename = filename if not self.filename.endswith(".xls"): self.filename += ".xls" - def write(self, estimator): - if len(estimator.pages) > 1: - print("Conversion of data with multiple pages not supported yet") - return False - + def write(self, estimator: Estimator): try: import xlwt except ImportError as e: - logger.error("Generating Excel files not available on your platform (you are probably running Python 3.2).") + logger.error("Generating Excel files not available on your platform (please install xlwt).") raise e - page = estimator.pages[0] + # create workbook + wb = xlwt.Workbook() - # save only 1-D data - if page.dimension != 1: - logger.warning("page dimension {:d} != 1, XLS output not supported".format(estimator.dimension)) - return 1 + for page_id, page in enumerate(estimator.pages): - # create workbook with single sheet - wb = xlwt.Workbook() - ws = wb.add_sheet('Data') + # save only 1-D data + if page.dimension != 1: + logger.warning("page dimension {:d} != 1, XLS output not supported".format(estimator.dimension)) + return 1 + + # create worksheet + ws = wb.add_sheet(f'Data_p{page_id}') - # save X axis data - for i, x in enumerate(page.plot_axis(0).data): - ws.write(i, 0, x) + # save X axis data + for i, x in enumerate(page.plot_axis(0).data): + ws.write(i, 0, x) - # save Y axis data - for i, y in enumerate(page.data_raw): - ws.write(i, 1, y) + # save Y axis data + for i, y in enumerate(page.data_raw): + ws.write(i, 1, y) - # save error column (if present) - if np.all(np.isfinite(page.error_raw)): - for i, e in enumerate(page.error_raw): - ws.write(i, 2, e) + # save error column (if present) + if page.error_raw is not None: + for i, e in enumerate(page.error_raw): + ws.write(i, 2, e) # save file logger.info("Writing: " + self.filename) diff --git a/pymchelper/writers/hdf.py b/pymchelper/writers/hdf.py index 5a5506e0c..9bd8310b0 100644 --- a/pymchelper/writers/hdf.py +++ b/pymchelper/writers/hdf.py @@ -1,6 +1,8 @@ import logging import numpy as np +from pymchelper.estimator import Estimator + logger = logging.getLogger(__name__) @@ -17,7 +19,7 @@ def __init__(self, filename, options): if not self.filename.endswith(".h5"): self.filename += ".h5" - def write(self, estimator): + def write(self, estimator: Estimator): if len(estimator.pages) == 0: print("No pages in the output file, conversion to HDF5 skipped.") return False @@ -42,7 +44,7 @@ def write(self, estimator): dset = hdf_file.create_dataset(dataset_name, data=page.data, compression="gzip", compression_opts=9) # save error (if present) - if not np.all(np.isnan(page.error_raw)) and np.any(page.error_raw): + if page.error is not None: hdf_file.create_dataset(dataset_error_name, data=page.error, compression="gzip", compression_opts=9) # save metadata