diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 05beb28cb..2274f86c8 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -24,7 +24,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: '3.11' - name: Install dependencies run: | @@ -39,17 +39,15 @@ jobs: - name: Smoke test with pytest run: | - pytest -k "smoke" tests/ + pytest -k "smoke" # all tests on matrix of all possible python versions and OSes normal_test: strategy: matrix: - python-version: ['3.8', '3.9', '3.10', '3.11'] + python-version: ['3.9', '3.10', '3.11'] platform: [ubuntu-latest, macos-latest, windows-latest] exclude: - - platform: macos-latest - python-version: '3.8' - platform: macos-latest python-version: '3.10' - platform: macos-latest @@ -74,4 +72,4 @@ jobs: - name: Test with pytest run: | - pytest -k "not slow" tests/ + pytest -k "not slow" diff --git a/.github/workflows/release-pip.yml b/.github/workflows/release-pip.yml index 6ac2492a2..8c3573ee2 100644 --- a/.github/workflows/release-pip.yml +++ b/.github/workflows/release-pip.yml @@ -15,7 +15,7 @@ jobs: full_test: strategy: matrix: - python-version: ['3.8', '3.9', '3.10', '3.11'] + python-version: ['3.9', '3.10', '3.11'] runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -33,7 +33,7 @@ jobs: - name: Run all test with pytest run: | - pytest -k "slow" tests/ + pytest -k "slow" - name: Check images generation for documentation run: | @@ -51,7 +51,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: '3.11' - name: Install dependencies run: | diff --git a/.github/workflows/release-win-exe.yml b/.github/workflows/release-win-exe.yml index c37ac3c77..36b25ed75 100644 --- a/.github/workflows/release-win-exe.yml +++ b/.github/workflows/release-win-exe.yml @@ -26,7 +26,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: 3.10 + python-version: 3.11 - name: Generate VERSION file run: python3 setup.py --help diff --git a/pymchelper/averaging.py b/pymchelper/averaging.py new file mode 100644 index 000000000..f55a84890 --- /dev/null +++ b/pymchelper/averaging.py @@ -0,0 +1,235 @@ +""" +Output from the simulation running on multiple processess needs to be agregated. +The simplest way of doing so is to calculated the average of the data. +There are however more sophisticated cases: COUNT scorers needs to be summed up, not averaged. +Phase space data needs to be concatenated, not averaged. +Each of the parallel jobs can have different number of histories, so the weighted average needs to be calculated. +Moreover, in case averaging is employed, we can estimate spread of the data as standard deviation or standard error. + +The binary output files from each job may be quite large (even ~GB for scoring in fine 3D mesh), to obtain good +statistics we sometimes parallelise the simulation of hundreds or thousands of jobs (when using HPC clusters). +In such case it's not feasible to load all the data into memory and calculate the average in one go, using standard +functions from numpy library. Instead, we need to calculate the average in an online manner, +i.e. by updating the state of respective aggregator object with each new binary output file read. + +Such approach results in a significant reduction of memory usage and is more numerically stable. + +This module contains several classes for aggregating data from multiple files: +- Aggregator: base class for all other aggregators +- WeightedStatsAggregator: for calculating weighted (using weights which not necessarily sums up 1) mean and variance +- ConcatenatingAggregator: for concatenating data +- SumAggregator: for calculating sum instead of variance +- NoAggregator: for cases when no aggregation is required + +All aggregators have `data` and `error` property, which can be used to obtain the result of the aggregation. +The `data` property returns the result of the aggregation: mean, sum or concatenated array. +The `error` property returns the spread of data for WeightedStatsAggregator, and `None` for other aggregators. + +The `update` method is used to update the state of the aggregator with new data from the file. + +For details on how this method is applied to average binary output of the MC codes, +see `fromfilelist` method from `input_output.py` module. +""" + +from dataclasses import dataclass, field +import logging +from typing import Union, Optional +import numpy as np +from numpy.typing import ArrayLike + + +@dataclass +class Aggregator: + """ + Base class for all aggregators. + The `data` property returns the result of the aggregation, needs to be implemented in derived classes. + The `error` function returns the spread of data, can be implemented in derived classes. It's a function, + not a property as different type of error can be calculated (standard deviation, standard error, etc.). + Type of errors may be then passed in optional keyword arguments `**kwargs`. + """ + + data: Union[float, ArrayLike] = float('nan') + _updated: bool = field(default=False, repr=False, init=False) + + def update(self, value: Union[float, ArrayLike], **kwargs): + """Update the state of the aggregator with new data.""" + raise NotImplementedError(f"Update function not implemented for {self.__class__.__name__}") + + def error(self, **kwargs): + """Default implementation of error function, returns None.""" + logging.debug("Error calculation not implemented for %s", self.__class__.__name__) + + @property + def updated(self) -> bool: + """ + Check if the aggregator was updated. The newly created aggregator is in the state of not being updated. + That means that no aggregation results are present via `data` or `error` properties. + We rely on the fact that `data` attribute is set to `nan` on creation of the aggregator object. + On first call to the `update` method the `data` is being filled with floating point numbers or arrays. + The `update` method sets also `self._updated` to True. + """ + if isinstance(self.data, float) and np.isnan(self.data): + return False + return self._updated + + +@dataclass +class WeightedStatsAggregator(Aggregator): + """ + Calculates weighted mean and variance of a sequence of numbers or numpy arrays. + We do not use frequency weights (which sums up to 1), the total sum of all weights is not known as + we are aggregating data from multiple files with different number of histories. + The aggregation uses single pass loop over all files. + + Good overview of currently known methods to calculate online weighted mean and variance can be found in [2]. + The original method to calculate online mean and variance was proposed by Welford in [1]. + The weighed version of this algoritm is nicely illustrated in [3]. + Here we employ algoritm proposed by West in [4] and descibed in Wikipedia [5]. + + [1] Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". + Technometrics. 4 (3): 419–420. + [2] Schubert, Erich, and Michael Gertz. "Numerically stable parallel computation of (co-) variance." + Proceedings of the 30th international conference on scientific and statistical database management. 2018. + [3] https://justinwillmert.com/posts/2022/notes-on-calculating-online-statistics/ + [4] West, D. H. D. (1979). "Updating Mean and Variance Estimates: An Improved Method". + Communications of the ACM. 22 (9): 532-535. + [5] https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm + """ + + data: Union[float, ArrayLike] = float('nan') + + _accumulator_S: Union[float, ArrayLike] = field(default=float('nan'), repr=False, init=False) + _total_weight_squared: float = field(default=0., repr=False, init=False) + total_weight: float = 0 + + def update(self, value: Union[float, ArrayLike], weight: float = 1.0, **kwargs): + """ + Update the state of the aggregator with new data. + Note that the weights are so called "reliability weights", not frequency weights. + If unsure put here the number of histories from the file if you are aggregating data from multiple files. + """ + if weight < 0: + raise ValueError("Weight must be non-negative") + + # first pass initialization + if not self.updated: + self.data = value * 0 + self._accumulator_S = value * 0 + + # W_n = W_{n-1} + w_n + self.total_weight += weight + self._total_weight_squared += weight**2 + + mean_old = self.data + # mu_n = (1 - w_n / W_n) * mu_{n-1} + (w_n / W_n) * x_n + # or in other words: + # mu_n - mu_{n-1} = (w_n / W_n) * (x_n - mu_{n-1}) + self.data += (weight / self.total_weight) * (value - mean_old) + + self._accumulator_S += weight * (value - self.data) * (value - mean_old) + + self._updated = True + logging.debug("Updated aggregator with value %s and weight %s", value, weight) + + @property + def mean(self) -> Union[float, ArrayLike]: + """Weighted mean of the sample""" + return self.data + + @property + def variance_population(self) -> Union[float, ArrayLike]: + """Biased estimate of the variance""" + if not self.updated: + raise ValueError("No data to calculate variance") + if self.total_weight <= 0: + raise ValueError("Total weight must be positive") + return self._accumulator_S / self.total_weight + + @property + def variance_sample(self) -> Union[float, ArrayLike]: + """ + Unbiased estimate of the variance. + The bias of the weighted estimator if (1 - sum w_i^2 / W_n^2), or in other words: + 1 - "sum of squares of weights" / "square of sum of weights". + For all equal weights the bias is 1 - n * w^2/((n * w)^2) = 1 - 1/n which + leads to the well known formula for the sample variance. + Here we use the weighted version of the formula. + """ + if not self.updated: + raise ValueError("No data to calculate variance") + if self.total_weight <= 0: + raise ValueError("Total weight must be positive") + return self._accumulator_S / (self.total_weight - (self._total_weight_squared / self.total_weight)) + + @property + def stddev(self) -> Union[float, ArrayLike]: + """Standard deviation of the sample""" + return np.sqrt(self.variance_sample) + + @property + def stderr(self) -> Union[float, ArrayLike]: + """ + Standard error of the sample. + For weighted data it is calculated as: + stddev * sqrt(sum w_i^2) / sum w_i + For equal weights it reduces via sqrt( n * w^2) / (n * w) = stddev / sqrt(n) + """ + return self.stddev * np.sqrt(self._total_weight_squared) / self.total_weight + + def error(self, **kwargs) -> Optional[Union[float, ArrayLike]]: + """ + Error calculation function, can be used to calculate standard deviation or standard error. + Type of error may be requested by `error_type` keyword argument with `stddev` or `stderr` values. + For other values or if the keyword argument is not present, None is returned. + """ + logging.debug("Calculating error with kwargs: %s", kwargs) + if 'error_type' in kwargs: + if kwargs['error_type'] == 'stddev': + return self.stddev + if kwargs['error_type'] == 'stderr': + return self.stderr + return None + + +@dataclass +class ConcatenatingAggregator(Aggregator): + """Class for concatenating numpy arrays""" + + def update(self, value: Union[float, ArrayLike], **kwargs): + """Update the state of the aggregator with new data.""" + if not self.updated: + self.data = value + else: + self.data = np.concatenate((self.data, value)) + self._updated = True + + +@dataclass +class SumAggregator(Aggregator): + """Class for calculating sum of a sequence of numbers.""" + + def update(self, value: Union[float, ArrayLike], **kwargs): + """Update the state of the aggregator with new data.""" + # first value added + if not self.updated: + self.data = value + # subsequent values added + else: + self.data += value + self._updated = True + + +@dataclass +class NoAggregator(Aggregator): + """ + Class for cases when no aggregation is required. + Sets the data to the first value and does not update it. + """ + + def update(self, value: Union[float, ArrayLike], **kwargs): + """Update the state of the aggregator with new data.""" + # set value only on first update + if not self.updated: + logging.debug("Setting data to %s", value) + self.data = value + self._updated = True diff --git a/pymchelper/estimator.py b/pymchelper/estimator.py index c32611880..f7eee212e 100644 --- a/pymchelper/estimator.py +++ b/pymchelper/estimator.py @@ -153,5 +153,5 @@ def average_with_nan(estimator_list, error_estimate=ErrorEstimate.stderr): page.error_raw /= np.sqrt(result.file_counter) # np.sqrt() always returns np.float64 else: for page in result.pages: - page.error_raw = np.zeros_like(page.data_raw) + page.error_raw = None return result diff --git a/pymchelper/executor/runner.py b/pymchelper/executor/runner.py index 6a0d6b5b3..fa805458b 100644 --- a/pymchelper/executor/runner.py +++ b/pymchelper/executor/runner.py @@ -14,7 +14,8 @@ from pymchelper.executor.options import SimulationSettings from pymchelper.simulator_type import SimulatorType -from pymchelper.input_output import frompattern, get_topas_estimators +from pymchelper.readers.topas import get_topas_estimators +from pymchelper.input_output import frompattern class OutputDataType(IntEnum): @@ -30,8 +31,12 @@ class Runner: Main class responsible for configuring and starting multiple parallel MC simulation processes It can be used to access combined averaged results of the simulation. """ - def __init__(self, settings: SimulationSettings, jobs: int=None, - keep_workspace_after_run: bool=False, output_directory: str='.'): + + def __init__(self, + settings: SimulationSettings, + jobs: int = None, + keep_workspace_after_run: bool = False, + output_directory: str = '.'): self.settings = settings # create pool of processes, waiting to be started by run method @@ -234,6 +239,7 @@ class WorkspaceManager: A workspace consists of multiple working directories (i.e. run_1, run_2), each per one of the parallel simulation run. """ + def __init__(self, output_directory='.', keep_workspace_after_run=False): self.output_dir_absolute_path = os.path.abspath(output_directory) self.keep_workspace_after_run = keep_workspace_after_run diff --git a/pymchelper/input_output.py b/pymchelper/input_output.py index ec06f86d6..642d47d3c 100644 --- a/pymchelper/input_output.py +++ b/pymchelper/input_output.py @@ -1,12 +1,14 @@ +from enum import IntEnum import logging +import gc import os from collections import defaultdict from glob import glob from pathlib import Path from typing import List, Optional -import numpy as np - +from pymchelper.averaging import (Aggregator, SumAggregator, WeightedStatsAggregator, ConcatenatingAggregator, + NoAggregator) from pymchelper.estimator import ErrorEstimate, Estimator, average_with_nan from pymchelper.readers.topas import TopasReaderFactory from pymchelper.readers.fluka import FlukaReader, FlukaReaderFactory @@ -17,10 +19,30 @@ logger = logging.getLogger(__name__) +class AggregationType(IntEnum): + """ + Enum for different types of aggregation. + This enum is related to integer value stored in SHIELD-HIT12A binary files, + which defines how the data is aggregated. + Below few examples of how such aggregation is used in SHIELD-HIT12A: + - NoAggregation is used for density (RHO) and material scorer. + - Sum is used for particle counter (COUNT). + - AveragingCumulative is used for dose and fluence scorers. + - AveragingPerPrimary is used for LET scorers (TLET and DLET). + - Concatenation is used for phase space (MCPL) scorer. + """ + + NoAggregation = 0 + Sum = 1 + AveragingCumulative = 2 + AveragingPerPrimary = 3 + Concatenation = 4 + + def guess_reader(filename): """ Guess a reader based on file contents or extensions. - In some cases (i.e. binary SH12A files) access to file contents is needed. + In some cases (i.e. binary SHIELD-HIT12A files) access to file contents is needed. :param filename: :return: Instantiated reader object """ @@ -42,7 +64,7 @@ def guess_reader(filename): def guess_corename(filename): """ Guess a reader based on file contents or extensions. - In some cases (i.e. binary SH12A files) access to file contents is needed. + In some cases (i.e. binary SHIELD-HIT12A files) access to file contents is needed. :param filename: :return: the corename of the file (i.e. the basename without the running number for averaging) """ @@ -53,7 +75,14 @@ def guess_corename(filename): def fromfile(filename: str) -> Optional[Estimator]: - """Read estimator data from a binary file ```filename```""" + """ + Read estimator data from a binary file `filename` + Note that for the in some cases the data are post-processes (i.e. normalized) after reading. + For example SHIELD-HIT12A saves dose and fluence data as cumulative values, + which are normalized by the number of primaries after by the Reader responsible for parsing binary files. + This way dose and fluence (and other similar quantities) are saved in Estimator as "per primary" values. + Fluka on the other hand saves dose and fluence as "per primary" values, so no normalization is needed. + """ reader = guess_reader(filename) if reader is None: @@ -66,9 +95,11 @@ def fromfile(filename: str) -> Optional[Estimator]: return estimator -def fromfilelist(input_file_list, error: ErrorEstimate = ErrorEstimate.stderr, nan: bool = True) -> Optional[Estimator]: +def fromfilelist(input_file_list, + error: ErrorEstimate = ErrorEstimate.stderr, + nan: bool = False) -> Optional[Estimator]: """ - Reads all files from a given list, and returns a list of averaged estimators. + Reads all files from a given list using `fromfile` method, and returns a list of averaged estimators. :param input_file_list: list of files to be read :param error: error estimation, see class ErrorEstimate class in pymchelper.estimator @@ -81,8 +112,6 @@ def fromfilelist(input_file_list, error: ErrorEstimate = ErrorEstimate.stderr, n if nan: estimator_list = [fromfile(filename) for filename in input_file_list] result = average_with_nan(estimator_list, error) - if not result: # TODO check here ! - return None elif len(input_file_list) == 1: result = fromfile(input_file_list[0]) if not result: @@ -92,53 +121,54 @@ def fromfilelist(input_file_list, error: ErrorEstimate = ErrorEstimate.stderr, n if not result: return None - # allocate memory for accumulator in standard deviation calculation - # not needed if user requested not to include errors - if error != ErrorEstimate.none: - for page in result.pages: - page.error_raw = np.zeros_like(page.data_raw) - - # loop over all files with n running from 2 - for n, filename in enumerate(input_file_list[1:], start=2): - current_estimator = fromfile(filename) # x - logger.info("Reading file %s (%d/%d)", filename, n, len(input_file_list)) - - if not current_estimator: - logger.warning("File %s could not be read", filename) - return None - + # _aggregator_mapping maps SHIELD-HIT12A normalization types (integers) to pymchelper aggregators + # using enums for clarity. AveragingCumulative (e.g., dose) and AveragingPerPrimary (e.g., LET) + # both utilize WeightedStatsAggregator. SHIELD-HIT12A stores "cumulative-like" data (e.g., dose, + # fluence) in BDO format as quantities for all particles. pymchelper normalizes this upon reading + # a BDO file by the number of primaries, making the `estimator` object data pre-normalized. Hence, + # aggregation for "cumulative-like" and "per-primary" data is handled uniformly in this mapping. + _aggregator_mapping: dict[AggregationType, Aggregator] = { + AggregationType.NoAggregation: NoAggregator, + AggregationType.Sum: SumAggregator, + AggregationType.AveragingCumulative: WeightedStatsAggregator, + AggregationType.AveragingPerPrimary: WeightedStatsAggregator, + AggregationType.Concatenation: ConcatenatingAggregator + } + + # create aggregators for each page and fill them with data from first file + page_aggregators = [] + for page in result.pages: + + # if no normalization attribute present (Fluka?) we can assume it is a cumulative-like quantity + current_page_normalisation = getattr(page, 'page_normalized', AggregationType.AveragingCumulative.value) + + # guess the aggregator based on the normalisation type + aggregator = _aggregator_mapping.get(current_page_normalisation, WeightedStatsAggregator)() + logger.debug("Selected aggregator %s for page %s", aggregator, page.name) + + # feed the aggregator with data from the first file + aggregator.update(value=page.data_raw, weight=result.number_of_primaries) + page_aggregators.append(aggregator) + + # process all other files, if there are any + for filename in input_file_list[1:]: + current_estimator = fromfile(filename) + for current_page, aggregator in zip(current_estimator.pages, page_aggregators): + aggregator.update(value=current_page.data_raw, weight=current_estimator.number_of_primaries) + + # force garbage collection if the estimator is too large + estimator_size_mbytes = sum(page.data_raw.nbytes for page in current_estimator.pages) / 1024 / 1024 + gc_threshold_mbytes = 100 + if estimator_size_mbytes > gc_threshold_mbytes: + logger.info("Large estimator (%.1f MB) detected, performing garbage collection", estimator_size_mbytes) + gc.collect() result.number_of_primaries += current_estimator.number_of_primaries - for current_page, result_page in zip(current_estimator.pages, result.pages): - # got a page with "concatenate normalisation" - if getattr(current_page, 'page_normalized', 2) == 4: - logger.debug("Concatenating page %s", current_page.name) - result_page.data_raw = np.concatenate((result_page.data_raw, current_page.data_raw)) - else: - logger.debug("Averaging page %s", current_page.name) - # Running variance algorithm based on algorithm by B. P. Welford, - # presented in Donald Knuth's Art of Computer Programming, Vol 2, page 232, 3rd edition. - # Can be found here: http://www.johndcook.com/blog/standard_deviation/ - # and https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm - delta = current_page.data_raw - result_page.data_raw # delta = x - mean - result_page.data_raw += delta / np.float64(n) - if error != ErrorEstimate.none: - # the line below is equivalent to M2 += delta * (x - mean) - result_page.error_raw += delta * (current_page.data_raw - result_page.data_raw) - - # unbiased sample variance is stored in `__M2 / (n - 1)` - # unbiased sample standard deviation in classical algorithm is calculated as (sqrt(1/(n-1)sum(x-)**2) - # here it is calculated as square root of unbiased sample variance: - if len(input_file_list) > 1 and error != ErrorEstimate.none: - for page in result.pages: - page.error_raw = np.sqrt(page.error_raw / (len(input_file_list) - 1.0)) - - # if user requested standard error then we calculate it as: - # S = stderr = stddev / sqrt(N), or in other words, - # S = s/sqrt(N) where S is the corrected standard deviation of the mean. - if len(input_file_list) > 1 and error == ErrorEstimate.stderr: - for page in result.pages: - page.error_raw /= np.sqrt(len(input_file_list)) # np.sqrt() always returns np.float64 + # extract data from aggregators and fill then into the result + for page, aggregator in zip(result.pages, page_aggregators): + logger.debug("Extracting data from aggregator %s for page %s", aggregator, page.name) + page.data_raw = aggregator.data + page.error_raw = aggregator.error(error_type=error.name) result.file_counter = len(input_file_list) core_names_dict = group_input_files(input_file_list) diff --git a/pymchelper/page.py b/pymchelper/page.py index b5dc7d1dc..53b1ded05 100644 --- a/pymchelper/page.py +++ b/pymchelper/page.py @@ -5,15 +5,16 @@ class Page: + def __init__(self, estimator=None): self.estimator = estimator self.data_raw = np.array([float("NaN")]) # linear data storage - self.error_raw = np.array([float("NaN")]) # linear data storage + self.error_raw = None # linear data storage - self.name : str = "" - self.unit : str = "" + self.name: str = "" + self.unit: str = "" self.dettyp = None # Dose, Fluence, LET etc... @@ -96,8 +97,8 @@ def data(self): if self.dettyp == SHDetType.mcpl: return self._reshape(data_1d=self.data_raw, shape=(8, -1)) return self._reshape(data_1d=self.data_raw, - shape=(self.estimator.x.n, self.estimator.y.n, self.estimator.z.n, - self.diff_axis1.n, self.diff_axis2.n)) + shape=(self.estimator.x.n, self.estimator.y.n, self.estimator.z.n, self.diff_axis1.n, + self.diff_axis2.n)) return self.data_raw @property @@ -110,12 +111,14 @@ def error(self): """ if self.estimator: return self._reshape(data_1d=self.error_raw, - shape=(self.estimator.x.n, self.estimator.y.n, self.estimator.z.n, - self.diff_axis1.n, self.diff_axis2.n)) + shape=(self.estimator.x.n, self.estimator.y.n, self.estimator.z.n, self.diff_axis1.n, + self.diff_axis2.n)) return self.error_raw - def _reshape(self, data_1d, shape : tuple): + def _reshape(self, data_1d, shape: tuple): # TODO check also tests/res/shieldhit/single/ex_yzmsh.bdo as it is saved in bin2010 format + if data_1d is None: + return None if self.estimator: order = 'C' if self.estimator.file_format in {'bdo2016', 'bdo2019', 'fluka_binary'}: @@ -124,7 +127,7 @@ def _reshape(self, data_1d, shape : tuple): else: return data_1d - def plot_axis(self, id : int): + def plot_axis(self, id: int): """ Calculate new order of detector axis, axis with data (n>1) comes first Axes with constant value goes last. diff --git a/pymchelper/readers/common.py b/pymchelper/readers/common.py index 602c529a0..08224e349 100644 --- a/pymchelper/readers/common.py +++ b/pymchelper/readers/common.py @@ -1,12 +1,11 @@ from abc import abstractmethod import logging -import numpy as np - logger = logging.getLogger(__name__) class ReaderFactory(object): + def __init__(self, filename): self.filename = filename @@ -16,6 +15,7 @@ def get_reader(self): class Reader(object): + def __init__(self, filename): self.filename = filename @@ -24,7 +24,7 @@ def read(self, estimator): if not result: return False for page in estimator.pages: - page.error_raw = np.zeros_like(page.data_raw) * np.nan + page.error_raw = None return True @abstractmethod diff --git a/pymchelper/readers/fluka.py b/pymchelper/readers/fluka.py index c5f5abd2d..5db52458e 100644 --- a/pymchelper/readers/fluka.py +++ b/pymchelper/readers/fluka.py @@ -106,7 +106,6 @@ def parse_usrbin(self, estimator): # TODO cross-check if reshaping is needed page.data_raw = np.array(unpackArray(usr_object.readData(det_no))) page.data_raw *= rescaling_factor - page.error_raw = np.empty_like(page.data_raw) estimator.add_page(page) @@ -180,7 +179,6 @@ def parse_usrbdx(self, estimator): # unpack detector data # TODO cross-check if reshaping is needed page.data_raw = np.array(unpackArray(usr_object.readData(det_no))) - page.error_raw = np.empty_like(page.data_raw) estimator.add_page(page) @@ -241,7 +239,6 @@ def parse_usrtrack(self, estimator): # unpack detector data # TODO cross-check if reshaping is needed page.data_raw = np.array(unpackArray(usr_object.readData(det_no))) - page.error_raw = np.empty_like(page.data_raw) estimator.add_page(page) return usr_object @@ -263,7 +260,6 @@ def parse_resnuclei(self, estimator): # unpack detector data # TODO cross-check if reshaping is needed page.data_raw = np.array(unpackArray(usr_object.readData(det_no))) - page.error_raw = np.empty_like(page.data_raw) estimator.add_page(page) return usr_object diff --git a/pymchelper/readers/shieldhit/reader_base.py b/pymchelper/readers/shieldhit/reader_base.py index b5e1b4206..815ecdd4f 100644 --- a/pymchelper/readers/shieldhit/reader_base.py +++ b/pymchelper/readers/shieldhit/reader_base.py @@ -64,8 +64,9 @@ def mesh_unit_and_name(estimator, axis): unit = _geotyp_units.get(estimator.geotyp, _default_units)[axis] - if estimator.geotyp in {SHGeoType.msh, SHGeoType.dmsh, SHGeoType.voxscore, SHGeoType.geomap, - SHGeoType.plane, SHGeoType.dplane}: + if estimator.geotyp in { + SHGeoType.msh, SHGeoType.dmsh, SHGeoType.voxscore, SHGeoType.geomap, SHGeoType.plane, SHGeoType.dplane + }: name = ("Position (X)", "Position (Y)", "Position (Z)")[axis] elif estimator.geotyp in {SHGeoType.cyl, SHGeoType.dcyl}: name = ("Radius (R)", "Angle (PHI)", "Position (Z)")[axis] @@ -161,9 +162,7 @@ def read_next_token(f): f is an open and readable file pointer. returns None if no token was found / EOF """ - tag = np.dtype([('pl_id', ' List[Estimator]: + """Get Topas estimators from provided directory""" + estimators_list = [] + for path in Path(output_files_path).iterdir(): + topas_reader = TopasReaderFactory(str(path)).get_reader() + if topas_reader: + reader = topas_reader(path) + estimator = Estimator() + reader.read(estimator) + estimators_list.append(estimator) + + return estimators_list + + def extract_parameter_filename(header_line: str) -> Optional[str]: """Get parameter filename from the output file""" pattern = r"# Parameter File: (.*)" @@ -43,9 +57,7 @@ def extract_bins_data(dimensions: List[str], header_lines: List[str]) -> Optiona for line_index, dimension in enumerate(dimensions): match = re.search(pattern.format(dimension), header_lines[line_index]) if match: - bins_data[dimension] = {'num': int(match.group(1)), - 'size': float(match.group(2)), - 'unit': match.group(3)} + bins_data[dimension] = {'num': int(match.group(1)), 'size': float(match.group(2)), 'unit': match.group(3)} else: return None return bins_data @@ -63,10 +75,11 @@ def extract_scorer_name(header_line: str) -> Optional[str]: def extract_scorer_unit_results(header_line: str) -> Optional[Tuple[str, str, List]]: """Get scoring quantity, unit and the scoring values (sum/mean/etc.) from the output file""" - scorers = ['DoseToMedium', 'DoseToWater', 'DoseToMaterial', 'TrackLengthEstimator', - 'AmbientDoseEquivalent', 'EnergyDeposit', 'Fluence', 'EnergyFluence', - 'StepCount', 'OpticalPhotonCount', 'OriginCount', 'Charge', 'EffectiveCharge', - 'ProtonLET', 'SurfaceCurrent', 'SurfaceTrackCount', 'PhaseSpace'] + scorers = [ + 'DoseToMedium', 'DoseToWater', 'DoseToMaterial', 'TrackLengthEstimator', 'AmbientDoseEquivalent', + 'EnergyDeposit', 'Fluence', 'EnergyFluence', 'StepCount', 'OpticalPhotonCount', 'OriginCount', 'Charge', + 'EffectiveCharge', 'ProtonLET', 'SurfaceCurrent', 'SurfaceTrackCount', 'PhaseSpace' + ] for scorer in scorers: if scorer in header_line: unit = "" @@ -93,8 +106,12 @@ def extract_differential_axis(header_line: str) -> Optional[MeshAxis]: min_val = float(match.group(5)) max_val = float(match.group(7)) - return MeshAxis(n=num_bins, min_val=min_val, max_val=max_val, - name=binned_by, unit=unit, binning=MeshAxis.BinningType.linear) + return MeshAxis(n=num_bins, + min_val=min_val, + max_val=max_val, + name=binned_by, + unit=unit, + binning=MeshAxis.BinningType.linear) return None @@ -105,6 +122,7 @@ def __init__(self, filename): super(TopasReader, self).__init__(filename) self.directory = Path(filename).parent + # skipcq: PY-R1000 def read_data(self, estimator: Estimator) -> bool: """ Read the data from the file and store them in the provided estimator object. @@ -137,30 +155,34 @@ def read_data(self, estimator: Estimator) -> bool: estimator.number_of_primaries = num_histories estimator.file_format = "csv" - dimensions = [['X', 'Y', 'Z'], - ['R', 'Phi', 'Z'], - ['R', 'Phi', 'Theta']] + dimensions = [['X', 'Y', 'Z'], ['R', 'Phi', 'Z'], ['R', 'Phi', 'Theta']] actual_dimensions = None for curr_dimensions in dimensions: - for line_index in range(len(header_lines)-3): - bins_data = extract_bins_data(curr_dimensions, header_lines[line_index:line_index+3]) + for line_index in range(len(header_lines) - 3): + bins_data = extract_bins_data(curr_dimensions, header_lines[line_index:line_index + 3]) if bins_data is not None: actual_dimensions = curr_dimensions - x_max = bins_data[actual_dimensions[0]]['size']*bins_data[actual_dimensions[0]]['num'] + x_max = bins_data[actual_dimensions[0]]['size'] * bins_data[actual_dimensions[0]]['num'] estimator.x = MeshAxis(n=bins_data[actual_dimensions[0]]['num'], - min_val=0.0, max_val=x_max, - name=actual_dimensions[0], unit=bins_data[actual_dimensions[0]]['unit'], + min_val=0.0, + max_val=x_max, + name=actual_dimensions[0], + unit=bins_data[actual_dimensions[0]]['unit'], binning=MeshAxis.BinningType.linear) - y_max = bins_data[actual_dimensions[1]]['size']*bins_data[actual_dimensions[1]]['num'] + y_max = bins_data[actual_dimensions[1]]['size'] * bins_data[actual_dimensions[1]]['num'] estimator.y = MeshAxis(n=bins_data[actual_dimensions[1]]['num'], - min_val=0.0, max_val=y_max, - name=actual_dimensions[1], unit=bins_data[actual_dimensions[1]]['unit'], + min_val=0.0, + max_val=y_max, + name=actual_dimensions[1], + unit=bins_data[actual_dimensions[1]]['unit'], binning=MeshAxis.BinningType.linear) - z_max = bins_data[actual_dimensions[2]]['size']*bins_data[actual_dimensions[2]]['num'] + z_max = bins_data[actual_dimensions[2]]['size'] * bins_data[actual_dimensions[2]]['num'] estimator.z = MeshAxis(n=bins_data[actual_dimensions[2]]['num'], - min_val=0.0, max_val=z_max, - name=actual_dimensions[2], unit=bins_data[actual_dimensions[2]]['unit'], + min_val=0.0, + max_val=z_max, + name=actual_dimensions[2], + unit=bins_data[actual_dimensions[2]]['unit'], binning=MeshAxis.BinningType.linear) no_bins = False break @@ -195,7 +217,6 @@ def read_data(self, estimator: Estimator) -> bool: num_results = len(results) page = Page(estimator=estimator) set_data = False - set_error = False for column, result in enumerate(results): if result not in ['Mean', 'Standard_Deviation']: continue @@ -216,8 +237,8 @@ def read_data(self, estimator: Estimator) -> bool: else: return False lines = np.genfromtxt(self.filename, delimiter=',') - last_bin_index = len(lines[0]) - num_results*additional_bins - scores = lines[:, column+num_results:last_bin_index:num_results].flatten() + last_bin_index = len(lines[0]) - num_results * additional_bins + scores = lines[:, column + num_results:last_bin_index:num_results].flatten() else: # When there is no differential axis, each line in csv file looks like this: @@ -233,7 +254,7 @@ def read_data(self, estimator: Estimator) -> bool: scores = np.array([data[column]]) else: lines = np.genfromtxt(self.filename, delimiter=',') - scores = lines[:, column+3] + scores = lines[:, column + 3] for line in header_lines: title = extract_scorer_name(line) @@ -249,13 +270,10 @@ def read_data(self, estimator: Estimator) -> bool: set_data = True elif result == 'Standard_Deviation': page.error_raw = scores - set_error = True # If we didn't find mean results for the scorer, we return False if not set_data: return False - if not set_error: - page.error_raw = np.empty_like(page.data_raw) estimator.add_page(page) return True diff --git a/pymchelper/writers/excel.py b/pymchelper/writers/excel.py index e65224888..54c5a9655 100644 --- a/pymchelper/writers/excel.py +++ b/pymchelper/writers/excel.py @@ -1,6 +1,6 @@ import logging -import numpy as np +from pymchelper.estimator import Estimator logger = logging.getLogger(__name__) @@ -9,45 +9,44 @@ class ExcelWriter: """ Supports writing XLS files (MS Excel 2003 format) """ + def __init__(self, filename, options): self.filename = filename if not self.filename.endswith(".xls"): self.filename += ".xls" - def write(self, estimator): - if len(estimator.pages) > 1: - print("Conversion of data with multiple pages not supported yet") - return False - + def write(self, estimator: Estimator): try: import xlwt except ImportError as e: - logger.error("Generating Excel files not available on your platform (you are probably running Python 3.2).") + logger.error("Generating Excel files not available on your platform (please install xlwt).") raise e - page = estimator.pages[0] + # create workbook + wb = xlwt.Workbook() - # save only 1-D data - if page.dimension != 1: - logger.warning("page dimension {:d} != 1, XLS output not supported".format(estimator.dimension)) - return 1 + for page_id, page in enumerate(estimator.pages): - # create workbook with single sheet - wb = xlwt.Workbook() - ws = wb.add_sheet('Data') + # save only 1-D data + if page.dimension != 1: + logger.warning("page dimension {:d} != 1, XLS output not supported".format(estimator.dimension)) + return 1 + + # create worksheet + ws = wb.add_sheet(f'Data_p{page_id}') - # save X axis data - for i, x in enumerate(page.plot_axis(0).data): - ws.write(i, 0, x) + # save X axis data + for i, x in enumerate(page.plot_axis(0).data): + ws.write(i, 0, x) - # save Y axis data - for i, y in enumerate(page.data_raw): - ws.write(i, 1, y) + # save Y axis data + for i, y in enumerate(page.data_raw): + ws.write(i, 1, y) - # save error column (if present) - if np.all(np.isfinite(page.error_raw)): - for i, e in enumerate(page.error_raw): - ws.write(i, 2, e) + # save error column (if present) + if page.error_raw is not None: + for i, e in enumerate(page.error_raw): + ws.write(i, 2, e) # save file logger.info("Writing: " + self.filename) diff --git a/pymchelper/writers/hdf.py b/pymchelper/writers/hdf.py index 5a5506e0c..dac754ba9 100644 --- a/pymchelper/writers/hdf.py +++ b/pymchelper/writers/hdf.py @@ -1,5 +1,6 @@ import logging -import numpy as np + +from pymchelper.estimator import Estimator logger = logging.getLogger(__name__) @@ -17,7 +18,7 @@ def __init__(self, filename, options): if not self.filename.endswith(".h5"): self.filename += ".h5" - def write(self, estimator): + def write(self, estimator: Estimator): if len(estimator.pages) == 0: print("No pages in the output file, conversion to HDF5 skipped.") return False @@ -42,7 +43,7 @@ def write(self, estimator): dset = hdf_file.create_dataset(dataset_name, data=page.data, compression="gzip", compression_opts=9) # save error (if present) - if not np.all(np.isnan(page.error_raw)) and np.any(page.error_raw): + if page.error is not None: hdf_file.create_dataset(dataset_error_name, data=page.error, compression="gzip", compression_opts=9) # save metadata diff --git a/pymchelper/writers/inspector.py b/pymchelper/writers/inspector.py index e6057b134..dec8eaf8f 100644 --- a/pymchelper/writers/inspector.py +++ b/pymchelper/writers/inspector.py @@ -1,14 +1,16 @@ import logging +from pymchelper.estimator import Estimator logger = logging.getLogger(__name__) class Inspector: + def __init__(self, filename, options): logger.debug("Initialising Inspector writer") self.options = options - def write(self, estimator): + def write(self, estimator: Estimator): """Print all keys and values from estimator structure they include also a metadata read from binary output file @@ -16,7 +18,7 @@ def write(self, estimator): for name, value in sorted(estimator.__dict__.items()): # skip non-metadata fields if name not in {'data', 'data_raw', 'error', 'error_raw', 'counter', 'pages'}: - line = "{:24s}: '{:s}'".format(str(name), str(value)) + line = f"{name:24s}: {value}" print(line) # print some data-related statistics print(75 * "*") @@ -26,10 +28,13 @@ def write(self, estimator): for name, value in sorted(page.__dict__.items()): # skip non-metadata fields if name not in {'data', 'data_raw', 'error', 'error_raw'}: - line = "\t{:24s}: '{:s}'".format(str(name), str(value)) + line = f"\t{name:24s}: {value}" print(line) - print("Data min: {:g}, max: {:g}, mean: {:g}".format( - page.data_raw.min(), page.data_raw.max(), page.data_raw.mean())) + print(f"Data min: {page.data.min():g}, max: {page.data.max():g}, mean: {page.data.mean():g}") + if page.error is not None: + print(f"Error min: {page.error.min():g}, max: {page.error.max():g}, mean: {page.error.mean():g}") + else: + print("No error data") print(75 * "-") if self.options.details: diff --git a/pymchelper/writers/plots.py b/pymchelper/writers/plots.py index a79e2884a..033097195 100644 --- a/pymchelper/writers/plots.py +++ b/pymchelper/writers/plots.py @@ -37,7 +37,7 @@ def write_single_page(self, page: Page, output_path: Path): # special case for 0-dim data if page.dimension == 0: # save two numbers to the file - if not np.all(np.isnan(page.error_raw)) and np.any(page.error_raw): + if page.error is not None: np.savetxt(self.output_path, [[page.data_raw, page.error_raw]], fmt="%g %g", delimiter=' ') else: # save one number to the file np.savetxt(self.output_path, [page.data_raw], fmt="%g", delimiter=' ') @@ -56,7 +56,7 @@ def write_single_page(self, page: Page, output_path: Path): data_to_save = axis_data_columns_long + [page.data_raw] # if error information is present save it as additional column - if not np.all(np.isnan(page.error_raw)) and np.any(page.error_raw): + if page.error is not None: fmt += " %g" data_to_save += [page.error_raw] @@ -124,7 +124,7 @@ def get_page_figure(self, page): ax.set_yscale('log') # add optional error area - if np.any(page.error): + if page.error is not None: ax.fill_between(plot_x_axis.data, (data_raw - error_raw).clip(0.0), (data_raw + error_raw).clip(0.0, 1.05 * data_raw.max()), alpha=0.2, diff --git a/pymchelper/writers/shieldhit.py b/pymchelper/writers/shieldhit.py index 32076e5ca..b36d1c133 100644 --- a/pymchelper/writers/shieldhit.py +++ b/pymchelper/writers/shieldhit.py @@ -10,6 +10,7 @@ class SHBinaryWriter: + def __init__(self, filename, options): self.filename = filename @@ -19,6 +20,7 @@ def write(self, estimator): class TxtWriter: + @staticmethod def _axis_name(geo_type, axis_no): cyl = ('R', 'PHI', 'Z') @@ -40,11 +42,20 @@ def __init__(self, filename, options): def _header_first_line(estimator): """first line with estimator geo type""" result = "# DETECTOR OUTPUT\n" - if estimator.geotyp in (SHGeoType.plane, SHGeoType.dplane,): + if estimator.geotyp in ( + SHGeoType.plane, + SHGeoType.dplane, + ): result = "# DETECTOR OUTPUT PLANE/DPLANE\n" - elif estimator.geotyp in (SHGeoType.zone, SHGeoType.dzone,): + elif estimator.geotyp in ( + SHGeoType.zone, + SHGeoType.dzone, + ): result = "# DETECTOR OUTPUT ZONE/DZONE\n" - elif estimator.geotyp in (SHGeoType.msh, SHGeoType.dmsh,): + elif estimator.geotyp in ( + SHGeoType.msh, + SHGeoType.dmsh, + ): result = "# DETECTOR OUTPUT MSH/DMSH\n" elif estimator.geotyp == SHGeoType.geomap: result = "# DETECTOR OUTPUT GEOMAP\n" @@ -74,7 +85,10 @@ def _header_scored_value(geotyp, dettyp, particle): result += f"# JPART:{particle:6d} DETECTOR TYPE: {str(dettyp).ljust(10)}\n" else: det_type_name = str(dettyp) - if dettyp in (SHDetType.zone, SHDetType.medium,): + if dettyp in ( + SHDetType.zone, + SHDetType.medium, + ): det_type_name += "#" result += f"# DETECTOR TYPE: {str(det_type_name).ljust(10)}\n" return result @@ -145,8 +159,8 @@ def write_single_page(self, page, filename): header += self._header_geometric_info(page.estimator) - header += self._header_scored_value( - page.estimator.geotyp, page.dettyp, getattr(page.estimator, 'particle', None)) + header += self._header_scored_value(page.estimator.geotyp, page.dettyp, + getattr(page.estimator, 'particle', None)) header += self._header_no_of_bins_and_prim(page.estimator) @@ -155,9 +169,10 @@ def write_single_page(self, page, filename): logger.info("Writing: %s", filename) fout.write(header) - det_error = page.error_raw.ravel() - if np.all(np.isnan(page.error_raw)): + if page.error is None: det_error = [None] * page.data_raw.size + else: + det_error = page.error_raw.ravel() xmesh = page.axis(0) ymesh = page.axis(1) zmesh = page.axis(2) diff --git a/setup.py b/setup.py index 6c2ec291d..5d0012515 100644 --- a/setup.py +++ b/setup.py @@ -34,10 +34,7 @@ def write_version_py(): 'image': ['matplotlib'], 'excel': ['xlwt'], 'hdf': ['h5py'], - 'dicom': [ - "pydicom>=2.3.1 ; python_version == '3.11'", - "pydicom ; python_version < '3.11'" - ], + 'dicom': ["pydicom>=2.3.1 ; python_version == '3.11'", "pydicom ; python_version < '3.11'"], 'pytrip': [ 'scipy', "pytrip98>=3.8.0,<3.9.0 ; platform_system == 'Darwin' and python_version >= '3.11'", @@ -81,10 +78,8 @@ def write_version_py(): # |---------------------------------------------------| # see https://www.python.org/dev/peps/pep-0508/ for language specification install_requires = [ - "numpy>=1.23.3 ; python_version == '3.11'", - "numpy>=1.21 ; python_version == '3.10'", - "numpy>=1.20,<1.26.0 ; python_version == '3.9'", - "numpy>=1.18,<1.26.0 ; python_version == '3.8'", + "numpy>=1.23.3 ; python_version == '3.11'", "numpy>=1.21 ; python_version == '3.10'", + "numpy>=1.20,<1.26.0 ; python_version == '3.9'" ] setuptools.setup( @@ -118,7 +113,6 @@ def write_version_py(): # Specify the Python versions you support here. In particular, ensure # that you indicate whether you support Python 2, Python 3 or both. - 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', @@ -134,5 +128,5 @@ def write_version_py(): package_data={'pymchelper': ['flair/db/*', 'VERSION']}, install_requires=install_requires, extras_require=EXTRAS_REQUIRE, - python_requires='>=3.8', + python_requires='>=3.9', ) diff --git a/tests/integration/shieldhit/test_averaging.py b/tests/integration/shieldhit/test_averaging.py new file mode 100644 index 000000000..9ecb518ab --- /dev/null +++ b/tests/integration/shieldhit/test_averaging.py @@ -0,0 +1,95 @@ +import logging +from pathlib import Path +from typing import Generator + +import numpy as np +import pytest + +from pymchelper.input_output import fromfile, fromfilelist + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope='module') +def averaging_bdos_directory(main_dir) -> Generator[Path, None, None]: + """Path to directory with BDO files""" + yield main_dir / "res" / "shieldhit" / "averaging" + + +@pytest.fixture(scope='function') +def large_stat_bdo_pattern() -> Generator[str, None, None]: + """Part of filename denoting large statistics BDO file""" + yield "_000?.bdo" + + +@pytest.fixture(scope='function') +def small_stat_bdo_pattern() -> Generator[str, None, None]: + """Part of filename denoting small statistics BDO file""" + yield "_001?.bdo" + + +@pytest.mark.parametrize("output_type", [ + "normalisation-1_aggregation-none", "normalisation-2_aggregation-sum", "normalisation-3_aggregation-mean", + "normalisation-4_aggregation-concat", "normalisation-5_aggregation-mean" +]) +def test_aggregating_equal_stats(averaging_bdos_directory, small_stat_bdo_pattern, large_stat_bdo_pattern, output_type): + """ + Check if data from averaged estimator is equal to data from all estimators + In both sets of data, the same number of primary particles was used + Therefore we can use simple averaging + """ + for stat_pattern in (small_stat_bdo_pattern, large_stat_bdo_pattern): + bdo_file_pattern = f"{output_type}{stat_pattern}" + bdo_file_list = list(averaging_bdos_directory.glob(bdo_file_pattern)) + + averaged_estimators = fromfilelist(input_file_list=[str(path) for path in bdo_file_list]) + assert len(averaged_estimators.pages) > 0 + + list_of_estimators_for_each_file = [fromfile(str(path)) for path in bdo_file_list] + assert len(list_of_estimators_for_each_file) > 0 + + for page_id, page in enumerate(averaged_estimators.pages): + list_of_entries_to_aggregate = [] + for estimator in list_of_estimators_for_each_file: + assert len(estimator.pages) > page_id + list_of_entries_to_aggregate.append(estimator.pages[page_id].data) + + if "mean" in output_type: + assert np.average(list_of_entries_to_aggregate) == pytest.approx(page.data) + assert np.std(list_of_entries_to_aggregate, ddof=1, axis=0) / np.sqrt( + len(bdo_file_list)) == pytest.approx(page.error) + elif "sum" in output_type: + assert np.sum(list_of_entries_to_aggregate) == pytest.approx(page.data) + assert page.error is None + elif "none" in output_type: + assert list_of_entries_to_aggregate[0] == pytest.approx(page.data) + assert page.error is None + elif "concat" in output_type: + assert np.concatenate(list_of_entries_to_aggregate, axis=1) == pytest.approx(page.data) + assert page.error is None + + +@pytest.mark.parametrize( + "output_type", + ["normalisation-2_aggregation-sum", "normalisation-3_aggregation-mean", "normalisation-5_aggregation-mean"]) +def test_aggregating_weighted_stats(averaging_bdos_directory, small_stat_bdo_pattern, large_stat_bdo_pattern, + output_type): + """ + For weighted statistics, we need to calculate the weighted average + The average from all files, should be approximately the same as from the large statistics file + """ + large_stat_bdo_files = list(averaging_bdos_directory.glob(f"{output_type}{large_stat_bdo_pattern}")) + all_bdo_files = list(averaging_bdos_directory.glob(f"{output_type}{small_stat_bdo_pattern}")) + large_stat_bdo_files + from pymchelper.estimator import ErrorEstimate + + averaged_estimators_all = fromfilelist(input_file_list=[str(path) for path in all_bdo_files], + error=ErrorEstimate.stddev) + assert len(averaged_estimators_all.pages) > 0 + + averaged_estimators_large_stat = fromfilelist(input_file_list=[str(path) for path in large_stat_bdo_files], + error=ErrorEstimate.stddev) + assert len(averaged_estimators_large_stat.pages) > 0 + + for all_stat_pages, large_stat_pages in zip(averaged_estimators_all.pages, averaged_estimators_large_stat.pages): + # the small stats should not affect the result by more than 1% + assert all_stat_pages.data == pytest.approx(large_stat_pages.data, rel=1e-2) diff --git a/tests/res/shieldhit/averaging/README.md b/tests/res/shieldhit/averaging/README.md new file mode 100644 index 000000000..81ff93ba0 --- /dev/null +++ b/tests/res/shieldhit/averaging/README.md @@ -0,0 +1,8 @@ +# Averaging + +We focus on couple of averaging methods: + + - concatenate (phase space spectra) + - average (LET, dose, fluence) + - sum (counters) + - no averaging (medium) diff --git a/tests/res/shieldhit/averaging/beam.dat b/tests/res/shieldhit/averaging/beam.dat new file mode 100644 index 000000000..3800f4f66 --- /dev/null +++ b/tests/res/shieldhit/averaging/beam.dat @@ -0,0 +1,9 @@ +* Input file FOR023.DAT for the SHIELD Transport Code +RNDSEED 89736501 ! Random seed +JPART0 25 ! Incident particle type +HIPROJ 12.0 6.0 ! A and Z of heavy ion +TMAX0 391.0 0.0 ! Incident energy; (MeV/nucl) +NSTAT 1000 1000 ! NSTAT, Step of saving +STRAGG 2 ! Straggling: 0-Off 1-Gauss, 2-Vavilov +MSCAT 2 ! Mult. scatt 0-Off 1-Gauss, 2-Moliere +NUCRE 1 ! Nucl.Reac. switcher: 1-ON, 0-OFF diff --git a/tests/res/shieldhit/averaging/detect.dat b/tests/res/shieldhit/averaging/detect.dat new file mode 100644 index 000000000..214900243 --- /dev/null +++ b/tests/res/shieldhit/averaging/detect.dat @@ -0,0 +1,38 @@ +Geometry Mesh + Name MyMesh_YZ + X -5.0 5.0 1 + Y -5.0 5.0 1 + Z 16.0 16.1 1 + +Filter + Name Carbon + Z == 6 + A == 12 + +Output + Filename normalisation-1_aggregation-none.bdo + Geo MyMesh_YZ + Quantity Material + Quantity Rho + +Output + Filename normalisation-2_aggregation-sum.bdo + Geo MyMesh_YZ + Quantity Count + +Output + Filename normalisation-3_aggregation-mean.bdo + Geo MyMesh_YZ + Quantity DLET + Quantity TLET + +Output + Filename normalisation-4_aggregation-concat.bdo + Geo MyMesh_YZ + Quantity MCPL Carbon + +Output + Filename normalisation-5_aggregation-mean.bdo + Geo MyMesh_YZ + Quantity Dose + Quantity Fluence diff --git a/tests/res/shieldhit/averaging/geo.dat b/tests/res/shieldhit/averaging/geo.dat new file mode 100644 index 000000000..1f666a339 --- /dev/null +++ b/tests/res/shieldhit/averaging/geo.dat @@ -0,0 +1,16 @@ +*---><---><--------><------------------------------------------------> + 0 0 C12 200 MeV/A, H2O 30 cm cylinder, r=10, 1 zone +*---><---><--------><--------><--------><--------><--------><--------> + RCC 1 0.0 0.0 0.0 0.0 0.0 30.0 + 10.0 + RCC 2 0.0 0.0 -5.0 0.0 0.0 35.0 + 15.0 + RCC 3 0.0 0.0 -10.0 0.0 0.0 40.0 + 20.0 + END + 001 +1 + 002 +2 -1 + 003 +3 -2 + END + 1 2 3 + 1 1000 0 diff --git a/tests/res/shieldhit/averaging/mat.dat b/tests/res/shieldhit/averaging/mat.dat new file mode 100644 index 000000000..2c80bf7c7 --- /dev/null +++ b/tests/res/shieldhit/averaging/mat.dat @@ -0,0 +1,3 @@ +MEDIUM 1 +ICRU 276 +END diff --git a/tests/res/shieldhit/averaging/normalisation-1_aggregation-none_0001.bdo b/tests/res/shieldhit/averaging/normalisation-1_aggregation-none_0001.bdo new file mode 100644 index 000000000..364924786 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-1_aggregation-none_0001.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-1_aggregation-none_0002.bdo b/tests/res/shieldhit/averaging/normalisation-1_aggregation-none_0002.bdo new file mode 100644 index 000000000..015801b6a Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-1_aggregation-none_0002.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-1_aggregation-none_0003.bdo b/tests/res/shieldhit/averaging/normalisation-1_aggregation-none_0003.bdo new file mode 100644 index 000000000..25d8ca9c4 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-1_aggregation-none_0003.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-1_aggregation-none_0011.bdo b/tests/res/shieldhit/averaging/normalisation-1_aggregation-none_0011.bdo new file mode 100644 index 000000000..3b91d0371 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-1_aggregation-none_0011.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-1_aggregation-none_0012.bdo b/tests/res/shieldhit/averaging/normalisation-1_aggregation-none_0012.bdo new file mode 100644 index 000000000..5f521d3fa Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-1_aggregation-none_0012.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-1_aggregation-none_0013.bdo b/tests/res/shieldhit/averaging/normalisation-1_aggregation-none_0013.bdo new file mode 100644 index 000000000..1ac12dcc8 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-1_aggregation-none_0013.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-1_aggregation-none_0014.bdo b/tests/res/shieldhit/averaging/normalisation-1_aggregation-none_0014.bdo new file mode 100644 index 000000000..15bf5f480 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-1_aggregation-none_0014.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-2_aggregation-sum_0001.bdo b/tests/res/shieldhit/averaging/normalisation-2_aggregation-sum_0001.bdo new file mode 100644 index 000000000..a0f3d0d44 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-2_aggregation-sum_0001.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-2_aggregation-sum_0002.bdo b/tests/res/shieldhit/averaging/normalisation-2_aggregation-sum_0002.bdo new file mode 100644 index 000000000..d00557bb2 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-2_aggregation-sum_0002.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-2_aggregation-sum_0003.bdo b/tests/res/shieldhit/averaging/normalisation-2_aggregation-sum_0003.bdo new file mode 100644 index 000000000..79ab96da4 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-2_aggregation-sum_0003.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-2_aggregation-sum_0011.bdo b/tests/res/shieldhit/averaging/normalisation-2_aggregation-sum_0011.bdo new file mode 100644 index 000000000..f80bbf15b Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-2_aggregation-sum_0011.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-2_aggregation-sum_0012.bdo b/tests/res/shieldhit/averaging/normalisation-2_aggregation-sum_0012.bdo new file mode 100644 index 000000000..c60621884 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-2_aggregation-sum_0012.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-2_aggregation-sum_0013.bdo b/tests/res/shieldhit/averaging/normalisation-2_aggregation-sum_0013.bdo new file mode 100644 index 000000000..1b36e4642 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-2_aggregation-sum_0013.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-2_aggregation-sum_0014.bdo b/tests/res/shieldhit/averaging/normalisation-2_aggregation-sum_0014.bdo new file mode 100644 index 000000000..3dbe8364d Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-2_aggregation-sum_0014.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-3_aggregation-mean_0001.bdo b/tests/res/shieldhit/averaging/normalisation-3_aggregation-mean_0001.bdo new file mode 100644 index 000000000..30c80825d Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-3_aggregation-mean_0001.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-3_aggregation-mean_0002.bdo b/tests/res/shieldhit/averaging/normalisation-3_aggregation-mean_0002.bdo new file mode 100644 index 000000000..66dd1deeb Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-3_aggregation-mean_0002.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-3_aggregation-mean_0003.bdo b/tests/res/shieldhit/averaging/normalisation-3_aggregation-mean_0003.bdo new file mode 100644 index 000000000..af8b140f7 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-3_aggregation-mean_0003.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-3_aggregation-mean_0011.bdo b/tests/res/shieldhit/averaging/normalisation-3_aggregation-mean_0011.bdo new file mode 100644 index 000000000..79ec75956 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-3_aggregation-mean_0011.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-3_aggregation-mean_0012.bdo b/tests/res/shieldhit/averaging/normalisation-3_aggregation-mean_0012.bdo new file mode 100644 index 000000000..d240211f8 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-3_aggregation-mean_0012.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-3_aggregation-mean_0013.bdo b/tests/res/shieldhit/averaging/normalisation-3_aggregation-mean_0013.bdo new file mode 100644 index 000000000..020a4e6c2 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-3_aggregation-mean_0013.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-3_aggregation-mean_0014.bdo b/tests/res/shieldhit/averaging/normalisation-3_aggregation-mean_0014.bdo new file mode 100644 index 000000000..b1802e0b8 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-3_aggregation-mean_0014.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-4_aggregation-concat_0001.bdo b/tests/res/shieldhit/averaging/normalisation-4_aggregation-concat_0001.bdo new file mode 100644 index 000000000..2f37ebc79 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-4_aggregation-concat_0001.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-4_aggregation-concat_0002.bdo b/tests/res/shieldhit/averaging/normalisation-4_aggregation-concat_0002.bdo new file mode 100644 index 000000000..42e31ed3d Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-4_aggregation-concat_0002.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-4_aggregation-concat_0003.bdo b/tests/res/shieldhit/averaging/normalisation-4_aggregation-concat_0003.bdo new file mode 100644 index 000000000..12da01380 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-4_aggregation-concat_0003.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-4_aggregation-concat_0011.bdo b/tests/res/shieldhit/averaging/normalisation-4_aggregation-concat_0011.bdo new file mode 100644 index 000000000..5efd5d3d1 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-4_aggregation-concat_0011.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-4_aggregation-concat_0012.bdo b/tests/res/shieldhit/averaging/normalisation-4_aggregation-concat_0012.bdo new file mode 100644 index 000000000..e1c8a631d Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-4_aggregation-concat_0012.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-4_aggregation-concat_0013.bdo b/tests/res/shieldhit/averaging/normalisation-4_aggregation-concat_0013.bdo new file mode 100644 index 000000000..b01dfbc4a Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-4_aggregation-concat_0013.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-4_aggregation-concat_0014.bdo b/tests/res/shieldhit/averaging/normalisation-4_aggregation-concat_0014.bdo new file mode 100644 index 000000000..4abe29ee1 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-4_aggregation-concat_0014.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-5_aggregation-mean_0001.bdo b/tests/res/shieldhit/averaging/normalisation-5_aggregation-mean_0001.bdo new file mode 100644 index 000000000..a30d15f65 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-5_aggregation-mean_0001.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-5_aggregation-mean_0002.bdo b/tests/res/shieldhit/averaging/normalisation-5_aggregation-mean_0002.bdo new file mode 100644 index 000000000..6c49cb243 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-5_aggregation-mean_0002.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-5_aggregation-mean_0003.bdo b/tests/res/shieldhit/averaging/normalisation-5_aggregation-mean_0003.bdo new file mode 100644 index 000000000..ee4f35ae0 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-5_aggregation-mean_0003.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-5_aggregation-mean_0011.bdo b/tests/res/shieldhit/averaging/normalisation-5_aggregation-mean_0011.bdo new file mode 100644 index 000000000..0cd96272a Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-5_aggregation-mean_0011.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-5_aggregation-mean_0012.bdo b/tests/res/shieldhit/averaging/normalisation-5_aggregation-mean_0012.bdo new file mode 100644 index 000000000..4046c6536 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-5_aggregation-mean_0012.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-5_aggregation-mean_0013.bdo b/tests/res/shieldhit/averaging/normalisation-5_aggregation-mean_0013.bdo new file mode 100644 index 000000000..9296e32d4 Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-5_aggregation-mean_0013.bdo differ diff --git a/tests/res/shieldhit/averaging/normalisation-5_aggregation-mean_0014.bdo b/tests/res/shieldhit/averaging/normalisation-5_aggregation-mean_0014.bdo new file mode 100644 index 000000000..b2ca88e1f Binary files /dev/null and b/tests/res/shieldhit/averaging/normalisation-5_aggregation-mean_0014.bdo differ diff --git a/tests/res/shieldhit/averaging/run.sh b/tests/res/shieldhit/averaging/run.sh new file mode 100755 index 000000000..c3dde579e --- /dev/null +++ b/tests/res/shieldhit/averaging/run.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +rm --force ./*.bdo +rm --force ./*.log + +# large statistic example which should converge to the same result +shieldhit --nstat=1000 --seedoffset=1 --silent . 1>/dev/null & +shieldhit --nstat=1000 --seedoffset=2 --silent . 1>/dev/null & +shieldhit --nstat=1000 --seedoffset=3 --silent . 1>/dev/null & + +# small statistic example which should deviate from run to run +shieldhit --nstat=10 --seedoffset=11 --silent . 1>/dev/null & +shieldhit --nstat=10 --seedoffset=12 --silent . 1>/dev/null & +shieldhit --nstat=10 --seedoffset=13 --silent . 1>/dev/null & +shieldhit --nstat=10 --seedoffset=14 --silent . 1>/dev/null & + +# wait for all runs to finish, then exit with the exit code of the last run +wait diff --git a/tests/test_averaging.py b/tests/test_averaging.py index 6859e8c89..5a54e6264 100644 --- a/tests/test_averaging.py +++ b/tests/test_averaging.py @@ -11,9 +11,8 @@ @pytest.fixture(scope='module') -def shieldhit_multiple_result_directory() -> Generator[Path, None, None]: +def shieldhit_multiple_result_directory(main_dir) -> Generator[Path, None, None]: """Return path to directory with single SHIELD-HIT12A result files""" - main_dir = Path(__file__).resolve().parent yield main_dir / "res" / "shieldhit" / "generated" / "many" / "msh" diff --git a/tests/test_default_converter.py b/tests/test_default_converter.py index 33af45f85..0bed0166e 100644 --- a/tests/test_default_converter.py +++ b/tests/test_default_converter.py @@ -1,142 +1,67 @@ -import difflib -import filecmp -import os -import sys -import subprocess -import tempfile -import shutil -import unittest -import logging - import numpy as np import pytest - -from pymchelper import run -from pymchelper.estimator import ErrorEstimate +from typing import List, Tuple +from pymchelper.estimator import Estimator, ErrorEstimate from pymchelper.input_output import fromfilelist, fromfile -logger = logging.getLogger(__name__) - -def shieldhit_binary(): - """ - :return: location of bdo2txt binary suitable for Linux or Windows +@pytest.fixture +def file_list() -> List[str]: """ - exe_path = os.path.join("tests", "res", "shieldhit", "converter", "bdo2txt") - if os.name == 'nt': - exe_path += ".exe" - logger.debug("bdo2txt binary: " + exe_path) - return exe_path - - -def run_bdo2txt_binary(inputfile, working_dir, bdo2txt_path, silent=True): - logger.info("running: " + bdo2txt_path + " " + inputfile) - with open(os.devnull, 'w') as shutup: - if silent: - retVal = subprocess.call(args=[bdo2txt_path, inputfile], stdout=shutup, stderr=shutup) - else: - retVal = subprocess.call(args=[bdo2txt_path, inputfile]) - return retVal - - -class TestErrorEstimate(unittest.TestCase): - def test_normal_numbers(self): - - # several files for the same estimator, coming from runs with different RNG seed - file_list = ["tests/res/shieldhit/generated/many/msh/aen_0_p000{:d}.bdo".format(i) for i in range(1, 4)] - - # read each of the files individually into estimator object - estimator_list = [fromfile(file_path) for file_path in file_list] - - for error in ErrorEstimate: # all possible error options (none, stddev, stderr) - logger.debug("Checking error calculation for error = {:s}".format(error.name)) - for nan in (False, True): # include or not NaNs in averaging - logger.debug("Checking error calculation for nan option = {:s}".format(str(nan))) - # read list of the files into one estimator object, doing averaging and error calculation - merged_estimators = fromfilelist(file_list, error=error, nan=nan) - - # manually calculate mean and check if correct - for page_no, page in enumerate(merged_estimators.pages): - mean_value = np.mean([estimator.pages[page_no].data_raw for estimator in estimator_list]) - self.assertEqual(mean_value, merged_estimators.pages[page_no].data_raw) - - # manually calculate mean and check if correct - if error == ErrorEstimate.none: - for page in merged_estimators.pages: - self.assertTrue(np.isnan(page.error_raw) or not np.any(page.error_raw)) - elif error == ErrorEstimate.stddev: - for page_no, page in enumerate(merged_estimators.pages): - error_value = np.std([estimator.pages[page_no].data_raw for estimator in estimator_list], - ddof=1) - self.assertEqual(error_value, merged_estimators.pages[page_no].error_raw) - elif error == ErrorEstimate.stderr: - for page_no, page in enumerate(merged_estimators.pages): - error_value = np.std([estimator.pages[page_no].data_raw for estimator in estimator_list], - ddof=1) - error_value /= np.sqrt(len(estimator_list)) - self.assertEqual(error_value, merged_estimators.pages[page_no].error_raw) - else: - return + Fixture to generate a list of file paths for testing. + Returns: + List[str]: A list of file paths for the test files. + """ + return [f"tests/res/shieldhit/generated/many/msh/aen_0_p000{i}.bdo" for i in range(1, 4)] -@pytest.mark.slow -class TestDefaultConverter(unittest.TestCase): - main_dir = os.path.join("tests", "res", "shieldhit", "generated") - single_dir = os.path.join(main_dir, "single") - many_dir = os.path.join(main_dir, "many") - bdo2txt_binary = shieldhit_binary() - - def test_shieldhit_files(self): - # skip tests on MacOSX, as there is no suitable bdo2txt converter available yet - if sys.platform.endswith('arwin'): - return - - # loop over all .bdo files in all subdirectories - for root, dirs, filenames in os.walk(self.single_dir): - for input_basename in filenames: - logger.info("root: {:s}, file: {:s}".format(root, input_basename)) - - inputfile_rel_path = os.path.join(root, input_basename) # choose input file - self.assertTrue(inputfile_rel_path.endswith(".bdo")) - - working_dir = tempfile.mkdtemp() # make temp working dir for converter output files - logger.info("Creating directory {:s}".format(working_dir)) - - # generate output file with native SHIELD-HIT12A converter - ret_value = run_bdo2txt_binary(inputfile_rel_path, - working_dir=working_dir, - bdo2txt_path=self.bdo2txt_binary) - self.assertEqual(ret_value, 0) - - # assuming input name 1.bdo, output file will be called 1.txt - shieldhit_output = inputfile_rel_path[:-3] + "txt" - logger.info("Expecting file {:s} to be generated by SHIELD-HIT12A converter".format(shieldhit_output)) - self.assertTrue(os.path.exists(shieldhit_output)) - - shutil.move(shieldhit_output, working_dir) - shieldhit_output_moved = os.path.join(working_dir, os.path.basename(shieldhit_output)) - logger.info("New location of SH12A file: {:s}".format(shieldhit_output_moved)) - self.assertTrue(os.path.exists(shieldhit_output_moved)) - # generate output with pymchelper assuming .ref extension for output file - pymchelper_output = os.path.join(working_dir, input_basename[:-3] + "ref.txt") - logger.info("Expecting file {:s} to be generated by pymchelper converter".format(pymchelper_output)) - run.main(["txt", inputfile_rel_path, pymchelper_output, '--error', 'none']) - self.assertTrue(os.path.exists(pymchelper_output)) +@pytest.fixture +def estimator_list(file_list: List[str]) -> List[Estimator]: + """ + Fixture to prepare a list of estimator objects based on a given list of file paths. - # compare both files - comparison = filecmp.cmp(shieldhit_output_moved, pymchelper_output) - if not comparison: - with open(shieldhit_output_moved, 'r') as f1, open(pymchelper_output, 'r') as f2: - diff = difflib.unified_diff(f1.readlines(), f2.readlines()) - diffs_to_print = [next(diff) for _ in range(30)] - for item in diffs_to_print: - logger.info(item) - self.assertTrue(comparison) + Args: + file_list (List[str]): A list of file paths from which to load the estimator objects. - logger.info("Removing directory {:s}".format(working_dir)) - shutil.rmtree(working_dir) + Returns: + List[Estimator]: A list of estimator objects loaded from the given file paths. + """ + return [fromfile(file_path) for file_path in file_list] -if __name__ == '__main__': - unittest.main() +@pytest.mark.parametrize("error", list(ErrorEstimate)) +@pytest.mark.parametrize("nan", [False, True]) +def test_normal_numbers_with_params(file_list: List[str], estimator_list: List[Estimator], error: ErrorEstimate, + nan: bool) -> None: + """ + Test the error calculations for merged estimators under various conditions. + + This function tests the error calculations by merging a list of estimators + with specified error handling and NaN inclusion options. It verifies that the + calculated mean values are consistent with manually calculated mean values from + the individual estimators. It also checks that the error values (standard deviation + or standard error) are calculated correctly, according to the specified error type. + + Args: + file_list (List[str]): A list of file paths used to create the estimator objects. + estimator_list (List[Estimator]): A list of estimator objects created from the file paths. + error (ErrorEstimate): The type of error calculation to apply when merging the estimators. + nan (bool): Flag indicating whether to include NaN values in the error calculations. + """ + merged_estimators = fromfilelist(file_list, error=error, nan=nan) + + for page_no, page in enumerate(merged_estimators.pages): + mean_value = np.mean([est.pages[page_no].data_raw for est in estimator_list]) + assert mean_value == page.data_raw, f"Mean value calculation mismatch for error={error}, nan={nan}" + + if error == ErrorEstimate.none: + for page in merged_estimators.pages: + assert page.error_raw is None, f"Error value should be None for error={error}, nan={nan}, file_list={file_list}" + if error in (ErrorEstimate.stddev, ErrorEstimate.stderr): + for page_no, page in enumerate(merged_estimators.pages): + error_value = np.std([est.pages[page_no].data_raw for est in estimator_list], ddof=1) + if error == ErrorEstimate.stderr: + error_value /= np.sqrt(len(estimator_list)) + assert np.allclose(error_value, + page.error_raw), f"Error value calculation mismatch for error={error}, nan={nan}" diff --git a/tests/test_shieldhit12a_generate.py b/tests/test_shieldhit12a_generate.py index 835472399..5c6faf7af 100644 --- a/tests/test_shieldhit12a_generate.py +++ b/tests/test_shieldhit12a_generate.py @@ -19,6 +19,7 @@ class TestSHGenerate(unittest.TestCase): + def test_create(self): outdir = os.path.join("tests", "res", "shieldhit", "generated") gen.main([outdir]) @@ -103,8 +104,9 @@ def test_image(self): if pattern in os.path.basename(infile): will_produce_output = True if will_produce_output: - for options in ([], ["--colormap", "gnuplot2"], - ["--error", "stderr"], ["--error", "stddev"], ["--error", "none"]): + for options in ([], ["--colormap", + "gnuplot2"], ["--error", "stderr"], ["--error", "stddev"], ["--error", + "none"]): fd, outfile = tempfile.mkstemp() os.close(fd) os.remove(outfile) @@ -152,7 +154,3 @@ def test_get_object(self): self.assertEqual(estimator.number_of_primaries, 1000) self.assertGreaterEqual(len(estimator.pages), 1) self.assertGreaterEqual(estimator.pages[0].data_raw.size, 1) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_weighted_stats.py b/tests/test_weighted_stats.py new file mode 100644 index 000000000..d596003c7 --- /dev/null +++ b/tests/test_weighted_stats.py @@ -0,0 +1,133 @@ +import pytest +import numpy as np +from pymchelper.averaging import WeightedStatsAggregator + + +def test_initial_state(): + ws = WeightedStatsAggregator() + # check if ws.mean is nan + assert np.isnan(ws.mean) + assert ws.total_weight == 0 + + +def test_single_update(): + ws = WeightedStatsAggregator() + ws.update(value=10, weight=2) + assert ws.mean == 10 + assert ws.total_weight == 2 + + +def test_multiple_updates(): + ws = WeightedStatsAggregator() + updates = [(10, 2), (20, 3), (30, 5)] + total_weight = sum(weight for _, weight in updates) + weighted_sum = sum(value * weight for value, weight in updates) + expected_mean = weighted_sum / total_weight + + for value, weight in updates: + ws.update(value, weight) + + assert ws.total_weight == total_weight + assert pytest.approx(ws.mean, 0.001) == expected_mean + + +def test_zero_weight(): + ws = WeightedStatsAggregator() + with pytest.raises(Exception): + ws.update(value=10, weight=0) + + +def test_negative_weight(): + ws = WeightedStatsAggregator() + with pytest.raises(Exception): + ws.update(value=10, weight=-1) + + +def test_update_with_1d_array(): + ws = WeightedStatsAggregator() + values = np.array([10, 20, 30]) + weights = np.array([2, 3, 5]) + total_weight = weights.sum() + weighted_sum = np.dot(values, weights) + expected_mean = weighted_sum / total_weight + + for value, weight in zip(values, weights): + ws.update(value, weight) + + assert ws.total_weight == total_weight + assert pytest.approx(ws.mean, 0.001) == expected_mean + + +def test_update_with_flattened_array(): + ws = WeightedStatsAggregator() + values = np.array([[10, 20], [30, 40]]).flatten() + weights = np.array([[2, 3], [4, 1]]).flatten() + total_weight = weights.sum() + weighted_sum = np.dot(values, weights) + expected_mean = weighted_sum / total_weight + + for value, weight in zip(values, weights): + ws.update(value, weight) + + assert ws.total_weight == total_weight + assert pytest.approx(ws.mean, 0.001) == expected_mean + + +def compute_expected_variance(values, weights, total_weight, is_sample=False): + """Utility function to compute the expected variance.""" + weighted_mean = np.average(values, weights=weights) + variance = np.sum(weights * (values - weighted_mean)**2) + if is_sample: + variance /= (total_weight - (np.sum(weights**2) / total_weight)) + else: + variance /= total_weight + return variance + + +def test_variance_population_single_update(): + ws = WeightedStatsAggregator() + ws.update(value=10, weight=2) + # Variance should be 0 for a single value + assert ws.variance_population == 0 + + +def test_variance_population_multiple_updates(): + ws = WeightedStatsAggregator() + values = np.array([10, 20, 30]) + weights = np.array([2, 3, 5]) + total_weight = weights.sum() + + for value, weight in zip(values, weights): + ws.update(value, weight) + + expected_variance = compute_expected_variance(values, weights, total_weight) + assert pytest.approx(ws.variance_population, 0.001) == expected_variance + + +def test_variance_sample_multiple_updates(): + ws = WeightedStatsAggregator() + values = np.array([10, 20, 30]) + weights = np.array([2, 3, 5]) + total_weight = weights.sum() + + for value, weight in zip(values, weights): + ws.update(value, weight) + + expected_variance = compute_expected_variance(values, weights, total_weight, is_sample=True) + assert pytest.approx(ws.variance_sample, 0.001) == expected_variance + + +def test_variance_with_1d_array(): + ws = WeightedStatsAggregator() + values = np.array([10, 20, 30]) + weights = np.array([2, 3, 5]) + total_weight = weights.sum() + + for value, weight in zip(values, weights): + ws.update(value, weight) + + expected_variance_population = compute_expected_variance(values, weights, total_weight) + assert pytest.approx(ws.variance_population, 0.001) == expected_variance_population + + expected_variance_sample = compute_expected_variance(values, weights, total_weight, is_sample=True) + assert pytest.approx(ws.variance_sample, 0.001) == expected_variance_sample