From 00b94f4b64fd304a3936420768f943c545ea004b Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Fri, 23 Feb 2024 11:20:28 +0100 Subject: [PATCH] Big Bang --- .github/workflows/pre_commit.yml | 19 + .gitignore copy | 160 ++++++ .pre-commit-config.yaml | 17 + README.md | 146 +++++- exponax/__init__.py | 94 ++++ exponax/base_stepper.py | 191 +++++++ exponax/exponential_integrators.py | 281 +++++++++++ exponax/forced_stepper.py | 103 ++++ exponax/initial_conditions.py | 472 ++++++++++++++++++ exponax/nonlinear_functions/__init__.py | 14 + exponax/nonlinear_functions/base.py | 69 +++ exponax/nonlinear_functions/convection.py | 68 +++ exponax/nonlinear_functions/gradient_norm.py | 73 +++ exponax/nonlinear_functions/polynomial.py | 61 +++ exponax/nonlinear_functions/reaction.py | 146 ++++++ .../vorticity_convection.py | 116 +++++ exponax/nonlinear_functions/zero.py | 31 ++ exponax/normalized_stepper/__init__.py | 11 + exponax/normalized_stepper/convection.py | 110 ++++ exponax/normalized_stepper/gradient_norm.py | 86 ++++ exponax/normalized_stepper/linear.py | 53 ++ exponax/normalized_stepper/utils.py | 76 +++ exponax/poisson.py | 102 ++++ exponax/repeated_stepper.py | 58 +++ exponax/sample_stepper/__init__.py | 32 ++ exponax/sample_stepper/burgers.py | 63 +++ exponax/sample_stepper/convection.py | 75 +++ exponax/sample_stepper/gradient_norm.py | 75 +++ exponax/sample_stepper/korteveg_de_vries.py | 86 ++++ .../sample_stepper/kuramoto_sivashinsky.py | 140 ++++++ exponax/sample_stepper/linear.py | 314 ++++++++++++ exponax/sample_stepper/navier_stokes.py | 127 +++++ exponax/sample_stepper/nikolaevskiy.py | 141 ++++++ exponax/sample_stepper/reaction.py | 350 +++++++++++++ exponax/spectral.py | 412 +++++++++++++++ exponax/utils.py | 365 ++++++++++++++ ks_rollout.png | Bin 0 -> 197512 bytes pyproject.toml | 21 + setup.cfg | 8 + setup.py | 8 + tests/test_builtin_solvers.py | 131 +++++ tests/test_validation.py | 66 +++ 42 files changed, 4969 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/pre_commit.yml create mode 100644 .gitignore copy create mode 100644 .pre-commit-config.yaml create mode 100644 exponax/__init__.py create mode 100644 exponax/base_stepper.py create mode 100644 exponax/exponential_integrators.py create mode 100644 exponax/forced_stepper.py create mode 100644 exponax/initial_conditions.py create mode 100644 exponax/nonlinear_functions/__init__.py create mode 100644 exponax/nonlinear_functions/base.py create mode 100644 exponax/nonlinear_functions/convection.py create mode 100644 exponax/nonlinear_functions/gradient_norm.py create mode 100644 exponax/nonlinear_functions/polynomial.py create mode 100644 exponax/nonlinear_functions/reaction.py create mode 100644 exponax/nonlinear_functions/vorticity_convection.py create mode 100644 exponax/nonlinear_functions/zero.py create mode 100644 exponax/normalized_stepper/__init__.py create mode 100644 exponax/normalized_stepper/convection.py create mode 100644 exponax/normalized_stepper/gradient_norm.py create mode 100644 exponax/normalized_stepper/linear.py create mode 100644 exponax/normalized_stepper/utils.py create mode 100644 exponax/poisson.py create mode 100644 exponax/repeated_stepper.py create mode 100644 exponax/sample_stepper/__init__.py create mode 100644 exponax/sample_stepper/burgers.py create mode 100644 exponax/sample_stepper/convection.py create mode 100644 exponax/sample_stepper/gradient_norm.py create mode 100644 exponax/sample_stepper/korteveg_de_vries.py create mode 100644 exponax/sample_stepper/kuramoto_sivashinsky.py create mode 100644 exponax/sample_stepper/linear.py create mode 100644 exponax/sample_stepper/navier_stokes.py create mode 100644 exponax/sample_stepper/nikolaevskiy.py create mode 100644 exponax/sample_stepper/reaction.py create mode 100644 exponax/spectral.py create mode 100644 exponax/utils.py create mode 100644 ks_rollout.png create mode 100644 pyproject.toml create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 tests/test_builtin_solvers.py create mode 100644 tests/test_validation.py diff --git a/.github/workflows/pre_commit.yml b/.github/workflows/pre_commit.yml new file mode 100644 index 0000000..37f41e9 --- /dev/null +++ b/.github/workflows/pre_commit.yml @@ -0,0 +1,19 @@ +name: Code linting + +on: + pull_request: + + push: + branches: + - main + +jobs: + pre-commit: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.12' + - uses: pre-commit/action@v3.0.0 \ No newline at end of file diff --git a/.gitignore copy b/.gitignore copy new file mode 100644 index 0000000..68bc17f --- /dev/null +++ b/.gitignore copy @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..c56e09d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,17 @@ +repos: + - repo: https://github.com/ambv/black + rev: 23.12.1 + hooks: + - id: black-jupyter + language_version: python3 + + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + args: ["--profile", "black", "--filter-files"] + + - repo: https://github.com/pycqa/flake8 + rev: 7.0.0 + hooks: + - id: flake8 diff --git a/README.md b/README.md index 1e32121..2e6e767 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,144 @@ -# exponax -the all new exponax in multiple dimensions with multiple channels +# Exponax + +A suite of simple solvers for 1d PDEs on periodic domains based on exponential +time differencing algorithms, built on top of +[JAX](https://github.com/google/jax). **Efficient**, **Elegant**, +**Vectorizable**, and **Differentiable**. + +### Quickstart - 1d Kuramoto-Sivashinsky equation + +```python +import jax +import exponax as ex +import matplotlib.pyplot as plt + +ks_stepper = ex.KuramotoSivashinskyConservative( + num_spatial_dims=1, domain_extent=100.0, + num_points=200, dt=0.1, +) + +u_0 = ex.RandomTruncatedFourierSeries( + num_spatial_dims=1, cutoff=5 +)(num_points=200, key=jax.random.PRNGKey(0)) + +trajectory = ex.rollout(ks_stepper, 500, include_init=True)(u_0) + +plt.imshow(trajectory[:, 0, :].T, aspect='auto', cmap='RdBu', vmin=-2, vmax=2, origin="lower") +plt.xlabel("Time"); plt.ylabel("Space"); plt.show() +``` + +![](ks_rollout.png) + +See also the *examples* folder for more. It is best to start with +`simple_advection_example.ipynb` to get familiar with the ideoms of the package, +especially if not too familiar with JAX. Then, continue with the +`solver_showcase.ipynb`. To see the solvers in action to solve a supervised +learning problem, see `learning_burgers_autoregressive_neural_operator.ipynb`. A +tutorial notebook that requires the differentiability of the solvers is in the +works. + +### Features + +Using JAX as the computational backend gives: + +1. **Backend agnotistic code** - run on CPU, GPU, or TPU, in both single and double + precision. +2. **Automatic differentiation** over the timesteppers - compute gradients of + solutions with respect to initial conditions, parameters, etc. +3. Also helpful for **tight integration with Deep Learning** since each + timestepper is also just an [Equinox](https://github.com/patrick-kidger/equinox) Module. +4. **Automatic Vectorization** using `jax.vmap` (or `equinox.filter_vmap`) + allowing to advance multiple states in time or instantiate multiple solvers at a time that operate efficiently in batch. + +Exponax strives to be lightweight and without custom types; there is no `grid` or `state` object. Everything is based on `jax.numpy` arrays. + +### Background + +Exponax supports the efficient solution of 1d (semi-linear) partial differential equations on periodic domains. Those are PDEs of the form + +$$ \partial u/ \partial t = Lu + N(u) $$ + +where $L$ is a linear differential operator and $N$ is a nonlinear differential +operator. The linear part can be exactly solved using a (matrix) exponential, +and the nonlinear part is approximated using Runge-Kutta methods of various +orders. These methods have been known in various disciplines in science for a +long time and have been unified for a first time by [Cox & Matthews](https://doi.org/10.1006/jcph.2002.6995) [1]. In particular, this package uses the complex contour integral method of [Kassam & Trefethen](https://doi.org/10.1137/S1064827502410633) [2] for numerical stability. The package is restricted to original first, second, third and fourth order method. Since the package of [1] many extensions have been developed. A recent study by [Montanelli & Bootland](https://doi.org/10.1016/j.matcom.2020.06.008) [3] showed that the original *ETDRK4* method is still one of the most efficient methods for these types of PDEs. + +### Built-In solvers + +This package comes with the following solvers: + +* Linear PDEs: + * Advection equation + * Diffusion equation + * Advection-Diffusion equation + * Dispersion equation + * Hyper-Diffusion equation + * General linear equation containing zeroth, first, second, third, and fourth order derivatives +* Nonlinear PDEs: + * Burgers equation + * Kuramoto-Sivashinsky equation + * Korteweg-de Vries equation + +Other equations can easily be implemented by subclassing from the `BaseStepper` +module. + +### Other functionality + +Next to the timesteppers operating on JAX array states, it also comes with: + +* Initial Conditions: + * Random sine waves + * Diffused Noise + * Random Discontinuities + * Gaussian Random Fields +* Utilities: + * Mesh creation + * Rollout functions + * Spectral derivatives + * Initial condition set creation +* Poisson solver +* Modification to make solvers take an additional forcing argument +* Modification to make solvers perform substeps for more accurate simulation + +### Similar projects and motivation for this package + +This package is greatly inspired by the [chebfun](https://www.chebfun.org/) +package in *MATLAB*, in particular the +[`spinX`](https://www.chebfun.org/docs/guide/guide19.html) module within it. It +has been used extensively as a data generator in early works for supervised +physics-informed ML, e.g., the +[DeepHiddenPhysics](https://github.com/maziarraissi/DeepHPMs/tree/7b579dbdcf5be4969ebefd32e65f709a8b20ec44/Matlab) +and [Fourier Neural +Operators](https://github.com/neuraloperator/neuraloperator/tree/af93f781d5e013f8ba5c52baa547f2ada304ffb0/data_generation) +(the links show where in their public repos they use the `spinX` module). The +approach of pre-sampling the solvers, writing out the trajectories, and then +using them for supervised training worked for these problems, but of course +limits to purely supervised problem. Modern research ideas like correcting +coarse solvers (see for instance the [Solver-in-the-Loop +paper](https://arxiv.org/abs/2007.00016) or the [ML-accelerated CFD +paper](https://arxiv.org/abs/2102.01010)) requires the coarse solvers to be +[differentiable](https://physicsbaseddeeplearning.org/diffphys.html). Some ideas +of diverted chain training also requires the fine solver to be differentiable! +Even for applications without differentiable solvers, we still have the +**interface problem** with legacy solvers (like the MATLAB ones). Hence, we +cannot easily query them "on-the-fly" for sth like active learning tasks, nor do +they run efficiently on hardward accelerators (GPUs, TPUs, etc.). Additionally, +they were not designed with batch execution (in the sense of vectorized +application) in mind which we get more or less for free by `jax.vmap`. With the +reproducible randomness of `JAX` we might not even have to ever write out a +dataset and can re-create it in seconds! + +This package took much inspiration from the +[FourierFlows.jl](https://github.com/FourierFlows/FourierFlows.jl) in the +*Julia* ecosystem, especially for checking the implementation of the contout +integral method of [2] and how to handle (de)aliasing. + + +### References + +[1] Cox, Steven M., and Paul C. Matthews. "Exponential time differencing for stiff systems." Journal of Computational Physics 176.2 (2002): 430-455. + +[2] Kassam, A.K. and Trefethen, L.N., 2005. Fourth-order time-stepping for stiff PDEs. SIAM Journal on Scientific Computing, 26(4), pp.1214-1233. + +[3] Montanelli, Hadrien, and Niall Bootland. "Solving periodic semilinear stiff PDEs in 1D, 2D and 3D with exponential integrators." Mathematics and Computers in Simulation 178 (2020): 307-327. \ No newline at end of file diff --git a/exponax/__init__.py b/exponax/__init__.py new file mode 100644 index 0000000..f34baf8 --- /dev/null +++ b/exponax/__init__.py @@ -0,0 +1,94 @@ +from .forced_stepper import ForcedStepper +from .initial_conditions import ( + MultiChannelIC, + RandomMultiChannelICGenerator, + RandomTruncatedFourierSeries, + DiffusedNoise, + GaussianRandomField, +) +from .poisson import Poisson +from .repeated_stepper import RepeatedStepper +from .sample_stepper import ( + Advection, + Diffusion, + AdvectionDiffusion, + Dispersion, + HyperDiffusion, + GeneralLinearStepper, + Burgers, + KuramotoSivashinsky, + KuramotoSivashinskyConservative, + Nikolaevskiy, + NikolaevskiyConservative, + GeneralConvectionStepper, + GeneralGradientNormStepper, + NavierStokesVorticity2d, + KolmogorovFlowVorticity2d, + SwiftHohenberg, + GrayScott, + KortevegDeVries, + FisherKPP, + AllenCahn, + CahnHilliard, + BelousovZhabotinsky, +) +from .normalized_stepper import ( + NormalizedLinearStepper, + NormalizedConvectionStepper, + NormalizedGradientNormStepper, + normalize_coefficients, + denormalize_coefficients, + normalize_convection_scale, + denormalize_convection_scale, + normalize_gradient_norm_scale, + denormalize_gradient_norm_scale, +) +from .utils import ( + get_grid, + get_animation, + get_grouped_animation, + rollout, + repeat, + stack_sub_trajectories, + build_ic_set, +) +from .spectral import ( + derivative, +) + +__all__ = [ + "ForcedStepper", + "SineWaves", + "RandomSineWaves", + "DiffusedNoise", + "RandomDiffusedNoise", + "Poisson", + "RepeatedStepper", + "Advection", + "Advection1d", + "Advection2d", + "Advection3d", + "Diffusion", + "Diffusion1d", + "Diffusion2d", + "Diffusion3d", + "AdvectionDiffusion", + "Dispersion", + "HyperDiffusion", + "Burgers", + "Burgers1d", + "Burgers2d", + "Burgers3d", + "KuramotoSivashinsky", + "KuramotoSivashinsky1d", + "KuramotoSivashinsky2d", + "KuramotoSivashinsky3d", + "NavierStokesVorticity2d", + "get_grid", + "get_animation", + "get_grouped_animation", + "rollout", + "repeat", + "stack_sub_trajectories", + "build_ic_set", +] diff --git a/exponax/base_stepper.py b/exponax/base_stepper.py new file mode 100644 index 0000000..d9f2551 --- /dev/null +++ b/exponax/base_stepper.py @@ -0,0 +1,191 @@ +import jax +import jax.numpy as jnp +import equinox as eqx +from jaxtyping import Array, Float, Complex + +from .exponential_integrators import BaseETDRK, ETDRK0, ETDRK1, ETDRK2, ETDRK3, ETDRK4 +from .spectral import ( + build_derivative_operator, + space_indices, + spatial_shape, + wavenumber_shape, +) + +from .nonlinear_functions import BaseNonlinearFun + + +class BaseStepper(eqx.Module): + num_spatial_dims: int + domain_extent: float + num_points: int + num_channels: int + dt: float + dx: float + + _integrator: BaseETDRK + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + num_channels: int, + order: int, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.num_spatial_dims = num_spatial_dims + self.domain_extent = domain_extent + self.num_points = num_points + self.dt = dt + self.num_channels = num_channels + + # Uses the convention that N does **not** include the right boundary + # point + self.dx = domain_extent / num_points + + derivative_operator = build_derivative_operator( + num_spatial_dims, domain_extent, num_points + ) + + linear_operator = self._build_linear_operator(derivative_operator) + single_channel_shape = (1,) + wavenumber_shape( + self.num_spatial_dims, self.num_points + ) # Same operator for each channel (i.e., we broadcast) + multi_channel_shape = (self.num_channels,) + wavenumber_shape( + self.num_spatial_dims, self.num_points + ) # Different operator for each channel + if linear_operator.shape not in (single_channel_shape, multi_channel_shape): + raise ValueError( + f"Expected linear operator to have shape {single_channel_shape} or {multi_channel_shape}, got {linear_operator.shape}." + ) + nonlinear_fun = self._build_nonlinear_fun(derivative_operator) + + if order == 0: + self._integrator = ETDRK0( + dt, + linear_operator, + ) + elif order == 1: + self._integrator = ETDRK1( + dt, + linear_operator, + nonlinear_fun, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + elif order == 2: + self._integrator = ETDRK2( + dt, + linear_operator, + nonlinear_fun, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + elif order == 3: + self._integrator = ETDRK3( + dt, + linear_operator, + nonlinear_fun, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + elif order == 4: + self._integrator = ETDRK4( + dt, + linear_operator, + nonlinear_fun, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + else: + raise NotImplementedError(f"Order {order} not implemented.") + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "D ... (N//2)+1"]: + """ + Assemble the L operator in Fourier space. + + **Arguments:** + - `derivative_operator`: The derivative operator, shape `( D, ..., + N//2+1 )`. The ellipsis are (D-1) axis of size N (**not** of size + N//2+1). + + **Returns:** + - `L`: The linear operator, shape `( D, ..., N//2+1 )`. + """ + raise NotImplementedError("Must be implemented in subclass.") + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> BaseNonlinearFun: + """ + Build the function that evaluates nonlinearity in physical space, + transforms to Fourier space, and evaluates derivatives there. + + **Arguments:** + - `derivative_operator`: The derivative operator, shape `( D, ..., N//2+1 )`. + + **Returns:** + - `nonlinear_fun`: A function that evaluates the nonlinearities in + time space, transforms to Fourier space, and evaluates the + derivatives there. Should be a subclass of `BaseNonlinearFun`. + """ + raise NotImplementedError("Must be implemented in subclass.") + + def step(self, u: Float[Array, "C ... N"]) -> Float[Array, "C ... N"]: + """ + Perform one step of the time integration. + + **Arguments:** + - `u`: The state vector, shape `(C, ..., N,)`. + + **Returns:** + - `u_next`: The state vector after one step, shape `(C, ..., N,)`. + """ + u_hat = jnp.fft.rfftn(u, axes=space_indices(self.num_spatial_dims)) + u_next_hat = self.step_fourier(u_hat) + u_next = jnp.fft.irfftn( + u_next_hat, + s=spatial_shape(self.num_spatial_dims, self.num_points), + axes=space_indices(self.num_spatial_dims), + ) + return u_next + + def step_fourier( + self, u_hat: Complex[Array, "C ... (N//2)+1"] + ) -> Complex[Array, "C ... (N//2)+1"]: + """ + Perform one step of the time integration in Fourier space. Oftentimes, + this is more efficient than `step` since it avoids back and forth + transforms. + + **Arguments:** + - `u_hat`: The (real) Fourier transform of the state vector + + **Returns:** + - `u_next_hat`: The (real) Fourier transform of the state vector + after one step + """ + return self._integrator.step_fourier(u_hat) + + def __call__( + self, + u: Float[Array, "C ... N"], + ) -> Float[Array, "C ... N"]: + """ + Performs a check + """ + expected_shape = (self.num_channels,) + spatial_shape( + self.num_spatial_dims, self.num_points + ) + if u.shape != expected_shape: + raise ValueError( + f"Expected shape {expected_shape}, got {u.shape}. For batched operation use `jax.vmap` on this function." + ) + return self.step(u) diff --git a/exponax/exponential_integrators.py b/exponax/exponential_integrators.py new file mode 100644 index 0000000..efbb1c9 --- /dev/null +++ b/exponax/exponential_integrators.py @@ -0,0 +1,281 @@ +import jax.numpy as jnp +import equinox as eqx +from jaxtyping import Complex, Array, Float +from typing import Callable + +from .nonlinear_functions import BaseNonlinearFun, ZeroNonlinearFun + +# E can either be 1 (single channel) or num_channels (multi-channel) for either +# the same linear operator for each channel or a different linear operator for +# each channel, respectively. +# +# So far, we do **not** support channel mixing via the linear operator (for +# example if we solved the wave equation or the sine-Gordon equation). + + +class BaseETDRK(eqx.Module): + dt: float + _exp_term: Complex[Array, "E ... (N//2)+1"] + + def __init__( + self, + dt: float, + linear_operator: Complex[Array, "E ... (N//2)+1"], + ): + self.dt = dt + self._exp_term = jnp.exp(self.dt * linear_operator) + + def step_fourier( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + """ + Advance the state in Fourier space. + """ + raise NotImplementedError("Must be implemented by subclass") + + +class ETDRK0(BaseETDRK): + """ + Exactly solve a linear PDE in Fourier space + """ + + def step_fourier( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + return self._exp_term * u_hat + + +def roots_of_unity(M: int) -> Complex[Array, "M"]: + """ + Return (complex-valued) array with M roots of unity. + """ + # return jnp.exp(1j * jnp.pi * (jnp.arange(1, M+1) - 0.5) / M) + return jnp.exp(2j * jnp.pi * (jnp.arange(1, M + 1) - 0.5) / M) + + +class ETDRK1(BaseETDRK): + _nonlinear_fun: BaseNonlinearFun + _coef_1: Complex[Array, "E ... (N//2)+1"] + + def __init__( + self, + dt: float, + linear_operator: Complex[Array, "E ... (N//2)+1"], + nonlinear_fun: BaseNonlinearFun, + *, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + super().__init__(dt, linear_operator) + self._nonlinear_fun = nonlinear_fun + + LR = ( + circle_radius * roots_of_unity(n_circle_points) + + linear_operator[..., jnp.newaxis] * dt + ) + + self._coef_1 = dt * jnp.mean((jnp.exp(LR) - 1) / LR, axis=-1).real + + def step_fourier( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + return self._exp_term * u_hat + self._coef_1 * self._nonlinear_fun(u_hat) + + +class ETDRK2(BaseETDRK): + _nonlinear_fun: BaseNonlinearFun + _coef_1: Complex[Array, "E ... (N//2)+1"] + _coef_2: Complex[Array, "E ... (N//2)+1"] + + def __init__( + self, + dt: float, + linear_operator: Complex[Array, "E ... (N//2)+1"], + nonlinear_fun: BaseNonlinearFun, + *, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + super().__init__(dt, linear_operator) + self._nonlinear_fun = nonlinear_fun + + LR = ( + circle_radius * roots_of_unity(n_circle_points) + + linear_operator[..., jnp.newaxis] * dt + ) + + self._coef_1 = dt * jnp.mean((jnp.exp(LR) - 1) / LR, axis=-1).real + + self._coef_2 = dt * jnp.mean((jnp.exp(LR) - 1 - LR) / LR**2, axis=-1).real + + def step_fourier( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + u_nonlin_hat = self._nonlinear_fun(u_hat) + u_stage_1_hat = self._exp_term * u_hat + self._coef_1 * u_nonlin_hat + + u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat) + u_next_hat = u_stage_1_hat + self._coef_2 * ( + u_stage_1_nonlin_hat - u_nonlin_hat + ) + return u_next_hat + + +class ETDRK3(BaseETDRK): + _nonlinear_fun: BaseNonlinearFun + _half_exp_term: Complex[Array, "E ... (N//2)+1"] + _coef_1: Complex[Array, "E ... (N//2)+1"] + _coef_2: Complex[Array, "E ... (N//2)+1"] + _coef_3: Complex[Array, "E ... (N//2)+1"] + _coef_4: Complex[Array, "E ... (N//2)+1"] + _coef_5: Complex[Array, "E ... (N//2)+1"] + + def __init__( + self, + dt: float, + linear_operator: Complex[Array, "E ... (N//2)+1"], + nonlinear_fun: BaseNonlinearFun, + *, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + super().__init__(dt, linear_operator) + self._nonlinear_fun = nonlinear_fun + self._half_exp_term = jnp.exp(0.5 * dt * linear_operator) + + LR = ( + circle_radius * roots_of_unity(n_circle_points) + + linear_operator[..., jnp.newaxis] * dt + ) + + self._coef_1 = dt * jnp.mean((jnp.exp(LR / 2) - 1) / LR, axis=-1).real + + self._coef_2 = dt * jnp.mean((jnp.exp(LR) - 1) / LR, axis=-1).real + + self._coef_3 = ( + dt + * jnp.mean( + (-4 - LR + jnp.exp(LR) * (4 - 3 * LR + LR**2)) / (LR**3), axis=-1 + ).real + ) + + self._coef_4 = ( + dt + * jnp.mean( + (4.0 * (2.0 + LR + jnp.exp(LR) * (-2 + LR))) / (LR**3), axis=-1 + ).real + ) + + self._coef_5 = ( + dt + * jnp.mean( + (-4 - 3 * LR - LR**2 + jnp.exp(LR) * (4 - LR)) / (LR**3), axis=-1 + ).real + ) + + def step_fourier( + self, + u_hat: Complex[Array, "E ... (N//2)+1"], + ) -> Complex[Array, "E ... (N//2)+1"]: + u_nonlin_hat = self._nonlinear_fun(u_hat) + u_stage_1_hat = self._half_exp_term * u_hat + self._coef_1 * u_nonlin_hat + + u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat) + u_stage_2_hat = self._exp_term * u_hat + self._coef_2 * ( + 2 * u_stage_1_nonlin_hat - u_nonlin_hat + ) + + u_stage_2_nonlin_hat = self._nonlinear_fun(u_stage_2_hat) + + u_next_hat = ( + self._exp_term * u_hat + + self._coef_3 * u_nonlin_hat + + self._coef_4 * u_stage_1_nonlin_hat + + self._coef_5 * u_stage_2_nonlin_hat + ) + + return u_next_hat + + +class ETDRK4(BaseETDRK): + _nonlinear_fun: BaseNonlinearFun + _half_exp_term: Complex[Array, "E ... (N//2)+1"] + _coef_1: Complex[Array, "E ... (N//2)+1"] + _coef_2: Complex[Array, "E ... (N//2)+1"] + _coef_3: Complex[Array, "E ... (N//2)+1"] + _coef_4: Complex[Array, "E ... (N//2)+1"] + _coef_5: Complex[Array, "E ... (N//2)+1"] + _coef_6: Complex[Array, "E ... (N//2)+1"] + + def __init__( + self, + dt: float, + linear_operator: Complex[Array, "E ... (N//2)+1"], + nonlinear_fun: BaseNonlinearFun, + *, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + super().__init__(dt, linear_operator) + self._nonlinear_fun = nonlinear_fun + self._half_exp_term = jnp.exp(0.5 * dt * linear_operator) + + LR = ( + circle_radius * roots_of_unity(n_circle_points) + + linear_operator[..., jnp.newaxis] * dt + ) + + self._coef_1 = dt * jnp.mean((jnp.exp(LR / 2) - 1) / LR, axis=-1).real + + self._coef_2 = self._coef_1 + self._coef_3 = self._coef_1 + + self._coef_4 = ( + dt + * jnp.mean( + (-4 - LR + jnp.exp(LR) * (4 - 3 * LR + LR**2)) / (LR**3), axis=-1 + ).real + ) + + self._coef_5 = ( + dt * jnp.mean((2 + LR + jnp.exp(LR) * (-2 + LR)) / (LR**3), axis=-1).real + ) + + self._coef_6 = ( + dt + * jnp.mean( + (-4 - 3 * LR - LR**2 + jnp.exp(LR) * (4 - LR)) / (LR**3), axis=-1 + ).real + ) + + def step_fourier( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + u_nonlin_hat = self._nonlinear_fun(u_hat) + u_stage_1_hat = self._half_exp_term * u_hat + self._coef_1 * u_nonlin_hat + + u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat) + u_stage_2_hat = ( + self._half_exp_term * u_hat + self._coef_2 * u_stage_1_nonlin_hat + ) + + u_stage_2_nonlin_hat = self._nonlinear_fun(u_stage_2_hat) + u_stage_3_hat = self._half_exp_term * u_stage_1_hat + self._coef_3 * ( + 2 * u_stage_2_nonlin_hat - u_nonlin_hat + ) + + u_stage_3_nonlin_hat = self._nonlinear_fun(u_stage_3_hat) + + u_next_hat = ( + self._exp_term * u_hat + + self._coef_4 * u_nonlin_hat + + self._coef_5 * 2 * (u_stage_1_nonlin_hat + u_stage_2_nonlin_hat) + + self._coef_6 * u_stage_3_nonlin_hat + ) + + return u_next_hat diff --git a/exponax/forced_stepper.py b/exponax/forced_stepper.py new file mode 100644 index 0000000..ff2d421 --- /dev/null +++ b/exponax/forced_stepper.py @@ -0,0 +1,103 @@ +from typing import Any +import equinox as eqx +from .base_stepper import BaseStepper + +from jaxtyping import Array, Float, Complex + + +class ForcedStepper(eqx.Module): + stepper: BaseStepper + + def __init__( + self, + stepper: BaseStepper, + ): + """ + Transform a stepper of signature `(u,) -> u_next` into a stepper of + signature `(u, f) -> u_next` that also accepts a forcing vector `f`. + + Transforms a stepper for a PDE of the form u_t = Lu + N(u) into a stepper + for a PDE of the form u_t = Lu + N(u) + f, where f is a forcing term. For + this, we split by operators + + v_t = f + + u_t = Lv + N(v) + + Since we assume to only have access to the forcing function evaluated at one + time level (but on the same grid as the state), we use a forward Euler + scheme to integrate the first equation. The second equation is integrated + using the original stepper. + + Note: This operator splitting makes the total scheme only first order + accurate in time. It is a quick hack to extend the other sophisticated + transient integrators to forced problems. + + **Arguments**: + - `stepper`: The stepper to be transformed. + """ + self.stepper = stepper + + def step( + self, + u: Float[Array, "C ... N"], + f: Float[Array, "C ... N"], + ) -> Float[Array, "C ... N"]: + """ + Step the PDE forward in time by one time step given the current state + `u` and the forcing term `f`. + + The forcing term `f` is assumed to be evaluated on the same grid as `u`. + + **Arguments**: + - `u`: The current state. + - `f`: The forcing term. + + **Returns**: + - `u_next`: The state after one time step. + """ + u_with_force = u + self.stepper.dt * f + return self.stepper.step(u_with_force) + + def step_fourier( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + f_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + """ + Step the PDE forward in time by one time step given the current state + `u_hat` in Fourier space and the forcing term `f_hat` in Fourier space. + + The forcing term `f_hat` is assumed to be evaluated on the same grid as + `u_hat`. + + **Arguments**: + - `u_hat`: The current state in Fourier space. + - `f_hat`: The forcing term in Fourier space. + + **Returns**: + - `u_next_hat`: The state after one time step in Fourier space. + """ + u_hat_with_force = u_hat + self.stepper.dt * f_hat + return self.stepper.step_fourier(u_hat_with_force) + + def __call__( + self, + u: Float[Array, "C ... N"], + f: Float[Array, "C ... N"], + ) -> Float[Array, "C ... N"]: + """ + Step the PDE forward in time by one time step given the current state + `u` and the forcing term `f`. + + The forcing term `f` is assumed to be evaluated on the same grid as `u`. + + **Arguments**: + - `u`: The current state. + - `f`: The forcing term. + + **Returns**: + - `u_next`: The state after one time step. + """ + + return self.step(u, f) diff --git a/exponax/initial_conditions.py b/exponax/initial_conditions.py new file mode 100644 index 0000000..371e296 --- /dev/null +++ b/exponax/initial_conditions.py @@ -0,0 +1,472 @@ +import jax.numpy as jnp +import jax.random as jr +from typing import List +import equinox as eqx +from jaxtyping import Complex, Array, Float, PRNGKeyArray + +from abc import ABC, abstractmethod +from typing import Optional + +from .sample_stepper import Diffusion +from .spectral import ( + build_scaled_wavenumbers, + spatial_shape, + wavenumber_shape, + low_pass_filter_mask, + space_indices, + build_scaling_array, +) +from .utils import get_grid + +### --- Base classes --- ### + + +class BaseIC(eqx.Module, ABC): + + @abstractmethod + def __call__(self, x: Float[Array, "D ... N"]) -> Float[Array, "1 ... N"]: + """ + Evaluate the initial condition. + + **Arguments**: + - `x`: The grid points. + + **Returns**: + - `u`: The initial condition evaluated at the grid points. + """ + pass + + +class BaseRandomICGenerator(eqx.Module): + num_spatial_dims: int + domain_extent: float + indexing: str = "ij" + + def gen_ic_fun(self, num_points: int, *, key: PRNGKeyArray) -> BaseIC: + """ + Generate an initial condition function. + + **Arguments**: + - `num_points`: The number of grid points in each dimension. + - `key`: A jax random key. + + **Returns**: + - `ic`: An initial condition function that can be evaluated at + degree of freedom locations. + """ + raise NotImplementedError( + "This random ic generator cannot represent its initial condition as a function. Directly evaluate it." + ) + + def __call__( + self, + num_points: int, + *, + key: PRNGKeyArray, + ) -> Float[Array, "1 ... N"]: + """ + Generate a random initial condition. + + **Arguments**: + - `num_points`: The number of grid points in each dimension. + - `key`: A jax random key. + - `indexing`: The indexing convention for the grid. + + **Returns**: + - `u`: The initial condition evaluated at the grid points. + """ + ic_fun = self.gen_ic_fun(num_points, key=key) + grid = get_grid( + self.num_spatial_dims, + self.domain_extent, + num_points, + indexing=self.indexing, + ) + return ic_fun(grid) + + +### Utilities to create ICs for multi-channel fields + + +class MultiChannelIC(eqx.Module): + initial_conditions: List[BaseIC] + + def __call__(self, x: Float[Array, "D ... N"]) -> Float[Array, "C ... N"]: + """ + Evaluate the initial condition. + + **Arguments**: + - `x`: The grid points. + + **Returns**: + - `u`: The initial condition evaluated at the grid points. + """ + return jnp.concatenate([ic(x) for ic in self.initial_conditions], axis=0) + + +class RandomMultiChannelICGenerator(eqx.Module): + ic_generators: List[BaseRandomICGenerator] + + def gen_ic_fun(self, num_points: int, *, key: PRNGKeyArray) -> MultiChannelIC: + ic_funs = [ + ic_gen.gen_ic_fun(num_points, key=key) for ic_gen in self.ic_generators + ] + return MultiChannelIC(ic_funs) + + def __call__( + self, num_points: int, *, key: PRNGKeyArray + ) -> Float[Array, "C ... N"]: + u_list = [ic_gen(num_points, key=key) for ic_gen in self.ic_generators] + return jnp.concatenate(u_list, axis=0) + + +### New version + +# class TruncatedFourierSeries(BaseIC): +# coefficient_array: Complex[Array, "1 ... (N//2)+1"] + +# def __init__( +# self, +# D: int, +# L: float, # unused +# N: int, +# *, +# coefficient_array: Complex[Array, "1 ... N"], +# ): +# super().__init__(D, N) +# self.coefficient_array = coefficient_array + +# def evaluate(self, x: Float[Array, "D ... N"]) -> Float[Array, "1 ... N"]: +# return jnp.fft.irfftn( +# self.coefficient_array, +# s=spatial_shape(self.D, self.N), +# axes=space_indices(self.D), +# ) + + +class RandomTruncatedFourierSeries(BaseRandomICGenerator): + num_spatial_dims: int + domain_extent: float + cutoff: int + amplitude_range: tuple[int, int] + angle_range: tuple[int, int] + offset_range: tuple[int, int] + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float = 1.0, + *, + cutoff: int = 10, + amplitude_range: tuple[int, int] = (-1.0, 1.0), + angle_range: tuple[int, int] = (0.0, 2.0 * jnp.pi), + offset_range: tuple[int, int] = (0.0, 0.0), # no offset by default + ): + self.num_spatial_dims = num_spatial_dims + self.domain_extent = domain_extent + + self.cutoff = cutoff + self.amplitude_range = amplitude_range + self.angle_range = angle_range + self.offset_range = offset_range + + def __call__( + self, num_points: int, *, key: PRNGKeyArray + ) -> Float[Array, "1 ... N"]: + fourier_noise_shape = (1,) + wavenumber_shape(self.num_spatial_dims, num_points) + amplitude_key, angle_key, offset_key = jr.split(key, 3) + + amplitude = jr.uniform( + amplitude_key, + shape=fourier_noise_shape, + minval=self.amplitude_range[0], + maxval=self.amplitude_range[1], + ) + angle = jr.uniform( + angle_key, + shape=fourier_noise_shape, + minval=self.angle_range[0], + maxval=self.angle_range[1], + ) + + fourier_noise = amplitude * jnp.exp(1j * angle) + + low_pass_filter = low_pass_filter_mask( + self.num_spatial_dims, num_points, cutoff=self.cutoff, axis_separate=True + ) + + fourier_noise = fourier_noise * low_pass_filter + + offset = jr.uniform( + offset_key, + shape=(1,), + minval=self.offset_range[0], + maxval=self.offset_range[1], + )[0] + fourier_noise = ( + fourier_noise.flatten().at[0].set(offset).reshape(fourier_noise_shape) + ) + + fourier_noise = fourier_noise * build_scaling_array( + self.num_spatial_dims, num_points + ) + + u = jnp.fft.irfftn( + fourier_noise, + s=spatial_shape(self.num_spatial_dims, num_points), + axes=space_indices(self.num_spatial_dims), + ) + + return u + + +### --- Legacy Sine Waves (truncated Fourier series) --- ### + +# class SineWaves(BaseIC): +# L: float +# filter_mask: Float[Array, "1 ... (N//2)+1"] +# zero_mean: bool +# key: PRNGKeyArray + + +# def __init__( +# self, +# D: int, +# L: float, +# N: int, +# *, +# cutoff: int, +# zero_mean: bool, +# axis_separate: bool = True, +# key: PRNGKeyArray, +# ): +# super().__init__(D, N) +# self.L = L +# self.filter_mask = low_pass_filter_mask(D, N, cutoff=cutoff, axis_separate=axis_separate) +# self.zero_mean = zero_mean +# self.key = key + +# def evaluate(self, x: Float[Array, "D ... N"]) -> Float[Array, "1 ... N"]: +# noise_shape = (1,) + spatial_shape(self.D, self.N) + +# noise = jr.normal(self.key, shape=noise_shape) +# noise_hat = jnp.fft.rfftn(noise, axes=space_indices(self.D)) +# noise_hat = noise_hat * self.filter_mask + +# noise = jnp.fft.irfftn(noise_hat, s=spatial_shape(self.D, self.N), axes=space_indices(self.D)) + +# if self.zero_mean: +# noise = noise - jnp.mean(noise) + +# return noise + +# class RandomSineWaves(BaseRandomICGenerator): +# D: int +# L: float +# N: int +# cutoff: int +# zero_mean: bool +# axis_separate: bool + +# def __init__( +# self, +# D: int, +# L: float, +# N: int, +# *, +# cutoff: int, +# zero_mean: bool, +# axis_separate: bool = True, +# ): +# """ +# Randomly generated initial condition consisting of a truncated Fourier series. + +# Arguments are drawn from uniform distributions. + +# **Arguments**: +# - `D`: The dimension of the domain. +# - `N`: The number of grid points in each dimension. +# - `L`: The length of the domain. +# - `cutoff`: The cutoff wavenumber. +# - `zero_mean`: Whether to subtract the mean. +# - `axis_separate`: Whether to draw the wavenumber cutoffs for each +# axis separately. +# """ +# self.D = D +# self.N = N +# self.L = L +# self.cutoff = cutoff +# self.zero_mean = zero_mean +# self.axis_separate = axis_separate + +# def __call__(self, key: PRNGKeyArray) -> SineWaves: +# return SineWaves( +# self.D, +# self.L, +# self.N, +# cutoff=self.cutoff, +# zero_mean=self.zero_mean, +# axis_separate=self.axis_separate, +# key=key, +# ) + + +# --- Diffused Noise --- ### + +# class DiffusedNoise(BaseIC): +# L: float +# intensity: float +# zero_mean: bool +# key: PRNGKeyArray + +# def __init__( +# self, +# D: int, +# L: float, +# N: int, +# *, +# intensity: float, +# zero_mean: bool, +# key: PRNGKeyArray, +# ): +# super().__init__(D, N) +# self.L = L +# self.intensity = intensity +# self.zero_mean = zero_mean +# self.key = key + +# def evaluate(self, x: Float[Array, "D ... N"]) -> Float[Array, "1 ... N"]: +# noise_shape = (1,) + spatial_shape(self.D, self.N) +# noise = jr.normal(self.key, shape=noise_shape) + +# diffusion_stepper = Diffusion(self.D, self.L, self.N, 1.0, diffusivity=self.intensity) +# ic = diffusion_stepper(noise) + +# if self.zero_mean: +# ic = ic - jnp.mean(ic) + +# return ic + + +class DiffusedNoise(BaseRandomICGenerator): + num_spatial_dims: int + domain_extent: float + intensity: float + zero_mean: bool + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float = 1.0, + *, + intensity=0.001, + zero_mean: bool = False, + ): + """ + Randomly generated initial condition consisting of a diffused noise field. + + Arguments are drawn from uniform distributions. + + **Arguments**: + - `D`: The dimension of the domain. + - `L`: The length of the domain. + - `N`: The number of grid points in each dimension. + - `intensity`: The diffusivity. + - `zero_mean`: Whether to subtract the mean. + """ + self.num_spatial_dims = num_spatial_dims + self.domain_extent = domain_extent + self.intensity = intensity + self.zero_mean = zero_mean + + def __call__( + self, num_points: int, *, key: PRNGKeyArray + ) -> Float[Array, "1 ... N"]: + noise_shape = (1,) + spatial_shape(self.num_spatial_dims, num_points) + noise = jr.normal(key, shape=noise_shape) + + diffusion_stepper = Diffusion( + self.num_spatial_dims, + self.domain_extent, + num_points, + 1.0, + diffusivity=self.intensity, + ) + ic = diffusion_stepper(noise) + + if self.zero_mean: + ic = ic - jnp.mean(ic) + + return ic + + +### Gausian Random Field ### + + +class GaussianRandomField(BaseRandomICGenerator): + num_spatial_dims: int + domain_extent: float + powerlaw_exponent: float + normalize: bool + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float = 1.0, + *, + powerlaw_exponent: float = 3.0, + normalize: bool = True, + ): + """ + Randomly generated initial condition consisting of a Gaussian random field. + """ + self.num_spatial_dims = num_spatial_dims + self.domain_extent = domain_extent + self.powerlaw_exponent = powerlaw_exponent + self.normalize = normalize + + def __call__( + self, num_points: int, *, key: PRNGKeyArray + ) -> Float[Array, "1 ... N"]: + wavenumber_grid = build_scaled_wavenumbers( + self.num_spatial_dims, self.domain_extent, num_points + ) + wavenumer_norm_grid = jnp.linalg.norm(wavenumber_grid, axis=0, keepdims=True) + amplitude = jnp.power(wavenumer_norm_grid, -self.powerlaw_exponent / 2.0) + amplitude = ( + amplitude.flatten().at[0].set(0.0).reshape(wavenumer_norm_grid.shape) + ) + + real_key, imag_key = jr.split(key, 2) + noise = jr.normal( + real_key, + shape=(1,) + wavenumber_shape(self.num_spatial_dims, num_points), + ) + 1j * jr.normal( + imag_key, + shape=(1,) + wavenumber_shape(self.num_spatial_dims, num_points), + ) + + noise = noise * amplitude + + ic = jnp.fft.irfftn( + noise, + s=spatial_shape(self.num_spatial_dims, num_points), + axes=space_indices(self.num_spatial_dims), + ) + + if self.normalize: + ic = ic - jnp.mean(ic) + ic = ic / jnp.std(ic) + + return ic + + +### Discontinuities ### + + +class Discontinuities(BaseIC): + pass + + +class RandomDiscontinuities(BaseRandomICGenerator): + pass diff --git a/exponax/nonlinear_functions/__init__.py b/exponax/nonlinear_functions/__init__.py new file mode 100644 index 0000000..28c821b --- /dev/null +++ b/exponax/nonlinear_functions/__init__.py @@ -0,0 +1,14 @@ +from .base import BaseNonlinearFun +from .convection import ConvectionNonlinearFun +from .gradient_norm import GradientNormNonlinearFun +from .polynomial import PolynomialNonlinearFun +from .reaction import ( + GrayScottNonlinearFun, + CahnHilliardNonlinearFun, + BelousovZhabotinskyNonlinearFun, +) +from .vorticity_convection import ( + VorticityConvection2d, + VorticityConvection2dKolmogorov, +) +from .zero import ZeroNonlinearFun diff --git a/exponax/nonlinear_functions/base.py b/exponax/nonlinear_functions/base.py new file mode 100644 index 0000000..cc51496 --- /dev/null +++ b/exponax/nonlinear_functions/base.py @@ -0,0 +1,69 @@ +import jax +import jax.numpy as jnp +import equinox as eqx +from jaxtyping import Complex, Array, Float, Bool +from ..spectral import ( + wavenumber_shape, + low_pass_filter_mask, +) +from abc import ABC, abstractmethod + + +class BaseNonlinearFun(eqx.Module, ABC): + num_spatial_dims: int + num_points: int + num_channels: int + derivative_operator: Complex[Array, "D ... (N//2)+1"] + dealiasing_mask: Bool[Array, "1 ... (N//2)+1"] + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + num_channels: int, + *, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + dealiasing_fraction: float, + ): + self.num_spatial_dims = num_spatial_dims + self.num_points = num_points + self.num_channels = num_channels + self.derivative_operator = derivative_operator + + # Can be done because num_points is identical in all spatial dimensions + nyquist_mode = (num_points // 2) + 1 + highest_resolved_mode = nyquist_mode - 1 + start_of_aliased_modes = dealiasing_fraction * highest_resolved_mode + + self.dealiasing_mask = low_pass_filter_mask( + num_spatial_dims, + num_points, + cutoff=start_of_aliased_modes - 1, + ) + + @abstractmethod + def evaluate( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + """ + Evaluate all potential nonlinearities "pseudo-spectrally", account for dealiasing. + """ + raise NotImplementedError("Must be implemented by subclass") + + def __call__( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + """ + Perform check + """ + expected_shape = (self.num_channels,) + wavenumber_shape( + self.num_spatial_dims, self.num_points + ) + if u_hat.shape != expected_shape: + raise ValueError( + f"Expected shape {expected_shape}, got {u_hat.shape}. For batched operation use `jax.vmap` on this function." + ) + + return self.evaluate(u_hat) diff --git a/exponax/nonlinear_functions/convection.py b/exponax/nonlinear_functions/convection.py new file mode 100644 index 0000000..a4854d8 --- /dev/null +++ b/exponax/nonlinear_functions/convection.py @@ -0,0 +1,68 @@ +import jax +import jax.numpy as jnp +import equinox as eqx +from jaxtyping import Complex, Array, Float, Bool +from ..spectral import ( + space_indices, + spatial_shape, +) + +from .base import BaseNonlinearFun + + +class ConvectionNonlinearFun(BaseNonlinearFun): + convection_scale: float + zero_mode_fix: bool + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + num_channels: int, + *, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + dealiasing_fraction: float, + convection_scale: float = 0.5, + zero_mode_fix: bool = False, + ): + self.convection_scale = convection_scale + self.zero_mode_fix = zero_mode_fix + super().__init__( + num_spatial_dims, + num_points, + num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=dealiasing_fraction, + ) + + def zero_fix( + self, + f: Float[Array, "... N"], + ): + return f - jnp.mean(f) + + def evaluate( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + u_hat_dealiased = self.dealiasing_mask * u_hat + u = jnp.fft.irfftn( + u_hat_dealiased, + s=spatial_shape(self.num_spatial_dims, self.num_points), + axes=space_indices(self.num_spatial_dims), + ) + u_outer_product = u[:, None] * u[None, :] + + if self.zero_mode_fix: + # Maybe there is more efficient way + u_outer_product = jax.vmap(self.zero_fix)(u_outer_product) + + u_outer_product_hat = jnp.fft.rfftn( + u_outer_product, axes=space_indices(self.num_spatial_dims) + ) + u_divergence_on_outer_product_hat = jnp.sum( + self.derivative_operator[None, :] * u_outer_product_hat, + axis=1, + ) + # Requires minus to move term to the rhs + return -self.convection_scale * u_divergence_on_outer_product_hat diff --git a/exponax/nonlinear_functions/gradient_norm.py b/exponax/nonlinear_functions/gradient_norm.py new file mode 100644 index 0000000..957dedf --- /dev/null +++ b/exponax/nonlinear_functions/gradient_norm.py @@ -0,0 +1,73 @@ +import jax +import jax.numpy as jnp +import equinox as eqx +from jaxtyping import Complex, Array, Float, Bool +from ..spectral import ( + space_indices, + spatial_shape, +) + +from .base import BaseNonlinearFun + + +class GradientNormNonlinearFun(BaseNonlinearFun): + scale: float + zero_mode_fix: bool + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + num_channels: int, + *, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + dealiasing_fraction: float, + zero_mode_fix: bool = True, + scale: float = 0.5, + ): + super().__init__( + num_spatial_dims, + num_points, + num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=dealiasing_fraction, + ) + self.zero_mode_fix = zero_mode_fix + self.scale = scale + + def zero_fix( + self, + f: Float[Array, "... N"], + ): + return f - jnp.mean(f) + + def evaluate( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + u_gradient_hat = self.derivative_operator[None, :] * u_hat[:, None] + u_gradient_dealiased_hat = self.dealiasing_mask * u_gradient_hat + u_gradient = jnp.fft.irfftn( + u_gradient_dealiased_hat, + s=spatial_shape(self.num_spatial_dims, self.num_points), + axes=space_indices(self.num_spatial_dims), + ) + + # Reduces the axis introduced by the gradient + u_gradient_norm_squared = jnp.sum(u_gradient**2, axis=1) + + if self.zero_mode_fix: + # Maybe there is more efficient way + u_gradient_norm_squared = jax.vmap(self.zero_fix)(u_gradient_norm_squared) + + u_gradient_norm_squared_hat = jnp.fft.rfftn( + u_gradient_norm_squared, axes=space_indices(self.num_spatial_dims) + ) + # if self.zero_mode_fix: + # # Fix the mean mode + # u_gradient_norm_squared_hat = u_gradient_norm_squared_hat.at[..., 0].set( + # u_hat[..., 0] + # ) + + # Requires minus to move term to the rhs + return -self.scale * u_gradient_norm_squared_hat diff --git a/exponax/nonlinear_functions/polynomial.py b/exponax/nonlinear_functions/polynomial.py new file mode 100644 index 0000000..46bd1d5 --- /dev/null +++ b/exponax/nonlinear_functions/polynomial.py @@ -0,0 +1,61 @@ +import jax +import jax.numpy as jnp +import equinox as eqx +from jaxtyping import Complex, Array, Float, Bool +from ..spectral import ( + space_indices, + spatial_shape, +) + +from .base import BaseNonlinearFun + + +class PolynomialNonlinearFun(BaseNonlinearFun): + """ + Channel-separate evaluation; and no mixed terms. + """ + + coefficients: list[float] # Starting from order 0 + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + num_channels: int, + *, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + dealiasing_fraction: float, + coefficients: list[float], + ): + """ + Coefficient list starts from order 0. + """ + super().__init__( + num_spatial_dims, + num_points, + num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=dealiasing_fraction, + ) + self.coefficients = coefficients + + def evaluate( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + u_hat_dealiased = self.dealiasing_mask * u_hat + u = jnp.fft.irfftn( + u_hat_dealiased, + s=spatial_shape(self.num_spatial_dims, self.num_points), + axes=space_indices(self.num_spatial_dims), + ) + u_power = 1.0 + u_nonlin = 0.0 + for coeff in self.coefficients: + u_nonlin += coeff * u_power + u_power = u_power * u + + u_nonlin_hat = jnp.fft.rfftn( + u_nonlin, axes=space_indices(self.num_spatial_dims) + ) + return u_nonlin_hat diff --git a/exponax/nonlinear_functions/reaction.py b/exponax/nonlinear_functions/reaction.py new file mode 100644 index 0000000..ad44a98 --- /dev/null +++ b/exponax/nonlinear_functions/reaction.py @@ -0,0 +1,146 @@ +""" +Nonlinear terms as they are found in reaction-diffusion(-advection) equations. +""" + +import jax +import jax.numpy as jnp +import equinox as eqx +from jaxtyping import Complex, Array, Float, Bool +from typing import Callable +from ..spectral import ( + space_indices, + spatial_shape, + build_laplace_operator, +) + +from .base import BaseNonlinearFun + + +class GrayScottNonlinearFun(BaseNonlinearFun): + b: float + d: float + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + num_channels: int, + *, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + dealiasing_fraction: float, + b: float, + d: float, + ): + if num_channels != 2: + raise ValueError(f"Expected num_channels = 2, got {num_channels}.") + super().__init__( + num_spatial_dims, + num_points, + num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=dealiasing_fraction, + ) + self.b = b + self.d = d + + def evaluate( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + u_hat_dealiased = self.dealiasing_mask * u_hat + u = jnp.fft.irfftn( + u_hat_dealiased, + s=spatial_shape(self.num_spatial_dims, self.num_points), + axes=space_indices(self.num_spatial_dims), + ) + u_power = jnp.stack( + [ + self.b * (1 - u[0]) - u[0] * u[1] ** 2, + -self.d * u[1] + u[0] * u[1] ** 2, + ] + ) + u_power_hat = jnp.fft.rfftn(u_power, axes=space_indices(self.num_spatial_dims)) + return u_power_hat + + +class CahnHilliardNonlinearFun(BaseNonlinearFun): + def __init__( + self, + num_spatial_dims: int, + num_points: int, + num_channels: int, + *, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + dealiasing_fraction: float, + ): + if num_channels != 1: + raise ValueError(f"Expected num_channels = 1, got {num_channels}.") + super().__init__( + num_spatial_dims, + num_points, + num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=dealiasing_fraction, + ) + + def evaluate( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + u_hat_dealiased = self.dealiasing_mask * u_hat + u = jnp.fft.irfftn( + u_hat_dealiased, + s=spatial_shape(self.num_spatial_dims, self.num_points), + axes=space_indices(self.num_spatial_dims), + ) + u_power = u[0] ** 3 + u_power_hat = jnp.fft.rfftn(u_power, axes=space_indices(self.num_spatial_dims)) + u_power_laplace_hat = ( + build_laplace_operator(self.derivative_operator, order=2) * u_power_hat + ) + return u_power_laplace_hat + + +class BelousovZhabotinskyNonlinearFun(BaseNonlinearFun): + """ + Taken from: https://github.com/chebfun/chebfun/blob/db207bc9f48278ca4def15bf90591bfa44d0801d/spin.m#L73 + """ + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + num_channels: int, + *, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + dealiasing_fraction: float, + ): + if num_channels != 3: + raise ValueError(f"Expected num_channels = 3, got {num_channels}.") + super().__init__( + num_spatial_dims, + num_points, + num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=dealiasing_fraction, + ) + + def evaluate( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + u_hat_dealiased = self.dealiasing_mask * u_hat + u = jnp.fft.irfftn( + u_hat_dealiased, + s=spatial_shape(self.num_spatial_dims, self.num_points), + axes=space_indices(self.num_spatial_dims), + ) + u_power = jnp.stack( + [ + u[0] + u[1] - u[0] * u[1] - u[0] ** 2, + u[2] - u[1] - u[0] * u[1], + u[0] - u[2], + ] + ) + u_power_hat = jnp.fft.rfftn(u_power, axes=space_indices(self.num_spatial_dims)) + return u_power_hat diff --git a/exponax/nonlinear_functions/vorticity_convection.py b/exponax/nonlinear_functions/vorticity_convection.py new file mode 100644 index 0000000..847f3f6 --- /dev/null +++ b/exponax/nonlinear_functions/vorticity_convection.py @@ -0,0 +1,116 @@ +import jax +import jax.numpy as jnp +import equinox as eqx +from jaxtyping import Complex, Array, Float, Bool +from ..spectral import ( + build_laplace_operator, + build_wavenumbers, + build_scaling_array, +) + +from .base import BaseNonlinearFun + + +class VorticityConvection2d(BaseNonlinearFun): + inv_laplacian: Complex[Array, "1 ... (N//2)+1"] + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + num_channels: int, + *, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + dealiasing_fraction: float, + ): + if num_spatial_dims != 2: + raise ValueError(f"Expected num_spatial_dims = 2, got {num_spatial_dims}.") + if num_channels != 1: + raise ValueError(f"Expected num_channels = 1, got {num_channels}.") + + super().__init__( + num_spatial_dims, + num_points, + num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=dealiasing_fraction, + ) + + laplacian = build_laplace_operator(derivative_operator, order=2) + + # Uses the UNCHANGED mean solution to the Poisson equation (hence, the + # mean of the "right-hand side" will be the mean of the solution) + self.inv_laplacian = jnp.where(laplacian == 0, 1.0, 1 / laplacian) + + def evaluate( + self, u_hat: Complex[Array, "1 ... (N//2)+1"] + ) -> Complex[Array, "1 ... (N//2)+1"]: + vorticity_hat = u_hat + stream_function_hat = self.inv_laplacian * vorticity_hat + + u_hat = +self.derivative_operator[1:2] * stream_function_hat + v_hat = -self.derivative_operator[0:1] * stream_function_hat + del_vorticity_del_x_hat = self.derivative_operator[0:1] * vorticity_hat + del_vorticity_del_y_hat = self.derivative_operator[1:2] * vorticity_hat + + u = jnp.fft.irfft2( + u_hat * self.dealiasing_mask, s=(self.num_points, self.num_points) + ) + v = jnp.fft.irfft2( + v_hat * self.dealiasing_mask, s=(self.num_points, self.num_points) + ) + del_vorticity_del_x = jnp.fft.irfft2( + del_vorticity_del_x_hat * self.dealiasing_mask, + s=(self.num_points, self.num_points), + ) + del_vorticity_del_y = jnp.fft.irfft2( + del_vorticity_del_y_hat * self.dealiasing_mask, + s=(self.num_points, self.num_points), + ) + + convection = u * del_vorticity_del_x + v * del_vorticity_del_y + + convection_hat = jnp.fft.rfft2(convection) + + # Do we need another dealiasing mask here? + # convection_hat = self.dealiasing_mask * convection_hat + + # Requires minus to move term to the rhs + return -convection_hat + + +class VorticityConvection2dKolmogorov(VorticityConvection2d): + injection: Complex[Array, "1 ... (N//2)+1"] + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + num_channels: int, + *, + injection_mode: int = 4, + injection_scale: float = 1.0, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + dealiasing_fraction: float, + ): + super().__init__( + num_spatial_dims, + num_points, + num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=dealiasing_fraction, + ) + + wavenumbers = build_wavenumbers(num_spatial_dims, num_points) + injection_mask = (wavenumbers[0] == 0) & (wavenumbers[1] == injection_mode) + self.injection = jnp.where( + injection_mask, + injection_scale * build_scaling_array(num_spatial_dims, num_points), + 0.0, + ) + + def evaluate( + self, u_hat: Complex[Array, "1 ... (N//2)+1"] + ) -> Complex[Array, "1 ... (N//2)+1"]: + neg_convection_hat = super().evaluate(u_hat) + return neg_convection_hat + self.injection diff --git a/exponax/nonlinear_functions/zero.py b/exponax/nonlinear_functions/zero.py new file mode 100644 index 0000000..91e6033 --- /dev/null +++ b/exponax/nonlinear_functions/zero.py @@ -0,0 +1,31 @@ +import jax +import jax.numpy as jnp +import equinox as eqx +from jaxtyping import Complex, Array, Float, Bool + +from .base import BaseNonlinearFun + + +class ZeroNonlinearFun(BaseNonlinearFun): + def __init__( + self, + num_spatial_dims: int, + num_points: int, + num_channels: int, + *, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + dealiasing_fraction: float = 1.0, + ): + super().__init__( + num_spatial_dims, + num_points, + num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=dealiasing_fraction, + ) + + def evaluate( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + return jnp.zeros_like(u_hat) diff --git a/exponax/normalized_stepper/__init__.py b/exponax/normalized_stepper/__init__.py new file mode 100644 index 0000000..9f61198 --- /dev/null +++ b/exponax/normalized_stepper/__init__.py @@ -0,0 +1,11 @@ +from .convection import NormalizedConvectionStepper +from .gradient_norm import NormalizedGradientNormStepper +from .linear import NormalizedLinearStepper +from .utils import ( + denormalize_coefficients, + denormalize_convection_scale, + denormalize_gradient_norm_scale, + normalize_coefficients, + normalize_convection_scale, + normalize_gradient_norm_scale, +) diff --git a/exponax/normalized_stepper/convection.py b/exponax/normalized_stepper/convection.py new file mode 100644 index 0000000..04ae0fa --- /dev/null +++ b/exponax/normalized_stepper/convection.py @@ -0,0 +1,110 @@ +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import ConvectionNonlinearFun +from jaxtyping import Complex, Float, Array + + +class NormalizedConvectionStepper(BaseStepper): + normalized_coefficients: list[float] + normalized_convection_scale: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + *, + dt: float = 0.1, + normalized_coefficients: list[float] = [0.0, 0.0, 0.01 * 0.1], + normalized_convection_scale: float = 0.5, + order: int = 2, + dealiasing_fraction: float = 2 / 3, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + """ + By default: Behaves like a Burgers with + + ``` Burgers( + D=D, L=1, N=N, dt=dt, diffusivity=0.01, + ) + ``` + + If you set `L=2 * jnp.pi` of your unnormalized scenario, then you have + to set your coefficients to `alpha_i * dt` (make sure to use the same dt + as is used here as the keyword based argument). + + If you set `L=1` of your unnormalized scenario, then you have to set + your coefficients to `alpha_i * dt / (2 * jnp.pi)^s` (make sure to use + the same dt as is used here as the keyword based argument) **and** set + your convection scale to whatever you had prior multiplied by 2 * + jnp.pi. + + If you set `L=L` of your unnormalized scenario, then you have to set + your coefficients to `alpha_i * dt * (L / (2 * jnp.pi))^s` (make sure to + use the same dt as is used here as the keyword based argument) **and** + set your convection scale to whatever you had prior multiplied by 2 * + jnp.pi / L. + + number of channels grow with number of spatial dimensions + + **Arguments:** + + - `num_spatial_dims`: number of spatial dimensions + - `num_points`: number of points in each spatial dimension + - `dt`: time step (default: 0.1) + - `normalized_coefficients`: coefficients for the linear operator, + `normalized_coefficients[i]` is the coefficient for the `i`-th + derivative (default: [0.0, 0.0, 0.01 * 0.1] refers to a diffusion + (2nd) order term) + - `normalized_convection_scale`: convection scale for the nonlinear + function (default: 0.5) + - `order`: order of exponential time differencing Runge Kutta method, + can be 1, 2, 3, 4 (default: 2) + - `dealiasing_fraction`: fraction of the wavenumbers being kept before + applying any nonlinearity (default: 2/3) + - `n_circle_points`: number of points to use for the complex contour + integral when computing coefficients for the exponential time + differencing Runge Kutta method (default: 16) + - `circle_radius`: radius of the complex contour integral when computing + coefficients for the exponential time differencing Runge Kutta method + (default: 1.0) + """ + self.normalized_coefficients = normalized_coefficients + self.normalized_convection_scale = normalized_convection_scale + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=1.0, # Derivative operator is just scaled with 2 * jnp.pi + num_points=num_points, + dt=dt, + num_channels=num_spatial_dims, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator(self, derivative_operator: Array) -> Array: + # Now the linear operator is unscaled + linear_operator = sum( + jnp.sum( + c / self.dt * (derivative_operator) ** i, + axis=0, + keepdims=True, + ) + for i, c in enumerate(self.normalized_coefficients) + ) + return linear_operator + + def _build_nonlinear_fun(self, derivative_operator: Array): + return ConvectionNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + convection_scale=self.normalized_convection_scale, + ) diff --git a/exponax/normalized_stepper/gradient_norm.py b/exponax/normalized_stepper/gradient_norm.py new file mode 100644 index 0000000..40f6a0b --- /dev/null +++ b/exponax/normalized_stepper/gradient_norm.py @@ -0,0 +1,86 @@ +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import GradientNormNonlinearFun +from jaxtyping import Complex, Float, Array + + +class NormalizedGradientNormStepper(BaseStepper): + normalized_coefficients: list[float] + normalized_gradient_norm_scale: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + *, + dt: float = 0.1, + normalized_coefficients: list[float] = [0.0, 0.0, 0.01 * 0.1], + normalized_gradient_norm_scale: float = 0.5, + order: int = 2, + dealiasing_fraction: float = 2 / 3, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + """ + the number of channels do **not** grow with the number of spatial + dimensions. They are always 1. + + **Arguments:** + - `num_spatial_dims`: number of spatial dimensions + - `num_points`: number of points in each spatial dimension + - `dt`: time step (default: 0.1) + - `normalized_coefficients`: coefficients for the linear operator, + `normalized_coefficients[i]` is the coefficient for the `i`-th + derivative (default: [0.0, 0.0, 0.01 * 0.1] refers to a diffusion + operator) + - `normalized_gradient_norm_scale`: scale for the gradient norm + (default: 0.5) + - `order`: order of the derivative operator (default: 2) + - `dealiasing_fraction`: fraction of the wavenumbers being kept before + applying any nonlinearity (default: 2/3) + - `n_circle_points`: number of points to use for the complex contour + integral when computing coefficients for the exponential time + differencing Runge Kutta method (default: 16) + - `circle_radius`: radius of the complex contour integral when computing + coefficients for the exponential time differencing Runge Kutta method + (default: 1.0) + """ + self.normalized_coefficients = normalized_coefficients + self.normalized_gradient_norm_scale = normalized_gradient_norm_scale + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=1.0, # Derivative operator is just scaled with 2 * jnp.pi + num_points=num_points, + dt=dt, + num_channels=1, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator(self, derivative_operator: Array) -> Array: + linear_operator = sum( + jnp.sum( + c / self.dt * (derivative_operator) ** i, + axis=0, + keepdims=True, + ) + for i, c in enumerate(self.normalized_coefficients) + ) + return linear_operator + + def _build_nonlinear_fun(self, derivative_operator: Array): + return GradientNormNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + scale=self.normalized_gradient_norm_scale, + zero_mode_fix=True, + ) diff --git a/exponax/normalized_stepper/linear.py b/exponax/normalized_stepper/linear.py new file mode 100644 index 0000000..e39d804 --- /dev/null +++ b/exponax/normalized_stepper/linear.py @@ -0,0 +1,53 @@ +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import ZeroNonlinearFun +from jaxtyping import Complex, Float, Array + + +class NormalizedLinearStepper(BaseStepper): + normalized_coefficients: list[float] + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + *, + normalized_coefficients: list[float] = [0.0, -0.5, 0.01], + ): + """ + By default: advection-diffusion with normalized advection of 0.5, and + normalized diffusion of 0.01. + + Take care of the signs! + """ + self.normalized_coefficients = normalized_coefficients + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=1.0, # Derivative operator is just scaled with 2 * jnp.pi + num_points=num_points, + dt=1.0, + num_channels=1, + order=0, + ) + + def _build_linear_operator(self, derivative_operator: Array) -> Array: + linear_operator = sum( + jnp.sum( + c * (derivative_operator) ** i, + axis=0, + keepdims=True, + ) + for i, c in enumerate(self.normalized_coefficients) + ) + return linear_operator + + def _build_nonlinear_fun(self, derivative_operator: Array): + return ZeroNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + ) diff --git a/exponax/normalized_stepper/utils.py b/exponax/normalized_stepper/utils.py new file mode 100644 index 0000000..f0e20e9 --- /dev/null +++ b/exponax/normalized_stepper/utils.py @@ -0,0 +1,76 @@ +import jax.numpy as jnp + + +def normalize_coefficients( + domain_extent: float, + dt: float, + coefficients: tuple[float], +) -> tuple[float]: + """ + Normalize the coefficients to a linear time stepper to be used with the + normalized linear stepper. + + **Arguments:** + - `domain_extent`: extent of the domain + - `dt`: time step + - `coefficients`: coefficients for the linear operator, `coefficients[i]` is + the coefficient for the `i`-th derivative + """ + normalized_coefficients = tuple( + c * dt / (domain_extent**i) for i, c in enumerate(coefficients) + ) + return normalized_coefficients + + +def denormalize_coefficients( + domain_extent: float, + dt: float, + normalized_coefficients: tuple[float], +) -> tuple[float]: + """ + Denormalize the coefficients as they were used in the normalized linear to + then be used again in a regular linear stepper. + + **Arguments:** + - `domain_extent`: extent of the domain + - `dt`: time step + - `normalized_coefficients`: coefficients for the linear operator, + `normalized_coefficients[i]` is the coefficient for the `i`-th + derivative + """ + coefficients = tuple( + c_n / dt * domain_extent**i for i, c_n in enumerate(normalized_coefficients) + ) + return coefficients + + +def normalize_convection_scale( + domain_extent: float, + convection_scale: float, +) -> float: + normalized_convection_scale = convection_scale / domain_extent + return normalized_convection_scale + + +def denormalize_convection_scale( + domain_extent: float, + normalized_convection_scale: float, +) -> float: + convection_scale = normalized_convection_scale * domain_extent + return convection_scale + + +def normalize_gradient_norm_scale( + domain_extent: float, + gradient_norm_scale: float, +): + normalized_gradient_norm_scale = gradient_norm_scale / jnp.square(domain_extent) + return normalized_gradient_norm_scale + + +def denormalize_gradient_norm_scale( + domain_extent: float, + normalized_gradient_norm_scale: float, +): + gradient_norm_scale = normalized_gradient_norm_scale * jnp.square(domain_extent) + return gradient_norm_scale diff --git a/exponax/poisson.py b/exponax/poisson.py new file mode 100644 index 0000000..e5a870f --- /dev/null +++ b/exponax/poisson.py @@ -0,0 +1,102 @@ +import jax +import jax.numpy as jnp +import equinox as eqx +from jaxtyping import Array, Float, Complex + +from .spectral import build_derivative_operator, build_laplace_operator, spatial_shape + + +class Poisson(eqx.Module): + num_spatial_dims: int + domain_extent: float + num_points: int + dx: float + + _inv_operator: Complex[Array, "1 ... N"] + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + *, + order=2, + ): + """ + Exactly solve the Poisson equation with periodic boundary conditions. + + This "stepper" is different from all other steppers in this package in + that it does not solve a time-dependent PDE. Instead, it solves the + Poisson equation + + $$ u_{xx} = - f $$ + + for a given right hand side $f$. + + It is included for completion. + + **Arguments:** + - `num_spatial_dims`: The number of spatial dimensions. + - `domain_extent`: The extent of the domain. + - `num_points`: The number of points in each spatial dimension. + - `order`: The order of the Poisson equation. Defaults to 2. You can + also set `order=4` for the biharmonic equation. + """ + self.num_spatial_dims = num_spatial_dims + self.domain_extent = domain_extent + self.num_points = num_points + + # Uses the convention that N does **not** include the right boundary + # point + self.dx = domain_extent / num_points + + derivative_operator = build_derivative_operator( + num_spatial_dims, domain_extent, num_points + ) + operator = build_laplace_operator(derivative_operator, order=order) + + # Uses mean zero solution + self._inv_operator = jnp.where(operator == 0, 0.0, 1 / operator) + + def step_fourier( + self, + f_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + """ + Solve the Poisson equation in Fourier space. + + **Arguments:** + - `f_hat`: The Fourier transform of the right hand side. + + **Returns:** + - `u_hat`: The Fourier transform of the solution. + """ + return -self._inv_operator * f_hat + + def step( + self, + f: Float[Array, "C ... N"], + ) -> Float[Array, "C ... N"]: + """ + Solve the Poisson equation in real space. + + **Arguments:** + - `f`: The right hand side. + + **Returns:** + - `u`: The solution. + """ + f_hat = jnp.fft.rfft(f) + u_hat = self.step_fourier(f_hat) + u = jnp.fft.irfft(u_hat, self.num_points) + return u + + def __call__( + self, + f: Float[Array, "C ... N"], + ) -> Float[Array, "C ... N"]: + if f.shape[1:] != spatial_shape(self.num_spatial_dims, self.num_points): + raise ValueError( + f"Shape of f[1:] is {f.shape[1:]} but should be {spatial_shape(self.num_spatial_dims, self.num_points)}" + ) + return self.step(f) diff --git a/exponax/repeated_stepper.py b/exponax/repeated_stepper.py new file mode 100644 index 0000000..f7567ce --- /dev/null +++ b/exponax/repeated_stepper.py @@ -0,0 +1,58 @@ +import equinox as eqx + +from .base_stepper import BaseStepper + +from .utils import repeat + +from jaxtyping import Array, Float, Complex + + +class RepeatedStepper(eqx.Module): + """ + Sugarcoat the utility function `repeat` in a callable PyTree for easy + composition with other equinox modules. + + One intended usage is to get "more accurate" or "more stable" time steppers + that perform substeps. + + The effective time step is `self.stepper.dt * self.n_sub_steps`. In order to + get a time step of X with Y substeps, first instantiate a stepper with a + time step of X/Y and then wrap it in a RepeatedStepper with n_sub_steps=Y. + + **Arguments:** + - `stepper`: The stepper to repeat. + - `n_sub_steps`: The number of substeps to perform. + """ + + stepper: BaseStepper + n_sub_steps: int + + def step( + self, + u: Float[Array, "C ... N"], + ) -> Float[Array, "C ... N"]: + """ + Step the PDE forward in time by self.n_sub_steps time steps given the + current state `u`. + """ + return repeat(self.stepper.step, self.n_sub_steps)(u) + + def step_fourier( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + """ + Step the PDE forward in time by self.n_sub_steps time steps given the + current state `u_hat` in real-valued Fourier space. + """ + return repeat(self.stepper.step_fourier, self.n_sub_steps)(u_hat) + + def __call__( + self, + u: Float[Array, "C ... N"], + ) -> Float[Array, "C ... N"]: + """ + Step the PDE forward in time by self.n_sub_steps time steps given the + current state `u`. + """ + return repeat(self.stepper, self.n_sub_steps)(u) diff --git a/exponax/sample_stepper/__init__.py b/exponax/sample_stepper/__init__.py new file mode 100644 index 0000000..63f198d --- /dev/null +++ b/exponax/sample_stepper/__init__.py @@ -0,0 +1,32 @@ +from .burgers import Burgers +from .convection import GeneralConvectionStepper +from .gradient_norm import GeneralGradientNormStepper +from .korteveg_de_vries import KortevegDeVries +from .kuramoto_sivashinsky import ( + KuramotoSivashinsky, + KuramotoSivashinskyConservative, +) +from .linear import ( + Advection, + Diffusion, + AdvectionDiffusion, + Dispersion, + HyperDiffusion, + GeneralLinearStepper, +) +from .navier_stokes import ( + NavierStokesVorticity2d, + KolmogorovFlowVorticity2d, +) +from .nikolaevskiy import ( + Nikolaevskiy, + NikolaevskiyConservative, +) +from .reaction import ( + SwiftHohenberg, + GrayScott, + FisherKPP, + AllenCahn, + CahnHilliard, + BelousovZhabotinsky, +) diff --git a/exponax/sample_stepper/burgers.py b/exponax/sample_stepper/burgers.py new file mode 100644 index 0000000..c97a8f0 --- /dev/null +++ b/exponax/sample_stepper/burgers.py @@ -0,0 +1,63 @@ +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import ConvectionNonlinearFun +from jaxtyping import Complex, Float, Array +from ..spectral import build_laplace_operator, build_gradient_inner_product_operator + + +class Burgers(BaseStepper): + diffusivity: float + convection_scale: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + diffusivity: float = 0.1, + convection_scale: float = 0.5, + order=2, + dealiasing_fraction: float = 2 / 3, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.diffusivity = diffusivity + self.convection_scale = convection_scale + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=num_spatial_dims, # Number of channels grows with dimension + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + + # The linear operator is the same for all D channels + return self.diffusivity * build_laplace_operator(derivative_operator) + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> ConvectionNonlinearFun: + return ConvectionNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + convection_scale=self.convection_scale, + ) diff --git a/exponax/sample_stepper/convection.py b/exponax/sample_stepper/convection.py new file mode 100644 index 0000000..1a31bae --- /dev/null +++ b/exponax/sample_stepper/convection.py @@ -0,0 +1,75 @@ +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import ConvectionNonlinearFun +from jaxtyping import Complex, Float, Array + + +class GeneralConvectionStepper(BaseStepper): + coefficients: list[float] + convection_scale: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + coefficients: list[float] = [0.0, 0.0, 0.01], + convection_scale: float = 0.5, + order=2, + dealiasing_fraction: float = 2 / 3, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + """ + Isotropic linear operators! + + By default Burgers equation with diffusivity of 0.01 + + """ + self.coefficients = coefficients + self.convection_scale = convection_scale + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=num_spatial_dims, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + linear_operator = sum( + jnp.sum( + c * (derivative_operator) ** i, + axis=0, + keepdims=True, + ) + for i, c in enumerate(self.coefficients) + ) + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> ConvectionNonlinearFun: + return ConvectionNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + convection_scale=self.convection_scale, + zero_mode_fix=False, # Todo: check this + ) diff --git a/exponax/sample_stepper/gradient_norm.py b/exponax/sample_stepper/gradient_norm.py new file mode 100644 index 0000000..2952b39 --- /dev/null +++ b/exponax/sample_stepper/gradient_norm.py @@ -0,0 +1,75 @@ +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import GradientNormNonlinearFun +from jaxtyping import Complex, Float, Array + + +class GeneralGradientNormStepper(BaseStepper): + coefficients: list[float] + gradient_norm_scale: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + coefficients: list[float] = [0.0, 0.0, -1.0, 0.0, -1.0], + gradient_norm_scale: float = 0.5, + order=2, + dealiasing_fraction: float = 2 / 3, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + """ + Isotropic linear operators! + + By default KS equation (in combustion science format) + + """ + self.coefficients = coefficients + self.gradient_norm_scale = gradient_norm_scale + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=num_spatial_dims, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + linear_operator = sum( + jnp.sum( + c * (derivative_operator) ** i, + axis=0, + keepdims=True, + ) + for i, c in enumerate(self.coefficients) + ) + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> GradientNormNonlinearFun: + return GradientNormNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + scale=self.gradient_norm_scale, + zero_mode_fix=False, # Todo: check this + ) diff --git a/exponax/sample_stepper/korteveg_de_vries.py b/exponax/sample_stepper/korteveg_de_vries.py new file mode 100644 index 0000000..24ef04b --- /dev/null +++ b/exponax/sample_stepper/korteveg_de_vries.py @@ -0,0 +1,86 @@ +from typing import Union + +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import ConvectionNonlinearFun +from jaxtyping import Complex, Float, Array +from ..spectral import build_laplace_operator, build_gradient_inner_product_operator + + +class KortevegDeVries(BaseStepper): + convection_scale: float + pure_dispersivity: Float[Array, "D"] + advect_over_diffuse_dispersivity: Float[Array, "D"] + diffusivity: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + convection_scale: float = -6 / 2, + pure_dispersivity: Union[Float[Array, "D"], float] = 1.0, + advect_over_diffuse_dispersivity: Union[Float[Array, "D"], float] = 0.0, + diffusivity: float = 0.0, + order: int = 2, + dealiasing_fraction: float = 2 / 3, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.convection_scale = convection_scale + if isinstance(pure_dispersivity, float): + pure_dispersivity = jnp.ones(num_spatial_dims) * pure_dispersivity + if isinstance(advect_over_diffuse_dispersivity, float): + advect_over_diffuse_dispersivity = ( + jnp.ones(num_spatial_dims) * advect_over_diffuse_dispersivity + ) + self.pure_dispersivity = pure_dispersivity + self.advect_over_diffuse_dispersivity = advect_over_diffuse_dispersivity + self.diffusivity = diffusivity + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=num_spatial_dims, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + laplace_operator = build_laplace_operator(derivative_operator, order=2) + linear_operator = ( + -build_gradient_inner_product_operator( + derivative_operator, self.pure_dispersivity, order=3 + ) + - build_gradient_inner_product_operator( + derivative_operator, self.advect_over_diffuse_dispersivity, order=1 + ) + * laplace_operator + + self.diffusivity * laplace_operator + ) + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> ConvectionNonlinearFun: + return ConvectionNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + convection_scale=self.convection_scale, + ) diff --git a/exponax/sample_stepper/kuramoto_sivashinsky.py b/exponax/sample_stepper/kuramoto_sivashinsky.py new file mode 100644 index 0000000..fe2c2ed --- /dev/null +++ b/exponax/sample_stepper/kuramoto_sivashinsky.py @@ -0,0 +1,140 @@ +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import ( + GradientNormNonlinearFun, + ConvectionNonlinearFun, +) +from jaxtyping import Complex, Float, Array +from ..spectral import build_laplace_operator, build_gradient_inner_product_operator + + +class KuramotoSivashinsky(BaseStepper): + second_order_diffusivity: float + fourth_order_diffusivity: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + second_order_diffusivity: float = 1.0, + fourth_order_diffusivity: float = 1.0, + dealiasing_fraction: float = 2 / 3, + order: int = 2, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + """ + Implements the KS equations as used in the combustion community, i.e., + with a gradient-norm nonlinearity instead of the convection nonliearity. + The advantage is that the number of channels is always 1 no matter the + number of spatial dimensions. + """ + self.second_order_diffusivity = second_order_diffusivity + self.fourth_order_diffusivity = fourth_order_diffusivity + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + linear_operator = -self.second_order_diffusivity * build_laplace_operator( + derivative_operator, order=2 + ) - self.fourth_order_diffusivity * build_laplace_operator( + derivative_operator, order=4 + ) + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> GradientNormNonlinearFun: + return GradientNormNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + zero_mode_fix=True, + scale=0.5, + ) + + +class KuramotoSivashinskyConservative(BaseStepper): + second_order_diffusivity: float + fourth_order_diffusivity: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + second_order_diffusivity: float = 1.0, + fourth_order_diffusivity: float = 1.0, + dealiasing_fraction: float = 2 / 3, + order: int = 2, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + """ + Using the fluid dynamics form of the KS equation (i.e. similar to the + Burgers equation). This also means that the number of channels grow with + the number of spatial dimensions. + """ + self.second_order_diffusivity = second_order_diffusivity + self.fourth_order_diffusivity = fourth_order_diffusivity + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=num_spatial_dims, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + linear_operator = -self.second_order_diffusivity * build_laplace_operator( + derivative_operator, order=2 + ) - self.fourth_order_diffusivity * build_laplace_operator( + derivative_operator, order=4 + ) + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> ConvectionNonlinearFun: + return ConvectionNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + zero_mode_fix=True, + convection_scale=0.5, + ) diff --git a/exponax/sample_stepper/linear.py b/exponax/sample_stepper/linear.py new file mode 100644 index 0000000..7415cd1 --- /dev/null +++ b/exponax/sample_stepper/linear.py @@ -0,0 +1,314 @@ +from typing import Union + +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import ZeroNonlinearFun +from jaxtyping import Complex, Float, Array +from ..spectral import build_laplace_operator, build_gradient_inner_product_operator + + +class Advection(BaseStepper): + velocity: Float[Array, "D"] + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + velocity: Union[Float[Array, "D"], float] = 1.0, + ): + if isinstance(velocity, float): + velocity = jnp.ones(num_spatial_dims) * velocity + self.velocity = velocity + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=0, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + # Requires minus to move term to the rhs + return -build_gradient_inner_product_operator( + derivative_operator, self.velocity, order=1 + ) + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> ZeroNonlinearFun: + return ZeroNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=1.0, + ) + + +class Diffusion(BaseStepper): + diffusivity: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + diffusivity: float = 0.01, + ): + self.diffusivity = diffusivity + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=0, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + return self.diffusivity * build_laplace_operator(derivative_operator) + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> ZeroNonlinearFun: + return ZeroNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=1.0, + ) + + +class AdvectionDiffusion(BaseStepper): + velocity: Float[Array, "D"] + diffusivity: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + velocity: Union[Float[Array, "D"], float] = 1.0, + diffusivity: float = 0.01, + ): + if isinstance(velocity, float): + velocity = jnp.ones(num_spatial_dims) * velocity + self.velocity = velocity + self.diffusivity = diffusivity + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=0, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + return -build_gradient_inner_product_operator( + derivative_operator, self.velocity, order=1 + ) + self.diffusivity * build_laplace_operator(derivative_operator) + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> ZeroNonlinearFun: + return ZeroNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=1.0, + ) + + +class Dispersion(BaseStepper): + dispersivity: Float[Array, "D"] + advect_on_diffusion: bool + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + dispersivity: Union[Float[Array, "D"], float] = 1.0, + advect_on_diffusion: bool = False, + ): + if isinstance(dispersivity, float): + dispersivity = jnp.ones(num_spatial_dims) * dispersivity + self.dispersivity = dispersivity + self.advect_on_diffusion = advect_on_diffusion + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=0, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + if self.advect_on_diffusion: + laplace_operator = build_laplace_operator(derivative_operator) + advection_operator = build_gradient_inner_product_operator( + derivative_operator, self.dispersivity, order=1 + ) + linear_operator = advection_operator * laplace_operator + else: + linear_operator = build_gradient_inner_product_operator( + derivative_operator, self.dispersivity, order=3 + ) + + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> ZeroNonlinearFun: + return ZeroNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=1.0, + ) + + +class HyperDiffusion(BaseStepper): + hyper_diffusivity: float + diffuse_on_diffuse: bool + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + hyper_diffusivity: float = 1.0, + diffuse_on_diffuse: bool = False, + ): + self.hyper_diffusivity = hyper_diffusivity + self.diffuse_on_diffuse = diffuse_on_diffuse + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=0, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + # Use minus sign to have diffusion work in "correct direction" by default + if self.diffuse_on_diffuse: + laplace_operator = build_laplace_operator(derivative_operator) + linear_operator = ( + -self.hyper_diffusivity * laplace_operator * laplace_operator + ) + else: + linear_operator = -self.hyper_diffusivity * build_laplace_operator( + derivative_operator, order=4 + ) + + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> ZeroNonlinearFun: + return ZeroNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=1.0, + ) + + +class GeneralLinearStepper(BaseStepper): + coefficients: list[float] + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + coefficients: list[float] = [0.0, -0.1, 0.01], + ): + """ + Isotropic linear operators! + + By default: advection-diffusion with advection of 0.1 and diffusion of + 0.01. + + Take care of the signs! + """ + self.coefficients = coefficients + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=0, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + linear_operator = sum( + jnp.sum( + c * (derivative_operator) ** i, + axis=0, + keepdims=True, + ) + for i, c in enumerate(self.coefficients) + ) + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> ZeroNonlinearFun: + return ZeroNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=1.0, + ) diff --git a/exponax/sample_stepper/navier_stokes.py b/exponax/sample_stepper/navier_stokes.py new file mode 100644 index 0000000..48f360e --- /dev/null +++ b/exponax/sample_stepper/navier_stokes.py @@ -0,0 +1,127 @@ +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import ( + VorticityConvection2d, + VorticityConvection2dKolmogorov, +) +from jaxtyping import Complex, Float, Array +from ..spectral import build_laplace_operator, build_gradient_inner_product_operator + + +class NavierStokesVorticity2d(BaseStepper): + diffusivity: float + drag: float + dealiasing_fraction: float + + def __init__( + self, + # Does not require D argument as it is fixed to 2 + domain_extent: float, + num_points: int, + dt: float, + *, + diffusivity: float = 0.01, + drag: float = 0.0, + order: int = 2, + dealiasing_fraction: float = 2 / 3, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.diffusivity = diffusivity + self.drag = drag + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=2, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + return self.diffusivity * build_laplace_operator( + derivative_operator, order=2 + ) + self.drag * build_laplace_operator(derivative_operator, order=0) + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> VorticityConvection2d: + return VorticityConvection2d( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + ) + + +class KolmogorovFlowVorticity2d(BaseStepper): + diffusivity: float + drag: float + injection_mode: int + injection_scale: float + dealiasing_fraction: float + + def __init__( + self, + # Does not require D argument as it is fixed to 2 + domain_extent: float, + num_points: int, + dt: float, + *, + diffusivity: float = 0.001, + drag: float = -0.1, + injection_mode: int = 4, + injection_scale: float = 1.0, + order: int = 2, + dealiasing_fraction: float = 2 / 3, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.diffusivity = diffusivity + self.drag = drag + self.injection_mode = injection_mode + self.injection_scale = injection_scale + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=2, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + return self.diffusivity * build_laplace_operator( + derivative_operator, order=2 + ) + self.drag * build_laplace_operator(derivative_operator, order=0) + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> VorticityConvection2dKolmogorov: + return VorticityConvection2dKolmogorov( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + injection_mode=self.injection_mode, + injection_scale=self.injection_scale, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + ) diff --git a/exponax/sample_stepper/nikolaevskiy.py b/exponax/sample_stepper/nikolaevskiy.py new file mode 100644 index 0000000..a99ba17 --- /dev/null +++ b/exponax/sample_stepper/nikolaevskiy.py @@ -0,0 +1,141 @@ +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import ( + GradientNormNonlinearFun, + ConvectionNonlinearFun, +) +from jaxtyping import Complex, Float, Array +from ..spectral import build_laplace_operator, build_gradient_inner_product_operator + + +class Nikolaevskiy(BaseStepper): + second_order_diffusivity: float + fourth_order_diffusivity: float + sixth_order_diffusivity: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + second_order_diffusivity: float = 0.1, + fourth_order_diffusivity: float = 1.0, + sixth_order_diffusivity: float = 1.0, + dealiasing_fraction: float = 2 / 3, + order: int = 2, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.second_order_diffusivity = second_order_diffusivity + self.fourth_order_diffusivity = fourth_order_diffusivity + self.sixth_order_diffusivity = sixth_order_diffusivity + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + linear_operator = ( + self.second_order_diffusivity + * build_laplace_operator(derivative_operator, order=2) + + self.fourth_order_diffusivity + * build_laplace_operator(derivative_operator, order=4) + + self.sixth_order_diffusivity + * build_laplace_operator(derivative_operator, order=6) + ) + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> GradientNormNonlinearFun: + return GradientNormNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + zero_mode_fix=True, + scale=0.5, + ) + + +class NikolaevskiyConservative(BaseStepper): + second_order_diffusivity: float + fourth_order_diffusivity: float + sixth_order_diffusivity: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + second_order_diffusivity: float = 0.1, + fourth_order_diffusivity: float = 1.0, + sixth_order_diffusivity: float = 1.0, + dealiasing_fraction: float = 2 / 3, + order: int = 2, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.second_order_diffusivity = second_order_diffusivity + self.fourth_order_diffusivity = fourth_order_diffusivity + self.sixth_order_diffusivity = sixth_order_diffusivity + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=num_spatial_dims, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + linear_operator = ( + self.second_order_diffusivity + * build_laplace_operator(derivative_operator, order=2) + + self.fourth_order_diffusivity + * build_laplace_operator(derivative_operator, order=4) + + self.sixth_order_diffusivity + * build_laplace_operator(derivative_operator, order=6) + ) + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> ConvectionNonlinearFun: + return ConvectionNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + zero_mode_fix=True, + convection_scale=0.5, + ) diff --git a/exponax/sample_stepper/reaction.py b/exponax/sample_stepper/reaction.py new file mode 100644 index 0000000..9ea8645 --- /dev/null +++ b/exponax/sample_stepper/reaction.py @@ -0,0 +1,350 @@ +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import ( + PolynomialNonlinearFun, + GrayScottNonlinearFun, + CahnHilliardNonlinearFun, + BelousovZhabotinskyNonlinearFun, +) +from jaxtyping import Complex, Float, Array +from ..spectral import build_laplace_operator, build_gradient_inner_product_operator + + +class SwiftHohenberg(BaseStepper): + g: float + r: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + g: float = 1.0, + r: float = 0.7, + order: int = 2, + dealiasing_fraction: float = 1 + / 2, # Needs lower value due to cubic nonlinearity + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.g = g + self.r = r + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + laplace = build_laplace_operator(derivative_operator, order=2) + linear_operator = self.r - (1 + laplace) ** 2 + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> PolynomialNonlinearFun: + return PolynomialNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + coefficients=[0.0, 0.0, self.g, -1.0], + ) + + +class GrayScott(BaseStepper): + epsilon_1: float + epsilon_2: float + b: float + d: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + epsilon_1: float = 0.00002, + epsilon_2: float = 0.00001, + b: float = 0.04, + d: float = 0.1, + order: int = 2, + dealiasing_fraction: float = 1 + / 2, # Needs lower value due to cubic nonlinearity + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.epsilon_1 = epsilon_1 + self.epsilon_2 = epsilon_2 + self.b = b + self.d = d + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=2, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "2 ... (N//2)+1"]: + laplace = build_laplace_operator(derivative_operator, order=2) + linear_operator = jnp.concatenate( + [ + self.epsilon_1 * laplace, + self.epsilon_2 * laplace, + ] + ) + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> GrayScottNonlinearFun: + return GrayScottNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + b=self.b, + d=self.d, + dealiasing_fraction=self.dealiasing_fraction, + ) + + +### !!! Below models lack validation ### + + +class FisherKPP(BaseStepper): + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + order: int = 2, + dealiasing_fraction: float = 2 / 3, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + laplace = build_laplace_operator(derivative_operator, order=2) + linear_operator = laplace + 1.0 + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> PolynomialNonlinearFun: + return PolynomialNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + coefficients=[0.0, 0.0, -1.0], + ) + + +class AllenCahn(BaseStepper): + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + order: int = 2, + dealiasing_fraction: float = 1 + / 2, # Needs lower value due to cubic nonlinearity + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + laplace = build_laplace_operator(derivative_operator, order=2) + linear_operator = laplace + 1.0 + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> PolynomialNonlinearFun: + return PolynomialNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + coefficients=[0.0, 0.0, 0.0, -1.0], + ) + + +class CahnHilliard(BaseStepper): + hyper_diffusivity: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + hyper_diffusivity: float = 0.2, + order: int = 2, + dealiasing_fraction: float = 1 + / 2, # Needs lower value due to cubic nonlinearity + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.hyper_diffusivity = hyper_diffusivity + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + laplace = build_laplace_operator(derivative_operator, order=2) + bi_laplace = build_laplace_operator(derivative_operator, order=4) + linear_operator = -self.hyper_diffusivity * bi_laplace - laplace + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> CahnHilliardNonlinearFun: + return CahnHilliardNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + ) + + +class BelousovZhabotinsky(BaseStepper): + diffusivities: list[float] + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + diffusivities: list[float] = [1e-5, 2e-5, 1e-5], + order: int = 2, + dealiasing_fraction: float = 1 + / 2, # Needs lower value due to cubic nonlinearity + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.diffusivities = diffusivities + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=3, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "3 ... (N//2)+1"]: + laplace = build_laplace_operator(derivative_operator, order=2) + linear_operator = jnp.concatenate( + [ + self.diffusivities[0] * laplace, + self.diffusivities[1] * laplace, + self.diffusivities[2] * laplace, + ] + ) + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> BelousovZhabotinskyNonlinearFun: + return BelousovZhabotinskyNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + ) diff --git a/exponax/spectral.py b/exponax/spectral.py new file mode 100644 index 0000000..135b23a --- /dev/null +++ b/exponax/spectral.py @@ -0,0 +1,412 @@ +import jax.numpy as jnp +from jaxtyping import Array, Float, Complex, PyTree, PRNGKeyArray, Bool +from typing import Union + + +def build_wavenumbers( + num_spatial_dims: int, + num_points: int, + *, + indexing: str = "ij", +) -> Float[Array, "D ... (N//2)+1"]: + """ + Setup an array containing integer coordinates of wavenumbers associated with + a "num_spatial_dims"-dimensional rfft (real-valued FFT) + `jax.numpy.fft.rfftn`. + + **Arguments:** + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. + + **Returns:** + - `wavenumbers`: An array of wavenumber integer coordinates, shape + `(D, ..., (N//2)+1)`. + """ + right_most_wavenumbers = jnp.fft.rfftfreq(num_points, 1 / num_points) + other_wavenumbers = jnp.fft.fftfreq(num_points, 1 / num_points) + + wavenumber_list = [ + other_wavenumbers, + ] * (num_spatial_dims - 1) + [ + right_most_wavenumbers, + ] + + wavenumbers = jnp.stack( + jnp.meshgrid(*wavenumber_list, indexing=indexing), + ) + + return wavenumbers + + +def build_scaled_wavenumbers( + D: int, + L: float, + N: int, + *, + indexing: str = "ij", +) -> Float[Array, "D ... (N//2)+1"]: + """ + Setup an array containing scaled wavenumbers associated with a + "num_spatial_dims"-dimensional rfft (real-valued FFT) + `jax.numpy.fft.rfftn`. Scaling is done by `2 * pi / L`. + + **Arguments:** + - `D`: The number of spatial dimensions. + - `L`: The domain extent. + - `N`: The number of points in each spatial dimension. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. + + **Returns:** + - `wavenumbers`: An array of wavenumber integer coordinates, shape + `(D, ..., (N//2)+1)`. + """ + scale = 2 * jnp.pi / L + wavenumbers = build_wavenumbers(D, N, indexing=indexing) + return scale * wavenumbers + + +def derivative( + field: Float[Array, "C ... N"], + domain_extent: float, + *, + order: int = 1, + indexing: str = "ij", +) -> Union[Float[Array, "C D ... (N//2)+1"], Float[Array, "D ... (N//2)+1"]]: + """ + Perform the spectral derivative of a field. In higher dimensions, this + defaults to the gradient (the collection of all partial derivatives). In 1d, + the resulting channel dimension holds the derivative. If the function is + called with an d-dimensional field which has 1 channel, the result will be a + d-dimensional field with d channels (one per partial derivative). If the + field originally had C channels, the result will be a matrix field with C + rows and d columns. + + Note that applying this operator twice will produce issues at the Nyquist if + the number of degrees of freedom N is even. For this, consider also using + the order option. + + **Arguments:** + - `field`: The field to differentiate, shape `(C, ..., N,)`. `C` can be + `1` for a scalar field or `D` for a vector field. + - `L`: The domain extent. + - `order`: The order of the derivative. Default is `1`. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. + + **Returns:** + - `field_der`: The derivative of the field, shape `(C, D, ..., + (N//2)+1)` or `(D, ..., (N//2)+1)`. + """ + channel_shape = field.shape[0] + spatial_shape = field.shape[1:] + D = len(spatial_shape) + N = spatial_shape[0] + derivative_operator = build_derivative_operator( + D, domain_extent, N, indexing=indexing + ) + ## I decided to not use this fix + + # # Required for even N, no effect for odd N + # derivative_operator_fixed = ( + # derivative_operator * nyquist_filter_mask(D, N) + # ) + derivative_operator_fixed = derivative_operator**order + + field_hat = jnp.fft.rfftn(field, axes=space_indices(D)) + if channel_shape == 1: + # Do not introduce another channel axis + field_der_hat = derivative_operator_fixed * field_hat + else: + # Create a "derivative axis" right after the channel axis + field_der_hat = field_hat[:, None] * derivative_operator_fixed[None, ...] + + field_der = jnp.fft.irfftn(field_der_hat, s=spatial_shape, axes=space_indices(D)) + + return field_der + + +def build_derivative_operator( + num_spatial_dims: int, + domain_extent: float, + num_points: int, + *, + indexing: str = "ij", +) -> Complex[Array, "D ... (N//2)+1"]: + """ + Setup the derivative operator in Fourier space. + + **Arguments:** + - `D`: The number of spatial dimensions. + - `L`: The domain extent. + - `N`: The number of points in each spatial dimension. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. + + **Returns:** + - `derivative_operator`: The derivative operator, shape `(D, ..., + N//2+1)`. + """ + return 1j * build_scaled_wavenumbers( + num_spatial_dims, domain_extent, num_points, indexing=indexing + ) + + +def build_laplace_operator( + derivative_operator: Complex[Array, "D ... (N//2)+1"], + *, + order: int = 2, +) -> Complex[Array, "1 ... (N//2)+1"]: + """ + Given the derivative operator of [`build_derivative_operator`], return the + Laplace operator. + + **Arguments:** + - `derivative_operator`: The derivative operator, shape `(D, ..., + N//2+1)`. + - `order`: The order of the Laplace operator. Default is `2`. + + **Returns:** + - `laplace_operator`: The Laplace operator, shape `(1, ..., N//2+1)`. + """ + if order % 2 != 0: + raise ValueError("Order must be even.") + + return jnp.sum(derivative_operator**order, axis=0, keepdims=True) + + +def build_gradient_inner_product_operator( + derivative_operator: Complex[Array, "D ... (N//2)+1"], + velocity: Float[Array, "D"], + *, + order: int = 1, +) -> Complex[Array, "1 ... (N//2)+1"]: + """ + Given the derivative operator of [`build_derivative_operator`] and a velocity + field, return the operator that computes the inner product of the gradient + with the velocity. + + **Arguments:** + - `derivative_operator`: The derivative operator, shape `(D, ..., + N//2+1)`. + - `velocity`: The velocity field, shape `(D,)`. + - `order`: The order of the gradient. Default is `1`. + + **Returns:** + - `operator`: The operator, shape `(1, ..., N//2+1)`. + """ + if order % 2 != 1: + raise ValueError("Order must be odd.") + + if velocity.shape != (derivative_operator.shape[0],): + raise ValueError( + f"Expected velocity shape to be {derivative_operator.shape[0]}, got {velocity.shape}." + ) + + # Need to move the channel/dimension axis last to enable autobroadcast over + # the arbitrary number of spatial axes, Then we can move this singleton axis + # back to the front + operator = jnp.swapaxes( + jnp.sum( + velocity + * jnp.swapaxes( + derivative_operator**order, + 0, + -1, + ), + axis=-1, + keepdims=True, + ), + 0, + -1, + ) + + return operator + + +def space_indices(num_spatial_dims: int) -> tuple[int, ...]: + """ + Returns the indices within a field array that correspond to the spatial + dimensions. + + **Arguments:** + - `num_spatial_dims`: The number of spatial dimensions. + + **Returns:** + - `indices`: The indices of the spatial dimensions. + """ + return tuple(range(-num_spatial_dims, 0)) + + +def spatial_shape(num_spatial_dims: int, num_points: int) -> tuple[int, ...]: + """ + Returns the shape of a spatial field array. + + **Arguments:** + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. + + **Returns:** + - `shape`: The shape of the spatial field array. + """ + return (num_points,) * num_spatial_dims + + +def wavenumber_shape(num_spatial_dims: int, num_points: int) -> tuple[int, ...]: + """ + Returns the spatial shape of a field in Fourier space (assuming the usage of + rfft, `jax.numpy.fft.rfftn`). + + **Arguments:** + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. + + **Returns:** + - `shape`: The shape of the spatial field array. + """ + return (num_points,) * (num_spatial_dims - 1) + (num_points // 2 + 1,) + + +def low_pass_filter_mask( + num_spatial_dims: int, + num_points: int, + *, + cutoff: int, + axis_separate: bool = True, + indexing: str = "ij", +) -> Bool[Array, "1 ... N"]: + """ + Create a low-pass filter mask in Fourier space. + + **Arguments:** + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. + - `cutoff`: The cutoff wavenumber. + - `axis_separate`: Whether to apply the cutoff to each axis separately. + Default is `True`. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. + + **Returns:** + - `mask`: The low-pass filter mask, shape `(1, ..., N//2+1)`. + """ + wavenumbers = build_wavenumbers(num_spatial_dims, num_points, indexing=indexing) + + if axis_separate: + mask = True + for wn_grid in wavenumbers: + mask = mask & (jnp.abs(wn_grid) <= cutoff) + else: + mask = jnp.linalg.norm(mask, axis=0) <= cutoff + + mask = mask[jnp.newaxis, ...] + + return mask + + +def nyquist_filter_mask( + num_spatial_dims: int, + num_points: int, +) -> Bool[Array, "1 ... N"]: + """ + Creates mask that if multiplied with a field in Fourier space will remove + the Nyquist mode. + + **Arguments:** + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. + + **Returns:** + - `mask`: The Nyquist filter mask, shape `(1, ..., N//2+1)`. + """ + if num_points % 2 == 1: + # Odd number of degrees of freedom (no issue with the Nyquist mode) + return jnp.ones( + (1, *wavenumber_shape(num_spatial_dims, num_points)), dtype=bool + ) + else: + # Even number of dof (hence the Nyquist only appears in the negative + # wavenumbers. This is problematic because the rfft in D >=2 has + # multiple FFTs after the rFFT) + nyquist_mode = num_points // 2 + 1 + mode_below_nyquist = nyquist_mode - 1 + return low_pass_filter_mask( + num_spatial_dims, + num_points, + cutoff=mode_below_nyquist - 1, + axis_separate=True, + ) + + # # Todo: Do we need the below? + # wavenumbers = build_wavenumbers(D, N, scaled=False) + # mask = True + # for wn_grid in wavenumbers: + # mask = mask & (wn_grid != -mode_below_nyquist) + # return mask + + +def build_scaling_array( + num_spatial_dims: int, + num_points: int, + *, + indexing: str = "ij", +) -> Float[Array, "1 ... (N//2)+1"]: + """ + Creates an array of the values that would be seen in the result of a + (real-valued) Fourier transform of a signal of amplitude 1. + + **Arguments:** + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. + + **Returns:** + - `scaling`: The scaling array, shape `(1, ..., N//2+1)`. + """ + right_most_wavenumbers = jnp.fft.rfftfreq(num_points, 1 / num_points) + other_wavenumbers = jnp.fft.fftfreq(num_points, 1 / num_points) + + right_most_scaling = jnp.where( + right_most_wavenumbers == 0, + num_points, + num_points / 2, + ) + other_scaling = jnp.where( + other_wavenumbers == 0, + num_points, + num_points / 2, + ) + + # If N is even, special treatment for the Nyquist mode + if num_points % 2 == 0: + # rfft has the Nyquist mode as positive wavenumber + right_most_scaling = jnp.where( + right_most_wavenumbers == num_points // 2, + num_points, + right_most_scaling, + ) + # standard fft has the Nyquist mode as negative wavenumber + other_scaling = jnp.where( + other_wavenumbers == -num_points // 2, + num_points, + other_scaling, + ) + + scaling_list = [ + other_scaling, + ] * (num_spatial_dims - 1) + [ + right_most_scaling, + ] + + scaling = jnp.prod( + jnp.stack( + jnp.meshgrid(*scaling_list, indexing=indexing), + ), + axis=0, + keepdims=True, + ) + + return scaling diff --git a/exponax/utils.py b/exponax/utils.py new file mode 100644 index 0000000..faab6fc --- /dev/null +++ b/exponax/utils.py @@ -0,0 +1,365 @@ +from typing import Union + +import jax +import jax.numpy as jnp +import jax.random as jr +import jax.tree_util as jtu +from jaxtyping import Array, Float, Complex, PyTree, PRNGKeyArray +from typing import Callable, Tuple + + +def get_grid( + num_spatial_dims: int, + domain_extent: float, + num_points: int, + *, + full: bool = False, + zero_centered: bool = False, + indexing: str = "ij", +) -> Float[Array, "D ... N"]: + """ + Return a grid in the spatial domain. A grid in d dimensions is an array of + shape (d,) + (num_points,)*d with the first axis representing all coordiate + inidices. + + Notice, that if `num_spatial_dims = 1`, the returned array has a singleton + dimension in the first axis, i.e., the shape is `(1, num_points)`. + + **Arguments:** + - `num_spatial_dims`: The number of spatial dimensions. + - `domain_extent`: The extent of the domain in each spatial dimension. + - `num_points`: The number of points in each spatial dimension. + - `full`: Whether to include the right boundary point in the grid. + Default: `False`. The right point is redundant for periodic boundary + conditions and is not considered a degree of freedom. Use this + option, for example, if you need a full grid for plotting. + - `zero_centered`: Whether to center the grid around zero. Default: + `False`. By default the grid considers a domain of (0, + domain_extent)^(num_spatial_dims). + - `indexing`: The indexing convention to use. Default: `'ij'`. + + **Returns:** + - `grid`: The grid in the spatial domain. Shape: `(num_spatial_dims, + ..., num_points)`. + """ + if full: + grid_1d = jnp.linspace(0, domain_extent, num_points + 1, endpoint=True) + else: + grid_1d = jnp.linspace(0, domain_extent, num_points, endpoint=False) + + if zero_centered: + grid_1d -= domain_extent / 2 + + grid_list = [ + grid_1d, + ] * num_spatial_dims + + grid = jnp.stack( + jnp.meshgrid(*grid_list, indexing=indexing), + ) + + return grid + + +def rollout( + stepper_fn: Union[Callable[[PyTree], PyTree], Callable[[PyTree, PyTree], PyTree]], + n: int, + *, + include_init: bool = False, + takes_aux: bool = False, + constant_aux: bool = True, +): + """ + Transform a stepper function into a function that autoregressively (i.e., + recursively applied to its own output) produces a trajectory of length `n`. + + Based on `takes_aux`, the stepper function is either fully automomous, just + mapping state to state, or takes an additional auxiliary input. This can be + a force/control or additional metadata (like physical parameters, or time + for non-autonomous systems). + + Args: + - `stepper_fn`: The time stepper to transform. If `takes_aux = False` + (default), expected signature is `u_next = stepper_fn(u)`, else + `u_next = stepper_fn(u, aux)`. `u` and `u_next` need to be PyTrees + of identical structure, in the easiest case just arrays of same + shape. + - `n`: The number of time steps to rollout the trajectory into the + future. If `include_init = False` (default) produces the `n` steps + into the future. + - `include_init`: Whether to include the initial condition in the + trajectory. If `True`, the arrays in the returning PyTree have shape + `(n + 1, ...)`, else `(n, ...)`. Default: `False`. + - `takes_aux`: Whether the stepper function takes an additional PyTree + as second argument. + - `constant_aux`: Whether the auxiliary input is constant over the + trajectory. If `True`, the auxiliary input is repeated `n` times, + otherwise the leading axis in the PyTree arrays has to be of length + `n`. + + Returns: + - `rollout_stepper_fn`: A function that takes an initial condition `u_0` + and an auxiliary input `aux` (if `takes_aux = True`) and produces + the trajectory by autoregressively applying the stepper `n` times. + If `include_init = True`, the trajectory has shape `(n + 1, ...)`, + else `(n, ...)`. Returns a PyTree of the same structure as the + initial condition, but with an additional leading axis of length + `n`. + """ + + if takes_aux: + + def scan_fn(u, aux): + u_next = stepper_fn(u, aux) + return u_next, u_next + + def rollout_stepper_fn(u_0, aux): + if constant_aux: + aux = jtu.tree_map( + lambda x: jnp.repeat(jnp.expand_dims(x, axis=0), n, axis=0), aux + ) + + _, trj = jax.lax.scan(scan_fn, u_0, aux, length=n) + + if include_init: + trj_with_init = jtu.tree_map( + lambda init, history: jnp.concatenate( + [jnp.expand_dims(init, axis=0), history], axis=0 + ), + u_0, + trj, + ) + return trj_with_init + else: + return trj + + return rollout_stepper_fn + + else: + + def scan_fn(u, _): + u_next = stepper_fn(u) + return u_next, u_next + + def rollout_stepper_fn(u_0): + _, trj = jax.lax.scan(scan_fn, u_0, None, length=n) + + if include_init: + trj_with_init = jtu.tree_map( + lambda init, history: jnp.concatenate( + [jnp.expand_dims(init, axis=0), history], axis=0 + ), + u_0, + trj, + ) + return trj_with_init + else: + return trj + + return rollout_stepper_fn + + +def repeat( + stepper_fn: Union[Callable[[PyTree], PyTree], Callable[[PyTree, PyTree], PyTree]], + n: int, + *, + takes_aux: bool = False, + constant_aux: bool = True, +): + """ + Transform a stepper function into a function that autoregressively (i.e., + recursively applied to its own output) applies the stepper `n` times and + returns the final state. + + Based on `takes_aux`, the stepper function is either fully automomous, just + mapping state to state, or takes an additional auxiliary input. This can be + a force/control or additional metadata (like physical parameters, or time + for non-autonomous systems). + + Args: + - `stepper_fn`: The time stepper to transform. If `takes_aux = False` + (default), expected signature is `u_next = stepper_fn(u)`, else + `u_next = stepper_fn(u, aux)`. `u` and `u_next` need to be PyTrees + of identical structure, in the easiest case just arrays of same + shape. + - `n`: The number of times to apply the stepper. + - `takes_aux`: Whether the stepper function takes an additional PyTree + as second argument. + - `constant_aux`: Whether the auxiliary input is constant over the + trajectory. If `True`, the auxiliary input is repeated `n` times, + otherwise the leading axis in the PyTree arrays has to be of length + `n`. + + Returns: + - `repeated_stepper_fn`: A function that takes an initial condition + `u_0` and an auxiliary input `aux` (if `takes_aux = True`) and + produces the final state by autoregressively applying the stepper + `n` times. Returns a PyTree of the same structure as the initial + condition. + """ + + if takes_aux: + + def scan_fn(u, aux): + u_next = stepper_fn(u, aux) + return u_next, None + + def repeated_stepper_fn(u_0, aux): + if constant_aux: + aux = jtu.tree_map( + lambda x: jnp.repeat(jnp.expand_dims(x, axis=0), n, axis=0), aux + ) + + final, _ = jax.lax.scan(scan_fn, u_0, aux, length=n) + return final + + return repeated_stepper_fn + + else: + + def scan_fn(u, _): + u_next = stepper_fn(u) + return u_next, None + + def repeated_stepper_fn(u_0): + final, _ = jax.lax.scan(scan_fn, u_0, None, length=n) + return final + + return repeated_stepper_fn + + +def stack_sub_trajectories( + trj: PyTree[Float[Array, "n_timesteps ..."]], + sub_len: int, +) -> PyTree[Float[Array, "n_stacks sub_len ..."]]: + """ + Slice a trajectory into subtrajectories of length `n` and stack them + together. Useful for rollout training neural operators with temporal mixing. + + !!! Note that this function can produce very large arrays. + + **Arguments:** + - `trj`: The trajectory to slice. Expected shape: `(n_timesteps, ...)`. + - `sub_len`: The length of the subtrajectories. If you want to perform rollout + training with k steps, note that `n=k+1` to also have an initial + condition in the subtrajectories. + + **Returns:** + - `sub_trjs`: The stacked subtrajectories. Expected shape: `(n_stacks, + n, ...)`. `n_stacks` is the number of subtrajectories stacked + together, i.e., `n_timesteps - n + 1`. + """ + n_time_steps = [l.shape[0] for l in jtu.tree_leaves(trj)] + + if len(set(n_time_steps)) != 1: + raise ValueError( + "All arrays in trj must have the same number of time steps in the leading axis" + ) + else: + n_time_steps = n_time_steps[0] + + if sub_len > n_time_steps: + raise ValueError( + "n must be smaller than or equal to the number of time steps in trj" + ) + + n_sub_trjs = n_time_steps - sub_len + 1 + + sub_trjs = jtu.tree_map( + lambda trj: jnp.stack( + [trj[i : i + sub_len] for i in range(n_sub_trjs)], axis=0 + ), + trj, + ) + + return sub_trjs + + +import matplotlib.pyplot as plt +from matplotlib.animation import FuncAnimation + + +def get_animation(trj, *, vlim=(-1, 1)): + fig, ax = plt.subplots() + im = ax.imshow( + trj[0].squeeze().T, vmin=vlim[0], vmax=vlim[1], cmap="RdBu_r", origin="lower" + ) + im.set_data(jnp.zeros_like(trj[0]).squeeze()) + + def animate(i): + im.set_data(trj[i].squeeze().T) + fig.suptitle(f"t_i = {i:04d}") + return im + + plt.close(fig) + + ani = FuncAnimation(fig, animate, frames=trj.shape[0], interval=100, blit=False) + + return ani + + +def get_grouped_animation( + trj, *, vlim=(-1, 1), grid=(3, 3), figsize=(10, 10), titles=None +): + """ + trj.shape = (n_trjs, n_timesteps, ...) + """ + fig, ax_s = plt.subplots(*grid, sharex=True, sharey=True, figsize=figsize) + im_s = [] + for i, ax in enumerate(ax_s.flatten()): + im = ax.imshow( + trj[i, 0].squeeze().T, + vmin=vlim[0], + vmax=vlim[1], + cmap="RdBu_r", + origin="lower", + ) + im.set_data(jnp.zeros_like(trj[i, 0]).squeeze()) + im_s.append(im) + + def animate(i): + for j, im in enumerate(im_s): + im.set_data(trj[j, i].squeeze().T) + if titles is not None: + ax_s.flatten()[j].set_title(titles[j]) + fig.suptitle(f"t_i = {i:04d}") + return im_s + + plt.close(fig) + + ani = FuncAnimation(fig, animate, frames=trj.shape[1], interval=100, blit=False) + + return ani + + +def build_ic_set( + ic_generator, + *, + num_points: int, + num_samples: int, + key: PRNGKeyArray, +) -> Float[Array, "S 1 ... N"]: + """ + Generate a set of initial conditions by sampling from a given initial + condition distribution and evaluating the function on the given grid. + + **Arguments:** + - `ic_generator`: A function that takes a PRNGKey and returns a + function that takes a grid and returns a sample from the initial + condition distribution. + - `num_samples`: The number of initial conditions to sample. + - `key`: The PRNGKey to use for sampling. + + **Returns:** + - `ic_set`: The set of initial conditions. Shape: `(S, 1, ..., N)`. + `S = num_samples`. + """ + + def scan_fn(k, _): + k, sub_k = jr.split(k) + ic = ic_generator(num_points, key=sub_k) + return k, ic + + _, ic_set = jax.lax.scan(scan_fn, key, None, length=num_samples) + + return ic_set diff --git a/ks_rollout.png b/ks_rollout.png new file mode 100644 index 0000000000000000000000000000000000000000..516097ec50f64c10ff916f14df8dea973c1d2c3b GIT binary patch literal 197512 zcmeFZWmH^EvnV`+4=^~v-CYti$Y8}m!JQB+xJz!H=l$+k z_niCt`*(lry?gbl?yl8U)m2@xKdLB6gHef50RRA4Rz^|{06-810Dz4sNPl~{+zqDw zegr(EbUf6ZtvtL<-7Eo$rXDU1&K?f7W|Ur*Ztk|uPH)-y*|}ILZ9F_&+yyx}9RCY| z-Pz5WW3~!c=q!CsQ=S9T!j8(L?lQlMoQj}|ju8X*)2`cwoF6NN z9j`Z;<{o#y9_R0T$@jbNel1q(w>Yi0dmjJqhX0g8@QQQZ|L8qlJayEGBL6=D{~d8R zD|r2X&C!*ym@oa`XZ<_As}V!^|5)cg3!W8w{+ju}P~fPRNc?X|{RLk2wDb6iXsT5c)s+- z>%Z;)sXEV@Fks6)_sJlaCh$YmUjbKq+Bbf+447yMm?(YW{qN9U8xlnxQF5Qw4n$5{ zW=^61oD#MaLHAUEtv}ZXuepTIqR`n83!G^UJ-YVt#?ZwIM)@Z_F_ynaBng01uRhq< zxY9thj`tn;CZFDcBk!#WDZFh~51h=%Y}_>Q&+}r{WPSBd2J3*15>24cXC&B2kkYD8 z$EgUipktSHdB=RWh|+rP{co*J%v$}7EAF9gZ}qPP&rd~c#9w+S5H(Yak7;i{9Ulr6 zdbJ-(pB0K&2|JP3x^@oVpR@nLZX`3tKYhRHfV)!fhqCMFc(wcPD0bJm33qd7hJm0`ThF6{ZkDk6 zb0iAn(1l^9KKp2{TWk7*ge@39m|t@(YG(S&kp^R;r=C5dxtgdYUgP5QE}-t>atT$A z1VRF_FWbgT;6cwvhZ28)WGyL0dD*^ChwtV5*%Lop*uFH$GDP79DBj++XO6ASX_>1i zyXy>G&+ttrFhAG6-Q9V>Ol_=|Q@(j5sddh+ep}#~)wOuEX}eIS|AOSJ{d&u9JRkA& zY2vgK?X(lV3VgExbo%&`;o|&Fz7>=8?K3z0WFH4E)}6%Xh}_BTN8O*Tj`i8rt-c|- z1WUdlS*+<3h#LY+BSc=hfMb&>+>_4UNenMHmv0eY$<{1P4_r9v>scK}o`y+loDYPm zkgKL@A+J15Bt5o=pXr?DiUSI)cV-5UWyBJZRv0r=LS~^v;O|UESq$)RBkdE$f7qC< z19&hv2LOTGDj%XwUC~#~8P#_I!We(|FR`T4s|gs%Wf=S*XaQqaG0YG&JOH=}sHE09 z+~}+SVQ; zzW8m)I7h^ewZ*uF6oy5&R2Cw&QM>(N*@(chQCP##w!_Cz(Ma|8H(9PfR5ke#HS}{k zdj^?c{HRh1eKFN{G}`3AH?vFrxz&qqhMU`ma+McTd!=qKF z3`!|@@$#(~cU!D@LUBaEgM4#}_@#_j0sILvJitB!qXeugHieF$ zc{}eKrFV!TfaG`mc+|(V23H(J#xNK2{WVBb16(*E9F=`^eJE9Vb`$IwfRA4Oi2rhh zY}9>ZM_R7fq`2wUS9A4~I7S2akgxt;cn@pvB9LNRdd(Yc zi!}8cQPQB_KI^lP6c%@cfB!z00@VC-E<6^M%ub0wZ^9J2;la$DKu4r zwzr}%k$z`_U<>>GOB7$FY07sf%b>w!<*x~5FbQ?5x z;{F?3Hav6yZjELOz%M*B){+=H0joj+Oh9Gsvj%v^UxQxjQ1UW21T!WX&vf1&ia3M( z$_%={1BTJXt0H5us#dKh{u*!1_Palz7gm2c5G#!U+zOmOryc;8V*r%Gn2rGNo&Y2H zNH02a%!zOA%NS0iHOSe{r{TcLA!z<1D}oz~T(x0DPp4X(g?`F6WjJ{#r+a78e*CEW z70c-?-WJ6l9yV(2=J-Yay!Y;l@P=+e-#9#BH8N60ai*^lwm7W{zWHX@G;H*szFNI8 znnr0_8(1i!D0p0=(j9ex$UWZi6V=_-O4fGF$p#7!wnVTve+6HqU< z3_Ld7ZS<)|{fLV7BbBlU>Ely@UbqHTiB?3C8tor=EDh=lpEk#)ess=aN-c&iv z?goQ0DV_K2i}%c4Tz0`37tIejW zXa?c!H{NQ^Js2jWj^3woxSPmzT%Rf+kxRp|`&lXcATv@;kE)g5E!cMtu0C(=wJlk2 z4QOhb**@Agc5)kh(YB+|^J!sM_e1nWMR7wQJRb}9c)8>%LTgKJq=;aKK|a%X4*A=_ z!wSbA6k$d0bZ(Z>EISS!hMUw4_#!%U41?27L_#n*WvC(X_~_WZ(Z)Li*W9&_9lpB4 z(J5|&lM|?s=qIM@Lus-b#&gXdgGs_DCm-K!V`)+^@HkO1N@PwkMp3g-id&LKpFa{u zpWoilGT+!IFhH?EW-=}aqY%lY9eLbF1X*HB|2*1x+cK-K%oCm@B6 z*vATd4C%|ZEo={8M9%Vj8P6x1bySDKnIm1=G-~j<$dN z5WPoX8XtwII^w-3tv9iMcdruA13_~%bU(n9`4hU>iA5pxV6me$S-Lv{RiP|EQu)oW zn8&A2h=wA<)?b!-4Y<)~^ymQMH4dQgYy};FC4jc|hlL4FjjAT#69Pnw-Qfww3-t(` zYrg(684)i9Z9>b_?-CS~(uV(eHPGf8=X6%`*XDABbzkg*S}4}4ppsyoA|H}Ywln=W zO&}P9SHIC=3t9XaJ9~7G zN@9+u@_sQ#iyGu{{-VMwEfC1Q+dHU;*6FIkbsKHYjCD5NP-j>n)t}1(!Li~Y1&)6l zmt#{K!Q(<*Tc9r&*+G4GgKwTNc*<6m{D@|P2nl78+O%`n;7HAwE{vo^^vRzpMAR>8 zKyp#F6?)J?bO3?TlIjnKB`ni$5CPw@?B$DD#nm^#=qg@$)C;bch!AAyW2B`Tt_4PV3 z;jE87mIQGjBDa<&oC(T4?8deqbNhwdjSg4=LS2pGTwV@Eet&sy))_B5zu}XmBBA0g z11SW(E#gNZan#yHyQzFl`&hU4ksc16`y;2Fx)5b;Dh1s0`HkyNb8UjEvc~;AHJ$*R zq>um7T0#g8PArQESM*JHD6^uvV6JG>kreTOz{Ax#VX!;JYSQixIF|7F-d9!EYw$L5iBv? zuXiv(;}N0UeTh9#vTNVz3MS)+ILsF>R7goivHFk&dYSamz8a6Aa0PRA{&A&T$0xE=4dVye#Gc>XK&a_8YEO8dRk20NPGE6?2=bxHdzl+kzh)=q|+4}im z7sO8fe8%|XIDlx&dN>n^sW!UX*e4uSM8Z4I9~FS>ghNK&;|uiR12Q=^!SJL_0NYkhQbg=4TW9U!rN z-7X?+fdC+tTywS{%LNb*Td0IvtjsSE4d(g4l1a^?FF}N)k0pN$?A{cVg2_b@i|6@B)5smc?V`(hDlC|&}DI6<1>u(5Iz^4)N zyg0mfeQ5V2zu>?}fA)Q?y*tKm)du)ihaR7T;=dZHs z=^5~_@gyh5S%~Fuz@JP5)AT3lCEvdDBdRpTT5Hw=m$PZ}=d5gG`q(ihg-QIn-J2aw zfkf$JPGeB0)kXKC!-~+%eUkw^HWVJWsl8ht61j4I;Y(J#*&zKVJ*{dzXZ>FT)cDzC z5Sc?+w%irZPUa}C5K^eg(F8hm3uKu*!TK zP@j`#o(!3|a`yxpOT&FV{TtG&+#Hy8{PCTZeKz=)+5i|f3c+sYZBR$~%(xhcF%J~* zdmv$!&~iM!m3nw9DrAt4u#6tsT>6XANvP>vAe)qNqUK_`9Xcq5>boa>+Mg+<(YPhVq*^@FV=QNyK5Lspat?d7stJ_r(U>jQ8 z8N13Z#$9XZhlp$=sNClnUybi=JR&y15efJ7`uf#>rDh>7W&G%# za^a??aST?=yM2}~Xbp7a)XD0GcP2~aa+&KN-{FzaHL^m+UC^;|y$IN=<9hCR;}f9w zAvgVswB6wDC(p_euE`%k=XFu?mf!2^8?ygOeBi6s;KZt{di4w(iFDrZn`wTyw+Th8 zKVEcRaLx1q%E*ML7q?WA_Z#7NuHf%U!%G?-RU5k^^Gj8@8pPB~0E^bX=DP+yQa7>f z5ntS1s%BmIge}vn0H`^ElS2rY9Qx^G)W71GZS&^{2L(p9p(yKo2#)rRG#<_OX1LD_ zVpD*s{aUF2U^m;yBR@K$j-lYNn8|Ev0os}6aUi2e2|y0jWhN6#=C`04r4+ZPhgrSb znJql7tMTJTqJlvY?|aOq$~Y>_KPBiWIjU$Nqn2k_LgepNWXy_&=_ ziHxzswt?sQus`Y|h$Le-3I~uVMp2EA6Rdukj+8a|mavOkadk%5WnBw@qB6h)n7Y85 z>{(TJQ+(S}dfoO#UBx$PWkBi8&!X+B*n`B{-Vs;2g(owN)oBY0m^oRZg}mmhCwZ>$ ze@J(@4>k(&fv1#%@byZe>CGQnhM==FL!{8u3`l(5Hu?%?qL`D1;a0nbnCF?otFM3Ab6}93rzAm5Nqwz~L}gn#kA#|QPHqpA zbVfCz8Wgvo-&lUd0|}$ETrUG?Nb}~CEYP4QeJSa(l(Y+)xQ9)_bycs7rm_Z#_($5c z?Q0p_hIk`^72M$x8TL>#4rfTl;**M+0EsjLDHw8DL|Ierd`uq*NWz1!o@yz?YHrQI zs~i#_iSv%vXlZtH{b({bC9uUbcHrxP$;UC)6u~X6(yLz4a+Qonnghg)SrGe7w{}wZ zy%|dE9sm}@`&N}Rf^%P0oj6+rDl;ALTcY)8|FRct(ZkdI4kL;IKsh9nv_s?KIi{}3 zKw5xTfrA}T1(b_P!3hCrf0_0n9sH$^K&M%;W0)p;FR}O(|HP(SO?*8RD)wsVXyvL` zLYkM!-7DGhZpx&Xm-zLaN!?$sLEd%f9P*pn)L@RtRp5%vS~fv8JK6`twa-`<*E2|b z8KteQ0gLf9X(?LR4rOpI)3>rLam>zF?WXUh3x^9Ydq;9_jpt8^|3ga`;`8k>QY7N! z%K{a2QUH-Z)-Dql+ojG3m4yAte5q;>cyvi{SXJh-ik(M$v6lQ%9}CW=l8^T>khw0* z**U(2Yr8yTD7`|Wfe(&cUoL^hQnTh!MY9hb=r#{EM}(-gv5l<3Q-k>%LaD%8uGMly zmdSn=$~+CoV)wk5L7B&L5LNi(F1(;zR0z}fX77(uyB;Ezh(z_4^2Cm`{`=2WN?jdk z0!*ovoQH&KGJ(aiILKs>%y%<;Ot;>11{QAHyMI}~7Uy`J`M-N0}huqLCoGaXtMGri3z zR}%^&b69=IF3>RD_k%YnrUk9iJ?DA(JMo)+wHm*Nc&?nVfm487kNpRDOT6KZBqvxl`ghsg4c)Cj zPqu%W)@$njIs;EvY9CB`CkV{p@NlwQKk(Ec5hF8PXhTERte>d7&WOCxd`JaeH~g0b zUl#|GNdwOGeB7$m>E~@y9V_^cZY#SzX$Ro*!cJrR#P!c41as!k6cw@<`S`X?rH{`= z%pSgd$yAk!B$>6SM10TRxaCqfX>EL0+<9Bwt-%gVd&5~yQzl=&KPs0wJ4sGeTs7gC zn2Xi~7p_mn)~!@JaxB2OEI-6x-Q=SVpuXFEebIl}ko-ZN$O1`tv&9c57M{t$lD7b$ zsebLCH_XyJ-y~Lx!IP1=8cUk)KUlQk}$n`Z`el(}$Sg>Lhfi8#sp;NjqkcT{m; zOWba`&zh3|bemqd13DHC98mBfVb;=&i{RaBJiPv3tjPEO)Z z(p?EBFe`>qcCDs}LwxX_&*?>t{#ZxJgQ}_rnhu`%>APiQHsxR504s;IvrvrkP4e?4 z?i_F^rYslG?7U}(AAz(?6e|>dHtGGQS&)W94|46!MI=?Fh}sp3ATxkEtyESIkv4T= zdT>N!Zvay}jR;@!$EtS zhMc7EOCBMn!~&5prD0(yzH4wzn1{!`1*ldns{OasjyR*wc1Dh4eOAN|B&a{A7Rg!to- zw@G~$WJChNJok&PMdF0-6fpG(#N7H5A zN}IxtFkHR)RqvH`%0$q7MB>RZD&~dYMe3`(BgQgHS=?7dX)1ScH+uXlzrXreW}_QT zbjei;fYTxX08-D(ro@8*aCMsNsBwd@5ClpH@ckZl%msEYN_?(~Gy)~nv4jwLT`0v3 zUrF2~i;7MdW?t$I17!mcoz`-M{KQ7FPE8+4tVlM!2u?Qy*E`^HKgaj7n1nCX(*gT+ zs|Qo|7N6!AG(`W-S7KCcu}Msx>o+;SEg;d}FsXSsZTNd2MwuM(h!4py!y7{6U)A=i ztr<^pUmFOVt(%eqd+8Jff4{K^n4}zq4I{Ldyv*FyzqB}&&+fG^`a3ie^mv1wn-f|B z|M;VFvOg}}Soxj~ZH|CF*FX7Bk<*ql{JIv=I*eQ%+{hj-NGzhXWl~4o&)ReHq4e>3 z;bT};Yj#|YJaXh9doJ-eKQ)rP5h9P}a%(0R9z4}4^@P6ZRLVNHi`0;2_se0?HSqBy z*D>&E+A!A@GEzMIC_4Q|@;}}T2!~I3M#K4viL>*{S^UY{0P_{UW9OYvSE)_ip@ z6ZYXd7883k9&Qm4X3;*9I01z)7{pf+OAM+QwN0%#FsCH9cFeZxJCzX)M)S ziW{({;7n!ncC}d+*#Lc8zUn+XKN8g&{n+Y~_C-CnSNo&eq6-MfJOS`&K}S0@bGj64 z?(@);LVfrm5_MTjm%FK_xY+*1LKC;x$}29ckBli6sXE%}2J_Ub!Th91mYrTU{vcAY z`r^0qLSLquw3*N;Vx!agKrM(ikg;a2FH=c+y{P{wO~VUJXSBbgv+)G0&nhBRQCln1 zGI1ymEOYMQSbWi6giF})CIBl9!wPbn1|deS={d;Cl@h9qqtQ(*6Haxt8BRuB|Re$6#4e$ zryu5sfx+dO>GM~I=i7x?xEtqu`)4J=40vGM5g*JKQyz_|YI3NDsgnThm_X+sHW+O+ z=Z*fnVyyhl+)fPa>c!sU$wBvTGKiO=GUhY^jfhWl$Yyizu!Ugt+MOur0Id;LIA z7fFw;pi7M?4=zZ)+|~hM&NDaHJ}sV0oC+(9AlqsR>tub^foFkI3B5b(OQP2)H<3Ys zBcKSkqVKCMy74NBdeLK3YsNLS&K{`)Nz9{0riP$hX&XMSnrm__Z!Ab-LT~VC#4>Q4 zB#nc&$(>WJSF1K2g@$*uZk*AQswO~qCyct`K5*5yY7prSX+1inKp9ob8@uEp1(!dw z_{UpAFCshK*l|=nw&>4O4P!6sgkA`5W-n<90q4FIfy4s%^gb6<;<< zRj5__kff?~jLz;}o-tI)AI2BoB8;tj`7=q`eZ{xIwScK~Vichj0oreTD(m^oCP^R2 zP@QeS9MOqCqh=o=bZX~jW7o?_Sk{HFW$kO&xSr< zs|FgwABC2^ZD)-o-~r`v&0`LpGmB4bt9CGAcSJd8(fsw~0k?6w@4$w>=GAib_cv=bF^hlb zU3OCtxSRoTH4S{K-lXLOyPgi7yk-=6smA0nB;x1!s@)3=r@SljS8edv?DDcthYrUfk>fb+Ejxrd2dV2hr6kal!UR9->S~r@{HYe?w300bdgJ- z%GvVhfEmg@RU?cnnAoH+DDN@>(WbPayQuXfPNC>)*(TUU8;!(5`iVmkXMzej(4ok& z$rDm?E98TI1iNPI?~|n;H8L$Cyx!^hWuDksrXAf^u>3*Znzms6svas&n05iHZU}MA zMKYWxsn!mR)EIeXLafG%!$_h%IoOTYlc;o=ob!BVhgc1c<8D(%U&_4e8^$pGnPw(N z?I|R^djXgIF^Rsp&=HZCFIXv`RYchrxfOA|1JFhzGc#G@xL-iL#Ew1f(>v2_ZmzLK zSFCT<8&Pb+4v~L8J(mw4H#A!;sOTxeLqA;vHXNU_=>x}p+~Zn-u|YQT74bgtJ07?T zMDWipxPYB|I3<$%i#J-}I&zSb!sKp)&-ZU51NN~KZha$j@Z?sp6FcWKzJ0#+_m2Z{ zumhsRl%d-098NZjT5{{}8Hf`WT*eM47Jj;Q-!YtAGp7<=Cc-WK&_fb4Kwf^~?r)rO zw$ecxx~=zy)5q?Xa;wkedcX?4@dQF;F4i;cDIYJtS=aVJzkgIyV8CdR5dMbRPHuAeU1^06mwf@4rXQIlp{8Xn-6vOCtGs@yFuCArVFVYIQd`5&na+h z&fo@(U86$A<+`7Ebx>^Npm&&MX0*L(sRr1L%_=ePRu&+Z`x|pE3)pjTXi!P#`HxtZ zXz7g&BQk1EujeFBu}=Iq_wOs?sqdGdj0V2Rk9-EMKEhp7f|T}xM))?KXbIze*9Ke5%;wUiNp_40(uS_TQmlu3hGqO>hrjS`LF?eH(B8qJ(MdLO z4&8LIrk3k3gHqitwS}M)+-3=i0X!+e=KP>?c+ z!JJh>-94Z3VrM`lar9w?a1^c*Z&D*pan5H-K5 zldMYwtahGdBKhnLO@5VBDWTa8QRihcyag6CjJvDOdB+jMrP}Hz8&kdl9?IVOPMAgl z^!bjl{pF)wyd}skeoQH|YG5Q~$gn~&qlgl} z4bcHHqOJzQE!}`F*=O|%i8|zGrlr!~6KiqT>gL)WbOTnj%De{;`X5sK->!?O|7+v8 z*LI^GTSvEjUlxHDC;-c+#*JIxsIF~x5FymRwLJv@mW9I0Vm=Li9LP^63+48y7=SI9 zXuoz?O5rV@%;oo9^Q*VP6TsJv$Qhc4l(opwgx;KV?udm2t^4ej6{l0Qm}E^7_Vto__Bd&)%1~%U>?-XO_{6JOVbRqtZDzf*E7$aLNzB_{BPkje{XWM z`>#!6Sex^DpV^wZf#Hvl9SSt|{aOZKCSIhb@^uAh(hfcZ`1VU>w@_Fw|Id%=)zgFx zCu9Rj>ryCYnZN{Y)H+$*htnxpKXQEHlo2zwsuojf{5~ltp;)yTj&~?oEe9HVtbFC* z+|FNDbRS;+uy3=Ad?rNTxK06Rg|+1pJHoBML|28Waqj zAK-qGhkfXm9{+DE=|-`RDR-V#`Y(kZB?Fm{V~(C6hz5XOj)hPyfsjocpA3i^f-5Nc zBTp(*Ei@S?dEWBFvBogfxa~A2cF;T^f<1DcNMrQQ(4x+ z?!)>tY1f2Bk3!RFNsTWufkiJB3JW7^M|bI|jA{=xf^U) z{()>hW;wdxc^TG(;=#KX=M|#FTKV*`GNalgx^wbUwpbqqd%1tYT3Ph6C}7rAuQAAI zv6xc(8G6)Wd$UUSlJ45893t!sm@<_~p%S@}O`(za4DUt{9OvA67#yWhjO(1L-80(3 zod;CiWKQ)3?ll+`_;0y{07fq@<1bJs#DXO|elUm`<_q{LT6%n@U61 z^r)ycm9?SBsn_c*Drf$fgaG)SHyfAo(ZxfGj8(8z_miz=ngCeNd!N_aqx1EHW zQ0dwp@l5DCt^ufF2^VjZvl7N#wath$00o`d3b8%h4Hrw?*G~q+Wk8{T^WS1F9bdkN zx%V#X@wb?G9sD}lpW_kGr>NDe`GT7Fsb*n>sjQ6s!mm?k@-Lfv=4l^EeSvr?a;OluvAF#3`qGPl!8RK?W&HFF#WU^6Ns~7(NavD z9?o4~miYB$@515em^)R9fXhpG zG6Rv({|dL`WmocrsEXv4RT%L6A8-2Ve*^8X%2UU~6%}eWIxh8s>z3iA*lqFm7A?)vOQRY*yPm%4i|C zry*n|^S7M!Kfoo3soWts1>XC$voneL^uWg?F4iZcFSy&qU42D=UUfeh-9MclHx!7FV3NuND-*GOZR(=-G~^miJ3YSJu%n5;D$x>E zNLkQTKCy-z%;hNxnR%z^R|t!joY)(9b-|SU$pNLi)275?SnAW_-b1m@-;0F z(o{MGm-Y+i%&mh2K>R?Q(l91#JN8Hqz`5FV(d)wJy@1L%E^S&W~3H@ycc0QNEf zcB(i+fvd?8!_BAnmXBUhR{+O}_%8 zwRaslc`J`+-X~hR4c*8bF?J}3DdL-&xnXt){XD)D4vm$rF7DiUvR&PC>{>`DuRHRS zjmOU$A^Ni~`YbAv-_yS8{!cB3*Xi>@qRZQY zDVi1CI!5OY-3D;apG-j|mz*U&78n3Yj>bo#Ti=+~G?>Fx@?SlqC%^dOq&xJ>w+rKSv&}CU2!)|r?LA$>K`k~GM8R(Nud^fCm--!IJ zc@nW)K_Qn}Ar(;cQMsN&AsrBnTOt?>GJNag>LAntX4*k@L2X_|L6@uj4QNcFr%@B8 ztR&GgKPjU$~>`+0vt-Wg!Q_$TP)Rj?5WOG zQM$aR_U~2%E0w`aF%3SBs>>TZGJ!|ObvCe8q=%aWE4hO&AyQn_t z^OC=O>cbe(C%0|8w<`&L(LVm6q_ z8S9yNO}!IFhQv>fmN^ww_*ox9RrEJM5_HS^_BV-m*>k@ccjH{H{!Z&uSR=ElGo-Ps zt}7#lW|Z3;fv@G~*cNc2=#+r4Q)RB~Qr9;(C|A$m^NiBXv9RET_f-4feOA^HfpE8h zl)Chd4dPmTsfZM=#Aq)uz(&bdURrf^IP&&o5;duxogGG|VlIR{52*X>IeKxt)|_u1 zxb9?FsXtyutk6suPEJ!7{kizNiy5amYmK7$xPrJ~nl#T)o)*mEt1Ai$Q*y|qUgxb| zT1`vC*`r|ES~-UCp9r|*Bk0JJ{&UZebLnb_-QR+<)uxA_;SgykYl9UB)mmzNBXXPINy)mvI^niLvk`SrtfjAxy+p{XDKn71&oavT{z+aI1=-jEUZ{sbB)#$+^9Z&BZh&_YDo)@xB% zTAjK4O+B(;jMq?wQIN{dBy-K-BR!YGoTWFO{Y7n1>i3iuEdD!2dCSVPIq)V1xUob_ zuMV{Sxc9)(JnAzF_v5W6eYqI7xEqRg)HfZb1y3@?<4bNgc`w4ALrt6m=&r?S*4Wr^ zHaRfxhJo40TVk2#^{elfGk z@8H2S7fid+8>Mb-zTa1~qL8S~ocq^z{bzZ1B7;ojzlWqlqMC>@mx?uvi(sfeE=$;t zD?fRZKRB7*c={23XgHW&`BCWdLa0>fAcnIv9Tus@92Sf6_P%m<)G|y7R!#2FZgqT| zWPM51ONrQ*R-+A)A?-4XZD8)@KjUbsTwO%PG745b%kfujar`$3v;OA{lha$i^@0e| zJ!V?MJ~Ks@Q}S<2?+c2sct_RBl2$kXD)32&{cZ*wI)6&%DZ|LI8yW5=uozW@M|A~m zLV@=0OI2iO(A@Z6RXJ8D+a>Q-U=wSbRv+m&isrL2sWRVv_Uve=`b9%J@ILKUTaKWM zi=AYGwzogBG$@}d8yS76q8jzQkFiP3{hXrX8B(l9QlSKmNZ6NnkIWH@&y8WSQL;RI z;rkVDGw){l8Lto-?Kq&G4L0L_q$V89S@P1Jkl_r6g|7q$CrG>?*C|kPyQJZl>5?#Q ztd^+!@~a_D10lend=l~z_m4gO9c*6yHe1XjmwH(w8E94eL+eHB>${M)pRM}zigrhi z++A~CEaXybxTUMM<2&_0w)?>-2{gU2{R%V;*8AJK3beCN<42Ezr-!5TBB?XxoW5oS z_-eeEkY=k=4nsT2z`g>sBooK))*sIktE{5wK@9Sp6EdapZ$d&thyolz%`-7voU z=563D@!fu1b_FFUV{@LHk@;s%^9JTVFke(w8#&Ll?%SpOy!Z%#2=1HFtteF%sei~n^dmO!m9Y;5 zrM8NWH*bWXr1_#SW1HZQ418-R=JesU@<~Y>w^*<~m75DCDmRD`nDy&7Hkoje z+Z>N5v0-SPirysyP*eX>A(ma=acjddnRsJTE+!L78MB2LMmcDh%uW01yD`8B$55XL zkkxX;S}n_a=0t!{2LJ(zm@(IdiWKFg6t6nDBAw^SL1o9j1tol%uaxi4Oyff7D0c&V zrlT>7JjuZOLS}}Uy~mB?rI-FRGJE|LXF^?PNlYf;p`K9FI-rgG=PT6QH!E$L{v(j{$A87J;zUPgy~KdRpN!itDTok|gs-eX46ze%=X&f>0g7izm;%A-Q z0%c87w9EQozWuiFXYP2 zT9p5j)p#(1xVtwOuP(LOU3AZwC+!tBQy7xK1<}xmoOcyVU92nGqNFr?=YWrnAKjq$ zh6Z8u2-EO%j`wEwH@C(IM;t5=cy{yD{=Pncvg}xjAFhb9%2Kq;aMOkc$uez&kK&G! zRMz`}jFUPPZDMthq7s}YqdrzIm@{6({ z)OiNVPy>>trHjpRd|WRAeTgYwD9$Sa2+2Hr_>jc*A(%)?4=Bud*=U&H#V2|I5CMR@ zndOdc;Dh-K-9c+v``XVFD;jg$>XaP!NaeS)Oz{H?r9}LtCs-!tZk|yequGD%RMU{h zb<;agS2=7DP=UNX%G0Nyn?a!7tV#6OTO#`vp^=nEt-7lu!7ab`PM`HG?x&R=;<_4? z#>b@_5*>gGSIOd8;urE2UnMpch6-zrc&$3ZWR2rKZATeoJob(7?%$Wk6}!=MitqkN z--=>^%7A68TC^g+r#A@$#7l;1@F>(i6p3Tw1Vb6W$-T8uIp}CYQew7p9+Eu8y{bv%TE*LPuKq`hCf}Y9O`d7$1+QetwsH5I*H&rk9b}t z7A>IafU$r;HD9aXy($+rK^bLrgaM2OFy=+y>n+UrsF9IHedpT>{}rQFeZ;e_$-9%J zA`0q_2I~1{MOG@VE5ZG}H-Cx^8*%V3W(W5DSdA#CB$@=ZYrZqNT4Dfp)@XlMnnhsfOneS8PyVUGab3Q0 z?o}+xzi4^bwb!j^q89&cLFC%>th7eIsmrNHQw#ZN)z1G%YnxID(rVzwv#?{DeDKvw z&#@u1yeA1@Phj1rwFz#hSuM0V^!BvNX{es@?xsQ^loU3;(4yt7KOU!-*2f@mU74&E zsEI@z0H*;D7GX-(AiLRPYBbESj5)Y!u038X@v@LVypuVpf<0B#s3Q`73TmN^h`P>_ z>)dLm(qc&8Q2)lqIC?aiudp09>!8n8G3Y+cqC{L&2F zBadSpKfJ2K*S&6YGU^~%wK1q9iRsSTuh)hBZ@Qj@d_z<7E)10v=@P^=7arz6MsXE_ zYdv?hvQR?uV%`l|I`jkV@N##xqJ{!wDzE@9MYoxsjHIw|c5&KJBhS57Y)p)FUJT|L zR$Vo9GbXqsw1RSz3ncl|YoD~X3d?FYQlv>@QePI&E%6y5#IP~=zl?tZFUV8`SLh9v z8>XQwkStD0Xhpm5_bpK6S0m#kGlzvGS#_D+Y^Pv6d9dOq;2(WD#;>$a>=o=Py+<{X z3`m+Atzo2<1%^=xCVoSF-};2=JhA%zK;0lVtz8UCS!s^qnDsXkR%mqx#h7V$LEK8`Bqs7=uEZ84cqm_ zv8?ADRar$`Xq<_!qu$c|5uQ8M7lOEbD)9 zoJO4~q!s$I#9!}?F^x*h+}`Qa`wQ#s0DQvyNA^TpkGM^|<5tunBuYp$8eAqm0&tP@ zZ_x@HooN@_sl1l)u|LcsFn*RLipW#?y z-I3(m_{=2sYdbmw%RPqgr_KjT1jw8e5+nu;msgGi@~qw}Db7+Ld|N5uBhYkE3yGxl z^zAZsfQ=_qH^!~=jD(Um|(Q(ts5*bc? zOO(bJO&w}YT1lE>xYb+8P%nX5n3Ykm;pTU;#?9~6c{*(i?;79TKN%f^j@nK0_#bzj zfIkl0-lobZY*N%pjQDDE&6v{x7E($cjFb5Y9j&PZsR?-&X|SG@6xdnb>_0kGEnN|9 zU6byA+Vkv@7qhDPGv;4OB&w04?>Z{)`xlnn2$R>HNZ2oWtaI7?(`*c1;pXDG=t(n|4GMU(Hw_km!rfqS6)A!~Ko`VGfsr5!_2NQhb#^{%ck(gFh6gtXg z3{6B7ALYy^(S@~qv8Wvsr@;&t`W5`JH?X~nqc1+Gn zc9?t!VHd|?O*t1yj{-0DGSO6Mzq-e(q2uh?dgFg5cGfj4Tmz>VXb*sCxc0ZYTK+JF zhA;y7P?z4bWDrt2dqE6)jq=tlhlxJ8myJ(&X=5Rn7lY8qtPA@JLuk3m#M2Sp7TpXo z?nllDmxx(MCn7QJbZrDvUU#LfoW>??SQbt`WY(k0A~5APN)^T-DLtVuUSdhsgWu^M zkfX(etfL>GFhpg3ZxR$HMmN+n83a^1^kk;K=(y)>=U0F6)0N?};R#g@ zA#ZiYdN-qAwn;wKV2YrloQnYMj6m_+paSql6_m(==&$l}Wuflic)0*_E9j*2<9G zONLd0nIM6pYxUw$#6yoQ1;+#ZX$5gc5I#M9@6EcG#h^rLb3 z8y}8ddPKa_2Z!ER(fGkS7?*Ys-E=O%$7G$0T<=pQ z-znwr#Xe2j%M*V3?>qkhXXr3K=~GzCVd4WF37xq(=c7J2;T9+(E8MkE#zSa2 z3rX2%w5&78Vf}a^Y*mqK{@bHCW38rIiF5n-zqkIxtkxZ>GEo-VyX_YRIe_U=(zHFl z`gM5H*oCAYLw7TnDlWYrR4f&Di{KRb`CBZ-(TFo-Oys@4(#Omnstg+K5;Si!m_J)| zAGCHy@Lz9!`sd;hce{3@lfNJw8E-66_JSy;)Vvw@kkGj$oHbhFHFH0E9t33qwQKsnuxLGqSxC1%aR29 zki@$u3#SkQHktSsy-)}?H1{E8K2avIWWm7a^tr!Wt=W4Ot}gN#z8h71M(o$6DV0RU zH8v}CE)tF*B)l%l_#mlsj*>z?>4A3DQA?0<%UHRaI$)Fe9R~%jRCOD+G8shc&I1GJ{G#dibz|duDdEH3S6}1KCO5_4E)CZ*QnV6 zIv38Dt1MDhS#)#SrSlWMbTW~VGi?HI0_bN(I|XtDss#qyp6g=%l+6yl3x5;-n9Z|* zr}tK`doK%&hEx9+rPKm^k}OdjN)6U5S$kK`AAD3&<5)v&p2M}hMLE{u8MC~ zNC=g;^dFb%AG%DyS$Fz#fmp1D*LFDvuyT5*3d_WA$gH`Q3(qqWhL759b`-n-(0E~VpGG33vCFVbIJN%tTtQO@r=DqWlHN)_+iJ{Hpg6X)nT zZ^?O$)XoeULhyiHI&N~R>j^NWeN%!_+%f5#24T|CzfCb0Mbr$7G7azaQwh!t-=(zK z68IDDYrsg`uU3?sRvi!ZtUvV<>lXTBB%ewrX;4IU05J`MarpxO>_7^;0?AiD=6v48 z7&7Q3nm?ZnV`X)ss7*b*rEW;<6|NiSS8T2Lh zcRJBSI&&qArDLs5y_g$aP~#%Y-{m>^ zZ7nDtrV8%svRo0#tcG#>Kcpc^fkrrspgL2eb(7U1iin1cZ+=F&_q~Bh6Bt9jhHiaB z<`AVQ6lA{g5RBVh`r&D!m$A57RX_CCdj9)ryAzdPN9YZ|#9#iquR(uZb`As24=>Vc z{E;Q@QzSnBc{%^1e$6R#wRuMogc$TP`RWVksd>HbrheV;e#-dvq7d{?;gaNa$E@Sw znY1|M)vDkofAJb!MY&tIwd>*x4zpE&q--TnWvC`qLTL9;EI&}BQ7q>atAC~dM1qaY z0X*+WQK_^G*uJtXm4{Fs5t9IiS9y*O{C=*Xtlcp>X&cZ>rlvbU*1M2Z;=j{;uUAW~ z-zC>>8F2Nctfnm^VoJ^`;_W(c{SHl@<9bJ7pZM>c<4>M4BGxwWZk`PP8=ln<2p;dl z@qY>DchVagos0_lF;Tm}UoPF$HFpHYGFR!>jV8bYkpg3&uN`YRm`YupixXLIs!Z4zWHT?2dLS8(ZK7Rm&;W4rev8PiTjsPM zcSq$$_3EWtOK0vzQsjkmB>W@xo;v^Y=kjdxi_vH~ba< z%zH`+!ry%C+%!skP3ykWyFiqwP8f zQnX}$+nb!q2temaCStBnmeoBw(bEjqic1*;?4gl4{XkvvAJeZ;vmetto2|T743vF# zON$w41jjW4grNbqfXE*giF#+Xx?J~?ixr+7@dFTN@Xckc{4$$i$tq8eF*ShnIZV^5rM-R}du6a^=DhCbdIQ8QP5vX3jQ-39HnED%{$B+NCxktelT_pW%> zR>O{E-!#ZOhW4#0S2HQpBd57>jE*zRVb7f_m%i6Gf#In7zBdUW9bJA?TdP9xp~YAa zhKWEjMtMH{m+&RZdg6AHNwul-hh*g~{!?jhRO7&29wHEpQIsTg>DcI>$V?D2bm@)V zbGaFq-F&ERvd#aweYYa!CeQ%t)YiGWP%tE7YxZ)Lak| zwJUN#5!mm-Ix6vW(XYE&vZRE_K*&rI7|Y!J@_v6b;X0vJLFIW`D=<-He6R0hB%HQX zo=yr&Bs`OPh1KV3D|@{e~T77PIzpUPrE+6-h?U>A@)J zw*=e#ZdJSH+97hKDu05PH^7w5Ri6$Q0&XEFmCE;SbDh{Q|A3OW^IdkE?039#6s$i9 zgJZF5!DW!F;7sjD7>^*~wzCIi1R&~2YfNAbErdr=I-^R<04QB8k|n8vM1vMD**lb9 zRK#o-nuPDufZ!QXzjP>BiLudju6;Z9ynnl7|N8yLn3mz69Xmf9o0lJK3>Mv#&3Vsg z9bvMA)@}t!+LCoVGzJRlRBF7YGa*KTz@nkpXr_timS2)r`na{VTequ1-i_T@8A4dBuXC;7=G&cL}d|<>&vM@A2IadavNFYol(WpgWw`_fK0mFSHQ_&uL~{ zB*t9J3TN{bdH@)SyrpFPBpwFc^g(Ly!Nr@RpW54LlNYMiW`)4-4W&L~0{Hvpyy19t~oXA<~N@#eYBrC=y z2)@-``5bUuc%}w144Pbpv9tpXCys<3> zkO1n9v;aH*+Q5F?EmQ}+weX{0wUY~PyZ+H7uW1DsZJKBs3L6~aw3=R2{??y3bZ+9` z!QjQ%H2p@IZ)Fp)ut2@*1}9h(Ay4EMCxi)rNUlWHgu;|eu?{|mAS(&b$hZJ9L82`d zvg%)s*g1IS^ozn!CVzFmo9X*hXakPE~xgTaLzi$DikS|GA1 z&FCtm;5=gm?eJ;3%fTng{?+HXG2m!j^Z{E;+yQ<(*sKLU4Fyrc{rt8^u?{AmTO}el zFGT`TyY2?KcQ#*EgJNi2wrK|Sx~_Qff`&4NygXt@3V_DR{Zi>5-^yF=ga+43Gu^b3EC;Vho&I_|s#! zhokj`9>h$@H%1C1IjKUM=YUVwiqO=~HZfAjq6ivatjx-SB?mV%=H-D{Jxyj^w$H^@ zm9daT12LT<`ftN|m(E*;z|?}xuT@@bg-~k+9fqnFx{jcLi75vxMSDT|U_8LoW&7@7 zVGF*o3nFAUbH!9`YETQXrG!qbR`x{om;29WtU|{l8E%kn z(Q9;|c5^*4Fj|cenva!jE}Dl|d+Cj3k3O`)ILKZSw6w_^8O}k}l&(LYJAOfhHLKrV zJvaOsZ7!PUG-#ab;&L&dYB8+ul>B7`<_nB#Hn%O`3@s3GwL+G+e>1Z2%wYN_+$a|N zmZ2KOAXuXj8eD1YP-+}G5Gn0DXlz$+A&Vv%W**6%k0x2$ueV$JVI23+dXbXvFj%t! zq>u_y_^of{nW)7!bwWJeamxwIAX*XaDY99v>PFKr( zU*gmK^{DF*BT5Po+8+Zs*P|_TCK2Dk54>p8nr{%!$vV7DQfMrcbaALG9LnG-D+Cl( zw+G2AjydObumuTxGj7%xNY$#4{Js`A=5o(@_uTz@yE_zkI`nqv9`$vLW=rGsNFr{r z#%%ZCK;DsoO{{C*s0&%*vipfR)>Oeda;Ok1VJqdY99e4oJHQ*_Ka zPxC$~_EB#63{Xro*bXEF%x6pE&TV@s25VfC8e5i5#LDEt;tO&!rcIoifv1cG@`-@DnqIS??Zs_=uS_eklnya{VefOz z#kXA2=?dDi%@C&3G24u3pT^G57{Bm}hep)~n@`?_Bt!=*G3jM?`q0P;iETL}4embf z-Dp*a`Iw7_kgpy3%T)jRgV=!3#@aR4acuoX zbQ?aK_@LOWfV845vhb>O3q#MWJV%`oWc`DR&(R6hrC(=&o*8sknNKjOqA{0A^##6| z3du^%*UYD{E-e2`jb%}o^`KGuRTXw4%7vo$0rL3GWej-|?KVdjHlgp5;n_~eA39(6 zIDJgr50i6qDK}QKCm|F>mC#_9Lq;X}D1mO>#8ENqGVc8%P; zr9qNR!qDveZ+%-w^(XaJZCbPhgGqY@m>NdvTps*I-eJ(Z(u+U*X7E4v$oq8Qy#jYN z*7Ng^INN;s8Wj5aFYkrmWklkhrQIp$_Rq_@eme5&y~L}{(*gYdfw%A8y`#q^lMvkg zaK-Z@{2Cn=dQApV{r73V6(7k>1T2Xif)jbMD>3wStW=bFF@l<<%GC?Kx$^Yf4XsYX zA{)oUo#%%dn@gjn>q3t{PVPLietu5GAfUzs6hCMzY*-QW=3)ycvt!dfTpXt-Ub8pe zubc1kRw}XhFOhyf5VoaY)}JV*s3n7bxL{qSYfu)15oK#itjijaV54-o`yTU?#6roR ze5#tF5XMLrk$oGJP&|_}?Eq7>gTp*gwv`Y#zdm=b0g&^acF%xC_P+~{`W?kv3w*{j zMTh!XXaelvl20pH=PbYcPiG)4fw5LK-^OD5{9W>@a-CX=0ML(B7a`u&`D=>{-<++g zyy)4@=WTU-`$d!#2(sRSb19t3NFwT3Va1qb~- zm>anha%RQ3=1jDmZ3YWXs}q{!u@VyZSl%tKQRq3kU(xf`L&>=1(+gyi!mwM!uw}w1 zQTE{(;i^+QZ_uQr)xDd{Ds4HpBIo<%hWpzJ;o@7<0}6mpu`gof!efN+3^c>JYMm3< zznPGkF7jJ7r^)>fDwh^w7h8lEmp{6Son4mUBN686=Idh+4ZMN9sQW)Q2R%Bk=DptM zMJ|wUKK`nZ}6jP1*9oXD5_kk4aUny^HB<_jj91q;Ol)MEbaY{%?Hz-W{bzKM_#0v}~ z{oa>IO{$jde-Nqe%cP1Priaw2g38uAgdR)D!wpMBi02s?cIN5RZ$7LrL+P`xnOI2N zQ-gk79JOoxJ`l+Vy(;2)9!9o-*wYj5E6nH&cA<{-@F%_FtI|k#yh5`JXfg zvctCguKsI!1OJUypXOg}f^KXs;m2EdzqkL8_ZIvP{&Vu>baDqZ=mPap;q^Doz&ro` zcgD|w=Utb`=h=Pk*j>?M%RTDYB_n;mu*V!PZ4sY9ZsPvqZ6WAkfI}Po9a`MJA@Vd% zNktfq6q&mX;b07wbxjG9RK)7mi~ECl7+)`yyn-I7$ILY60Y2Rb+)bCK1o&2*^4JX~ z61NS!5Re?wJ13~VlQI7H1+~`iw>?T(%yk*ZjT#g&Db$&FO_qd=tsUVmbkg$+bp@NZQ>X%>*}MCFs3^&nz<)D~uXji}1b2lI<;v)#Fk&5BrdZwZe|GJ@ z3x?ONieNnoL-J~@&FYK;fW~n?eZsfbBtnM{AGeH%-l(FoanBh2r9{7{XU|51mQiVJ zY%7TI-C;%jp3$EJWZ`tpd5%`Y%mgJ)Y%k|QPPq7cByw!^1KoGDC;Vg zoh5Ur#wJV|%bdEl_Wh$}IeW~j?nzG83~Xy_ zw{7OK-@K7plWDcPNvbxmt6U-Kco7-vS}gnaP7NX&dA4v*U@FlTg!FRV?fc-?9w;Q4nEll{z9ONH{wFZ`fC$#PK`_e)HgKhLk`6IgZ}IWaU3| zT`@tY;;)r@kuykW^5A&khPiGKX#1{%hsB&PRvo`6IwzCzdhmkvELZW$0I zl3Ju%pxvIpGO_S{AWuw*UIUnjmIY|14NEBSH#kfK>*+E9u2R!B(^ge=%?L>~*%0zS z2Jp73;|b7}2|Hondut?n67+(;%^*^Y;M@Wvzu!Y2_fH7M^|Asv%T7zea?0uZjo4K?*|vSm2QX?qX!A9<2M#%j`;xjmQt`ZU!onxKJp?lMv#NB*dMl-t1}Kgz8*QW@5& zzN0Bc%o_6Y<*qDZ>&5)ckNuS&Bf3QoNd*cL%b1@KQ7jfWm<1M`3D;N6Z>zC&0F=mU zk`0kOd-`09`gRHs7C9+g!&j&~&MW(Pw| zJxM~Tk9X|~i|yXqo=xJtI0!TzL`xziKHSasCpU&KB@Q2(@NlEqeAuH_qu@vM%QIer86?5DXZF{=jQ|8mk}=&;vc^+FAN*zZF`i_ z{^Rn05|P(ID;EeM5Y)``))$$8-HE&$4Z?mP5AV*NI-u zg;C)1YXK_XGp2#we~j;067e1-yOvdg(+2YvESzW$9sibP=>!z#rc4(Sjzi@BGuU?E zAfFK_^4f8J9}oJ-b-3g(%X9uA?wVu$r`Ul~Y2!rlur+b6vQw`qOh2UCjM*_BC5m(R zW!HCvxJ66iadKp8zh{sV13}BO@z6H0jAuQ-PX|=nuYrKhu}A8pr%&DKn>8V=PCbcG z11AKK68k6v`lbbwY7t=~i;ScYUHPq%xoj3KP|m(B-MJoRWBjV0+ ze@F4hOs8`gx2w8h8;PbS-dgS~D>@3@{6g>zV4yK+ckUES82NMx4UIuCpPTgQXAAWb z3rn5C)AuO`$ZScM2$!fu5Jfa9MBsDv=^)@KX&(4HA0k>^R3niV-~fE|Y1U)NPHQKI zcy}#g$0n(e-v0KvCo8VZSlPQ8EiQG7X?+VdX9uu%|MQ_}!y%lMJ2L&|`ugBMf7u@J z@IT?sr%N@n9@O1TY7nQGthjpXsc1M+%@7ltWpNP z>0Vo$*`!h47Fe<)h^w!4(B5+IkX?+VU9?3908$Cc?JW)xdhm|h>UKE3?-z+X))$9I zT~TH?m=B?+7Y)NTTzoX4C5Kwl<(3|xC)L3Ht1~&m} zxl1%?E+=mLyGJ_V*SST^L|6Uf-AF5(_9h_9x z7H{uKgwr<)Jp>e`y(J>AgwE-N&^#XFildmbTSuv!WtmwadR11Bg%U8e$?GL#Qf4VS z^P89BVN2i4(Kl*^neF0JmhHfhn#oXPcO2Z*L@*?Bq{yg#KQtt0YejGzs(=em)^kgU za1E&YuXE=UxXR^?qr=}AnIXOAvRe-qdZGp5GTt}LCA#!eUhVC(P3@_jf{>DKnpaNK z#EOmE78jgy4OxIrTY^i=2zZSu+UpuaUuWOceYY=Fzx7|K?v(?_Ooap@3f8hBxF(c5 zt;yEBYu@h+BE?WF*y&1`>OG+gK}XD_J7Y2D&Bm|1P>$T(zN>5(r0J3Ec|Hn+b# zjNQA6WLIHSZ@=w{_`9>!5qgdNBL^X>fvZE9qyWi@0f>Gg`n_ez%|KcNPΠ^^ZF> z{H%DT#HqebwcO$CPnwy-2e@^kXNRN9*KT#3n$5dTuScu@nYVjA!t3kId`8)a4l5of zdh8yb2EU|1>+QM*S|G^OJt9xF_goF<C-sjHqN7_XYR08?4Y@DF)nScQ_s#ZW{xHEZ+H>>fL7ijUZV^bws= zmQ_R1v0=qO^RH4-7C7y~p^5rtA&L2vyd@ko2N0_&5Sd<2kxphRPT%!$4iZ_ga`jqB zERw{W@Yw-)7qdNAmbTunv9&wYXC02jUAcH1k{g?heT+MkOewmfT@7aZ1hXQD2SA+d zOc#;wiX5^A2$gy~a@|+Ptg%-;j<8d7=%%fiUAWQ07S8#M*|gezwIC_`Km|_letg@1 zHfCub>!<)HX{@#=BbOFn0PgzDt-9k}8ZP0tVv`H;2@jCg&9Adn^mM~2_qjMc{r;G5 zCjG@pb0kMyCT#bB$~D4GBz;$0O>O{FrdV#D2ezIuCK5PP+JaOMxBrB$l=UraV8of$9_I74W`IxDUjo!qSQ zZ=IL&d!bDlzwY&o|Fe;bc7PRv9h4nKFvcVP70ss$FsCk%a zBuJPrk^#+?he=TGbez3%NQe(RujC zgw&hR>4*f+x#l&N-P<8zf%K*_D!m1uB-LgSwm^k~JH~Rt^c=cwXXI6}}vQi)-W)PGM4A1suV z(XqL%;AX5O6!|hsIM=iIFD5$`RihxNYo=6fALq7rR$xu%-&#T(N3>`hY<{gu9)*3@ zl|5kuIw4p8XYv(vg8(pG+uTIts#88IZLofx6)-YT4kKk>nnqzLfD7>>=QGzg{$HPk zwfV$JdhR@mMg{p!QAEV+huiSBOuU6+-2WByjx&>HBJE&o%QC(CC{1Kqx>+M%8UHUi9Z)cOyabkx%WTLj+jCa|s( zt6}sML4p%o-T4Oe3RFP&k1N(bYiaUX z`i&v{#21Xt+Ft15cHb7i1Sg(x5GibfHmhka4*$7o7!=3@$Qpc4ped|TeOAq3H!pzhva zV^%6)L^Ge8_B4HtSO6Lz(l#2{8T6(b(hXXoHWFF@lA|lrlD)Yp)T;1zmHCR=u5Uf7 z%mST~^SU(8-hQfP;=F~2ER1Pd59F3hk5)3JyGTX}|A4Xw4onLwwU0JG{?%D!uu>fmw#q@!Jd8_)VNV{CR9#!II>GQxp^UM>7@`Mu@)S+@q%6syKUD;DZxQzx9?D42uthClb} zM$QVSM->ro69`3n_pmr4iGY#;bhadY_TGl;>H)^-XY3}5*Rv&yhF_|0ihrab#UC1@>eu~K9}$U6X4zV|i6{jI{lX>=OyHNk zdL$2>TqhQN3^7_x>S2z>BWf(2^|@WXO)B|u2?g-r(ANVI(dINxb8<- z_wBuQBnpjepK=8uY#X|fg|GrchDUMDz|xK=6rVkdFv824p__&s9R3D_E-uChzbG?(7jCJf8@HP zc+!b8eZC1p%giKr2xP{(Nn_1v#9nT63xn9>wASG!JD!BwNY5Svfhc60nQ&q%VWFc! z_$7y=rOKE03%F7U*hcx57>obfGP4XW(OC#!jy&EB*B)td{=3;EW5 zOhd(8QB5+jiW?sKy@c9Yx=_`J_P=7axXw$vku zoWQr_3A-tLNDcVE_VluFb-R82iLx}J@2u>QWwY6<^K>)_7~A*Iw6X*Si1w{V78gHa zFRt%rk8vfi6A@ci*5XO7__B?C(+cZP5+(*BJN|>;5({8)o|VCk7P4LiLgn~Lj2l|z zD%D8au3eOaWsR09lMtNDvL|N$p{G2mkeLK zB@&*&QW@e~uiNS_sc=5!~ z-;bti>Cowp42nAKA#!Lo7J|6~;E|&hMMly3&G442XzE_oRjv$XnL2ulU!{XyXrWUn z11i^pkv5wbJ!0tPGu76o(xxxgJ&ui@p|TR$$5k2tt^#9y3B~8^w~M$Cq5G1pT4zp= zx3y`3p6r~JzR(mx!>a zo|aaVopXjAh=xrMmfY@eX^FO4#A#1;4M9ZLdhaT!*^F_R%eg{io&2!#xmG#^Z`GNK z0??Q0*=FBWvCeow=qO0#L% zg>%k3JA{*IIw=@*7JDQPOB$=dI4|T2+X?j-SoP+TJodiVYB18Yv&1 z5bXdI{Poq;Q?i)MEID2zv1n2nOd-}uv4KKn$zkHi9Jv`qq-eB1>w4mvATf!nR zlCeLA${pFaz-JjX_eQH)pA%zLt3U-O;T3ZBX~V8I&^lR!e-I@*Ij#?l^hGVuu>7A2 zPpc{!WV5De^U`a$=kniCd+3XL)vbwG4cMe8w^Y(}u_RJwrof0OUB9?-5U45bqpUC3 zd$zS6^-=Pgj?%QiGxt(#M83juhx(&(Ue0Oz9P{7)?(^!K2{SKiFMmoBPO_PTfQNeC zpA&YBvl5L-KbO97u-PGHv87rC6PYVzvHz_by_JM9IUe5XAd-dfB!eTGFp7o_(j{QlUfM^DDt zuSkr7bNZDUi(}(WXj&IrUyvB3zpx|Nn}QYV&|8iILWs4VqAv*aLNKVKw4aHUJ)9_7S8aJpEF+TgBOLYs;^jnek8DuGnydhi>UQi7>Ttv;CKFlbJ2H z(J$tY>5YBCAes0{wT2oV5J~+qBnC%RGW1A91JC}O4;N`{))Yv{(6b)8N@eVKC1sTd z#?oj#Rq>)(q^%7XqZ4_O*y)tAkv7uoyzj$YS~`^oI_NR5l-ztrTa?-AqlZmZJX<$; z;8=lICSAFdyKezD@18wBl9TK|XbtE4F!w&%wu zI#t&3mgE(79vTnBoT_b**?J>}KW$;sMudGf7O|~AcmjVy0I0@ruFS8Bo-qVfHMeT|vSaN+4F{wuzNv03R0tpIJ*mK$ zO4S6T$gnD6Zke-;NeAXRm@?VS)u&B?Q#H!4q z-%^%pK9cwBU*R_td!+TgF@=^et_v%iq!v|qX))urnJTIQ^I?SI^YAH+byrVO>g_E& zg$p7b?=&4kWOPnzR3787CIfRSM&2}&O2A62L!%!aDAUrLgIkkp%EXKx=v3F3%Pot` zpkuR^O2oq-+u6)0@!d#&l3BcvhtdaM8sljJfBq6#Tv-d6-L*L=rS2}9KlXn5hL0h&cK z(UT=`_q`@5haut{Qca>K3F1}iFr-#zz`y5V^vt!&c}9vv+4Z%3-9Lh89L9Exn^qxA zjJzNAzo&y(Xsl)7>qM){>>}WOWXf@@mK_rgtKOgCJ-Kkt_gIRAY>IQZ$2z!xxr*9w z!rTfu3rYm7BwLa2HtP?POp?P0$;Q;c!(`zx7IVISanYjua<<&be#=oSDeKwi)H(V+ zE`m{iXrjQH&5aIyqhXOhr^Kq5-WMd<35>z=N@&EM$0{R_A{U?Nf;x296}N3ZH3xI$NTbb?lJeVHRDr@;m~0$p5CIg`sCnzAEz&<(b0&PN0I0}kV%g}o zK3(o!NIpnpotPrgX}<-7vvVw0O_3-kY(97X)dX`_LJ(Q^ZUmH4p*3aCs=bo4(T$tg zl*(*Oz*IXXlb4RZzem+ls#**b7l{;#XNQI8nZ%Gdm#xCdM&T;`$oy=e|l02 zvBBbp-Z)Q6i>ZEa+FznJ9jI6f!oD)q6mbkdU&*79UZ^@))0^i3IbxQ?+!?k?_p?Qi zjELD`42fStdKs;Z&&w(s z4%ueP`{>bVfI9?ggl#FhQwmn=d@S=p^eA{W6w}RvS|^tPiXU}iGkr&$1*X+LL|yco zO@zBMvPoOO=I>v4i<_Ah`8LgxV=y(gN5)vU$~~wPZiI0WhL;&jJ5&|OMp+Iq00?^? z{yH!NLiLqx0k+J@R!A7fArm0!5fhN{7q=#ysV6oiLan@%ApJKz!Gs4QKlhYj|MyrG z;d$RcXB_@HL6qCH0zWDZprb4;{togTq{Mw38yU$u(p!pc-$d}WP+7KBzorfv02rAc zkiY1aa3In!i#EDUy(1fNLHWu!aa(V?8N>h@w#PFJ{F)lYl5-h63_RN== z)N36hMVuLZm&lS}ZIFJVP4QyrA~AqnV?{zB+gWY%Z0{Rbv92A8--LkQbtJ3Ecl1;x zrDy_0B&1H*E9H|?W;e_u>3M71;^*J{t8inZVwGLOUavK>gKoT3m$;+4xPD1XUlH;R zg-p^*L|eHRx0VNiv5a<*c-*!(hTM4W8oyFMIId`Jl;<17O4r%1k-}^}zIlHEAyeyZ zNnsO_LS>g}4l4EOlI{AL$^QhgD?=&aUmXcd0R~MXBF)L9k-BfXatfXPrU@%k4bm`j z57=jMP%&4y{4DAHsTcQEYFEJ4-<6N_8tq*_r)lhFm?pdb%U9G`;klgHz7f_nEb>AW z@prbQ@1%a&{2m*Uf_GQsp!lA>LbLF0(Y~$+@M*sc5>riGHg}8v*)6T*QGum#k^88Jz~HkexyIbyTo!G?eusf=vI42vHkdwi^zR z%tOe|r%gdpMPCGt6h~C0p3tRhPw5g@j|MV2ae`3dts$WSJi0VR_ku4Zsx>94!4a@wcT)sh-|@IvPtRgcVZxdDEvASHkqaVUf8uMJ>zvc}DxX(vMm*X7n`r-$P;7Ub8w>#$X9)$oa z75Pt=G_BC6I=v2pkB>|6Wbt+jqxhsJYI_i_OdF$>g#AU|fzM(Kx4Uv=dz!*Y8eO zh-Pz*6fz())$oqAly)&8eZ>Mt%MloTEfGhq-(Wl|7`5eWTM#>sac&3TvlcsBU>2U z^09`y{S=|?(^~_!7u!b|Eo1_wffpyP50Gsuf%8G0q}FNbe*eKU_ufOUnu1fzfP2B8 zG<3~bV+^M!HT`B5fCW7q*S{fSi{jX14)nAJ6Bt}IT(v3vA0?O+uUygGa7?+eJ+e#1 z4D(=UYHJ(Cq)Z;!##gsR!D&X4Vg`GzMaZ%~AL4|cS@q%|*-66JwmJ3QL^fg2*%}zV z{qPrio{Q8#8K7M0I37}%Zsd? z^Jfw+A!CzEkMwX2_caE44jT&1Xni&=#HD*jrWFnm5iS8Hq#X2WajKJq&-P~%LQ>Ti zv%h1?T5Or4(~oOw*z~yJBY^V8JeEjHKg%SzrRgfgl@x&tfVe0(TaM6hTdGuA!hds* zL{AW2b?h<)X?ao{B9l%+Q4Rni(NrWtsvhj-LL7#VffG%-qYF4MoSvFMmBc5BRUpVm z+X`V%NHb8TPha`0?8s6QdSVReA_wk zZ{ZMS*LHo+kVAKOmnb0J-8po(bhm;GjdX{!4Bg#b(jeU(f|R7>H`o1q@BgsRF4Ic!9IQ2=>l8+n?sE>$5o1ZHLtr+0jt7NQ9(?^4UlL~zahk#Fmy zln1M_PadCa{d+`m%9Hg=wqr1Pd_$NA<+zn*?u|7jR>ZiYxh08+^!sGv(V#Xe zmX`0Ar#lEAiH{`)GN@_DU66E0{ z6ENinB)3SIxVL%d#-%%eBV?FOm5w0!DIP9r{*5{3n8^+^(w-{*ktqyTpd)3EE$LOL zK2i|bm_#Wa`M!QOBY6s_TZjtvE>eI%U0clY1~s#_2;F3d$$e9;3+&($cTz&CPOy7D z^HHEv;?UGj9FREjk9~qoV-RyhpG?6jCW+Lhf3ojzYPqxYiZ-}OIXD+1Ihl@WDx_m1 zr}#D5B>%uU9pQT})npYnS#>BTyUBvdQ!-`R}m#F#hC+UNd4#U8$$9Z+9t1pX69pS4a=H2T#rXSUCH zj5f!-HBw`EcNr$xmoi5*O_`0qiz@&{TB za(0x{z`v?vqj^?Iu?QU0hFi|%+Heo^wZT*CW=xugs_7cRXHlHzT1?o|c{zWpVh($r ziV10l1+YK1V2!T6B@l0J)X_lh<42M_9v!PFhZ5>41FA)$AP<+n?Tl$2IDPuq_Lesr zSNg7+{fo2jEb8W{Hc>%L#RqqN3P`e%c=J4wOh*f}7XZP7Co)7uH;|SmN!QXbbuB`S zAQBJ9m+K8fVCNG4md6xuRA#zv!x0Q0y{iu>y86Qk-fFS_`#q0!_cye8p}Un@+UttM zZiaOMC*$O9<^M;NiB7f7YPQde$}q@bZQUhKN$gZ6#EXby!E|{Ww2szBF4_D6S?du= zYAehq@4DJKy3Pi@V!2koaKqvob=!25&DC9+M6qSyzn$t=x5WXLH|7z~@`2y)0ZCbF zAVpPTXR*kH!O)Z%rjV>P;HMe>jo4NS43kaFhyG}Qa`(euqz|e{RM|-WzSp>XhaRx= zT{93pP1gCq7}B8;5yMhAoKN5!+m9>@^bYP`kI)Tf@D2`q(a`%B7lj)rKt$s76T_?< z%NYn$qgW=7{l)rJF~|LepIdt;-&@ZoFn)!;FL+1RWY~+ZhW$AJRb4@RuqLF^o%)xe zn4&cgKMg;)LV>$X#-dgsn96HgJ2xrC$nFk&=YH^$)%ru*sQR`79An9vBu_t6_Im1Y z2B)%U?6yfEvM*&;Xm@n!b8RZxWTJSII7t@rPb}b#pwO#&a*q&76hH&<2lIJkum)YD zU|-D6jzbJqS)nXut6<(hMUK=voQ@RM#NSEP$ch0J_0!6iSX2Wl9*O~z8`j0~R6o`x zVpvRw{yj2>R3veQRtM=kC%9r*spC%RV@fYNbIJnuhIl0?u`rc6-Nq^vE#2RKTqov* zd?wFyh*9m8c3T+vP3R?VMzph`Nb-Yr8GBw?*f7b$p0UUi76#wjvQ3?mI1MPW1-lB4 z^o`!K=HwOfUS776xObLziI)KMjHeEh0>^!~H3dhO^w~jC02pTR%}2q@{Sf_}7M_|# z1kkIom;sXR_x5kX_jjj2NQEW)7m3aw>+GSi zOCFvLfI8YC-NWD|>lF9m?OvfDt^-$FYOKRMY}*&76VNi9_CLvKY7sb2>d9{^+tbn= zWBZ{Gv)g_6ZKd^>43>}o0$WvFo8SzY=@6?n<&XBj_DS|W{ zpiN19F#9}z1ZVF|>3Apf>W;Bg##sd%x4w=+EE|%Q1^nO=!Z!Ke)IR6l9#{kB^^m!J zwDo`e)H|79Xo%s|gHF94l20jM7nWsOf{Ohk4dpYc>;C&;imPV@&wz zv&bU_Q@<;gN1`iUH)&g{Y?w%33#J^=zFeTi8~M$v!mYq*;V=}*JA;#iABE#FR?Stc zyjIcl?Ie*rW?_WDwY@rsLqj=Psd>^I)C8}EQ>yp*=ITm}U_|0ZdWKbb81Cj$nPj)X z2||Mr_$R+Xi5 zrr`4@42J`YU@S+W5NzWY%)_njBY7fJY9KyN+{3R_&0vbMsiSwCS~@nbw$?8RT*LDE zV*TBC=CJqnAUY-6avf%`{$fu~eui38ML|RvjjUm;#>Hs~3_qoHx#Ylg9myiEy;=P+ z63R;kfFcq1lIl68VhG)SVabf{MmCcz*L^C>9JsqK=jiv52-tf?ng_&%#PWraf&;nq z1z6WKdV!rJ+t^PHmNg@rfnxt}ogUzC;>sEQkRXR;3+|mI^^D{)OSkR?qly+0r06O{ z$@6hra{tLMh``wtSl&sQr*O{PIYc@4Odlr(j8gAEQcG_NAW5GGEt#%4swITJW@ZTy zH&;b)0+V;R&4|5VYWSDQ-=83q+@eGq4x*h5X1MtDB@%lBa}c7v$gCZt4{D~MbDlgZ z*XvkTR#|Q&3a(XjgB=}sdsLQyh3AC1mnp1X&r!DhedmukJ{!meK@02pH5wGYN$Cs8lkG8Vat2rO@|^v9F$or-7Chu`6=jyO z^Fgdxb+S~^?ZG|kXXxet!M2uy|DmBn&XJ>-Q znM;x89|IeYcm)k}ZLg8Fg}pI9a+49*Ra<1_?0_76s%n!xb$@xCva|hm*|NV&#ms^0Hpac5R%Jm<9DU-TNkZAVY|J>eK9`Lb`DF-7ASyA;E8iEmmTZnx}P4CMkfZ?o8~nF>yg zD=n>`yVaq7RiD_s)~gFMHRI+*77Q{inoo-IW30!v%SCD|#U(0j@_$sDN?`HYA*(|m zUAkp|)WflIjK!heUY`;d_)`t#DMFK1Yj!nCZ?}|sm=0`$ytZQ7)#L4rNDSl^W7Izk zL>_t4tOVn-I{-nWCk|dbI+mpl#%j>IJA4v0GmTgmtaaTJ>XP{J5_p3Sk^ig6v^m`- zYH9o4#Bfoi+iCw%MY><0dQ5DOvMDHpH;-YM0w=2HtrBnb?H@iQ?B&oWTxb`+TW&u*rK4@8Iv0 zb5PD9oz_4Ca1t*4w`xRuGj{&Y>PP2TWHckN%H@|T@W&RL>*mV%#Xsg;N2kSp=Da_`c;02M}cHaBc_6^tx=sh z?ggq<5dzb7(+L0qUAZuKX*N;1k@?hnHc77K@ylwj++PB6N7iVe^gbC8BrTPdszelK za-eL%uN;inBdIhBFMSQe(g59xbOV+jeyqo|GTah)PLMejuz%%xBR{r+VgmLNlA&ytnFk|dM0#u+Olfg z)=2QqTV(8|JP&~%5a|l_HlT{ z?)gvHV^DAR4N|bYrAr@bekdGnBk1YiX~!O9GQB`!WG&%k-s2WiZsMGL8i#c5C=oj& zi$qpt!n_tmy3=JE6hArlJf>+FjMLPV8)Il(bbnvm3+GGP2&f&Ks$Xbso)Q-Qf3G&e ziBAfAsSZ{h;w(!CHS00pUhLp$79wpbq#0S8-{%>+5v`u-Rvks__r;e^#WxUP=4_DI zPwnAVI34nDF=|g%x4iLG79}HNx@j*T1$v2qiP(l#%GKOjyxGBP6XH=p#HP{?sYdqU z$$jHlz!I9?L?TYZ$?-K#@uXu#4&%_q*lvj#+H+se-m4i9e~mWko5+5!jNb6p=fKAC^^8uk}6|7zs=NG%>ecS`aB zCZ8m${4NEiubZ4)yE#ACdANy=#kzkY#y^Qqp?x$rVUcpC^`zqgA*!eVpkt&`nHH=1 zmaIHrj1^EP4yZsd%5nfA4~~RhNA@*?b|AQ)sf&NAyfD+@Cakt>gL1yyJV|>!`$N^b zA>P`kzh}OR5n@th1y(%Y*M$>20|OBGm+kt1~it=!3CClU~x54 z<%}e71SYJ1g1+P8B46oobMM)L5$EzH$d>JLEH}|5SS*z>k$7WfZ~KoE1nA&!DGGkt zj@@)2$eybr`EVBh9A5(#V!G}R*>Vtv5f%285;#WyKfu0GZc@G$x82*eZLlRjeGppm zxOIQ4kz{zBHRgL{XW;vI8R{{7XsY#5?RApQN{GWGLzz4?FG>;HWP(+AgJ|1|?`cT+ zpbXD9wheoFhEYNBgGb@Yah2N{ORaxqEeH0R$j%3S1}I6eCl5$jUcbJ8nnn+dU}XKC z8$Xb~W2JKLfl*$vkG-OjClXJ0GT1b#%pV3RDUm^Dv zrf8Tu;qVC{elFT>Ch(y(kAkvlVC5fzBuu=E=dC`!*u=2%`1!m1;y(8eAO?;cWegjm z5Mrp~t#pqXOcyI&-eeHkF?@M5^T>_=-jL*HEaJ2J>fm)GqIzs^sM)vOvty%=#Q>_` z*^1@3$-dosG+g5Z`_vO4x77u=_j$vZqBH98RB%53pLaH=flWpL5gLvVi;{#k4?Xf{lD z0!BJikFbNlS^M6MD;i$9|BnykjL5T&>h7qZh2nat>wNpVIi}>mupj9(Y9HOhR8`4& zj6=ZYb<(F!)c5V0f&7XGdz8(p!pt3?!W2L!7qhy?@JA>IK8!s%I{Hiy5pUd8dH}t^ zkw5dMys3ZG3ftWW^(4Mlgr48shl{vrI9kIphBZ;60!m5JG>psyUpR0wKZ?#_KluBG zi$zYA)VyCHPs%az0R0_aV*_!TdfGNJn&NK-f%*y9J_wMArWvEGAknXxoMecRs^6?4 z9?B&=UnkoPk@sc4Us?l**zZ%Rn*tIp4T7-%D%6ta+|wx1zT!HfZUO&du8-f%l zbiHe_EO%^pO#;2#pyb=(xq)E2-GX@NNNs`pX0=XeWqu@Yn34)OQQgn6CIUzNK^5o- zmo%fEeq{O%Ex`K09NfjaTCzmC&QhS#w1Ge_Nm7gTQIe`hQ(fIYDHE^{>vF_ErnJOq}X4A*r(#9Va)}CB}hz{!yhM8)80ECMqsZ)^{Dej`2XI?jHla#9TgMX%?w@Z zA4e}zT*iU*%^?;|!~7Pw%w=@vOU|O=nEUPC$mM(a>N-KIG9}7lzVxJGwXYg>qf{un z{AV1dlKNxXR4#{#Y#B7Fjj^gYOBgEo$P|q0IN`)qu0C8GIG<;D%Vc0vim8xOHZ@X5 zpIy?ruzM(a4ta)@7&&atIN=Cuix?j>0ilgrCus~PmGvztxDbU<$v(`nKTooAsHvA2K3Mnv9`JCV77Ta*t_$xSE=b8I ze#v!xJtiC%7zFsaQLQ(NzTJLgJi2Nsef8~XEq&B0aeReVB1$zo#vJ|1(Pxum-ag40 zb(4icD)xtEP?P95R7Z5XYtS9=rbVnh1*w4aOIf7hmrg}I;zxRwXgMpQZOrv`04u#k z)a=}v znc$KmS(}Z`=Kr0MN{`)d{l;4K1jtU;WY>g_%V?EDec>rzW0n+tkNmomt*!K2A698( zSl9y|jM=`(X9DivI^At)t%p{gj}ElL#b0OB@_*{ke_`PgP8B#R!0G)qN29&p$WL5# z$*)gbg^N&dMusG-8p1NBs)6G!WNVAKj@cLRX&%;v5KJd)&P`f~sj;fS!)0sM`Q)FW zMG_`K!zP!FnD(*JaIu;jLDKfgA4_bSgK@A!evc9w+U}L0Q%yN**@?7A->4H~TlkAU6`xV@ zR7#zmMJ^*MwB&9QZymnkrL5E742$X79*$MNpw%7ImGEosFU`VVNTNfPu|{Ke4SI^fa&jD-7`4em(Ykn)@6ljCrkWq)<8mGV>O8{d z*5OZ7yK#-1Y(r!$dok2yGAGL1TC!z@Qrin&;q+gtt%*zY^uuYN{n5%3aZ~UQ9J3x- zrR>We&cW@4h~gSN^X;xdwy@~%~Mb`#L`6E735sswP_mH$bZYFXm`JaZs)SqzQ_ehrhFo=ujYm_#LpGhfnJ{hRH&EaE7J1K&v&M-bNs zb6K-iG-`&Fvg}Rn;y(+HhC{t+E%?h^?Z2RB)=|As%x=NH+z(+Ep{YV9NIU zGBMG&*6nChvAx;P{$J@AA_n9%P=29$K30%9H%3ns*|&$RRTV*)*s%@anxOS|h7Ky&za#uzsf2}RGh<8ucRyM0MlZBGe(;wtx=h(s9nW7BCI zAhnC|OZvYBb-6RDbgH@9Z!I92l*@N`(GQlaurozPdOEFNb_67pi}@f7_b)Rw@GH-< zZ98Av@byz%l%#IpA59D6#?3jCbs}*j0>GeIBls>Gfd?$MP^x>-4fnj9N}0#=KeP`? z6(3y))o;`DaYf90@c=31P&xyeKAbA;VEV#Y;v4pub~MWk+TcXj+??G91(99CjDNu? zFLdoo*z^jhNn=$=*Z5p3*%!bgFgh+A<7PX85-2yc9;6Dc#7S?d`Y=z=!clQshAED^ zh?);zH1E+$x~xyr2*gzQX*|N*0*(W~tOl!xcoVc=CNp$T85ZRAZkN~=pkb0BTd8A&{>d?EWK1}WAnu0MH68Z6X{rx;iT^He4wa_Q{pN^0WxPJM3 z>k3CSKphu*FCmU2zMCOwewb7ANitn;5ofjeTKCrket@aYNW~iE3NkV(x0?x%A008> zMxm87n6*rL?T*qx+kFMt?N8zpTQYxHXk7_ znj>+xTIM8ef}@bN`Pm1Q+bp>!+U$|+4jRL4wUq6#BMF{)6_lYZlJj)zSA623&#Ynt zC}4>E_14LpK+wt9-Di0(*elq}9v>muoGzOnl_ba;4H|BIpF(G(!ex(9NcRBl3k#pZ zi-K{;6kNMduo<%9&4qhQahN=UDJ4l*kiDs4oEcb+vZcK>Hi3wx+4{~PQ#QIo8KL67L@k?;480P zrzF=n>O2}M0FfcNe+OM=m4G5AdbGD=!V=h+yuSI`s3~pbY>-g_1WoJh0X%eIzQqG} zyzs0drwjnVDbA=Hh7I!DG0Psp?(yJHugB3Z#{Eu3xdXE&GFmZxck}ZK$|3K*WbKXr zeH(K<0*m{_>yCHeVeQZHVk$Q;hA|A`s?l`auY!dsg8U%}{Hxmtz2$N7IFHE0#H%OT zFRSb&2y_U*Q!W%pBei1Hr^L90V5W|98*P0I75bsEWQfAvgHg^8NhWb$hKGFKy#}qB zP!YqjAI4;z^J^z8+fY$nMPG}!;AG#{zxK&}qco{X+Rcb0{GMuNQodrJ(DuoRqvLJL znS_`=eup*cKHU{7;{3kwXC+%Pm`OI$Ke@Yg)7EIg;PQQ1Xz)I1L1I)~RjZ|C<Z{$Q*ggJx0UCf6Ow|X@#%Xp! z$(7ERObFeMxS!P?vX0#KFXPH&cS=W1uH8Ul?)(5r?MLWE{Sn$TKoB^sgG8A>;oCp~ z(j~ft>ZE4JsMf}cqp-*^Z-Rh$;N;LHGW176#W3+oS;Us_jrETGA=erB%D@a;IL0YB z@BT?H;HCt7IzV$BPB}fFlwS-1v5!GA$y#+S~P^6f-nmYp^G6 zILeDiUJeSSj{%^{vW?Rl1*NK7-=)LmGDHyS6lgMAC>N1hsNon=ZhtB>s&3*>uzXUo zv&vrq2W89!3zpi4lf%64gj!9QQQzj(nrThs*YB=V%$K>~uMtMp%jHR=)vdY%V`^+B z?ujPiY4lFqO*pouGfW~<+Nkn~l;4+1n>4i2%t^E)-;d&wD>OtOy*$B|rwrIYYG;>) zYff83(!Ay(&YzPBQh0Ej0Aq(Jr*sZk#}XcCbXwDr?T_@8_1wz~VDgYuOG%Zr5r> zZs|LBX6_NXJSbN$eV%R;39w5GLNT)n_j$j7c#Lybmfb7>a6^!=$g~9G^juM}+;z!U zESrm|{k-9a@;0h)6i(c67S~UrbM2q>Yo6yie>^7E7a77J33m;s^X&MiJtS!Ejw>3Y zPJV&6`g8**aeUsiTw1*(wH#Lmn84A%(opVikq5%&pKlXV=;mHR4`yW%zhMMv zPoo`PCqaI8-TF3~Ki%${iClj6>oKz1OvBLv<|64FPb|ZesZ}MfURj>UHXwuJ#PLro zWi=Zso)#|$vlW|u67nZf%d;IuLDkDrIT&nL)Ju%eRGY`DAxstmfQ?{g|0fhyJd|Q6 zF8lok99g+&T;mpV;eGvQooFXr5vUtOhvdQYpXO`yIe#7eF>MUn6JEoo|I%TR+G|ev z1t;!xlSmo)mP);FplFoEVkeVyt&A!zB}?96A3|8RYeUrVe;%O@q2H;$3JJ_2KTJfi zAk`b)HiHLfOr_njh#iX0c}~Sv zG|UK?DNWplXS5p)Q@vw7l(U%O4dPs6&RDTy!uH^dmPU__J8o;}KpBD1N}{YDeoX#F zx`L*Cbv<-+1dx%~w#5wHBSs`1ymrV;`l8}KP-iYS#8}3(FpD@~N;$yBf|WfvL!Xu` z8I+IAB8LmA!>op*6ZcI}fY7I=9{F_D#1P)MJ?Ao_l80urR_XtVN032)6byqz!#UG8 z){A-Vg~LwAv$t=t#UpqjB-@iLMnbEh?jHwF<&b~2wu?TGO~#Vx#wYg2n62ewQ=<%} zUs3M4$BDnQVf!0jiQe1x-xg1Bh$JRU*Rm~9Xj>sWGQYhXx`M9OV>D8cwnlx^O zM=sEjko#{oe5sZKGWzCdpTx&0xVN;8;Onfd==7!c0HnU6*?e5iP8FYyfgagtQ4WyT zFEZfNKJjHWL`~_}jyNaLrsz59M3FAw^j=_s=rxm(F~l)h-{Kg5J=uebp7u7gxU$#S z?C~rxAm8~AP*v4e{e2V(P;3syYN9|SS?Cr7v7ZFIe${~h;LOZ!v%D6=g;Eg+!FjUA zo+y(S60kYc1-hW?JUoi&WTq4QGGC)wdOx+ZcOB}S$Nqow1Ov(EYNs$9+X{DZ`^?WU zch-_84#0EbX6UqdoAeY@?vHxid=Jrt+)-lg^i8l}ooTDf9CH}b@DYLeZHvIDOMec3 zm@7xZZ*RacV;mSiaNc8HqK?fqhcG=PZpNaCZ7K&Z{bid7goxkkGIaihNT=zPmMN!K zEIpgnB~@8w@z0JY3s>WsyV5%wrthU;t0&IWH+aZ!=n{5zsD)oOhRaE~W!>w338-I;A`H!zN`>Pe7|UNf+L#F-6C~yo97f>Bhl7PX zBH}|%C;!m^Qm#p^2Ly)b3g)M-1JrTYv22z{Exw}n>37}gO?jY>bz~FqpIVODL<
W@|?F?yN+(qcs zecUW87vXdw*!}kTBuH5d*4`}FV%$N&?VR~#T%&Zhyrw{gd5!qVwZlEHrBH!kfq6m4 zg}<^-u2fxvVzFos_3$RpbedWeTmkq0I^!lJ?AC9YhY? z_!{RQV}Kc{a0VkkGkkYpvKFf%JfJX=bnwaGL3`R`*+Im(%4f56gP?a1W0cEm@Q(1O zUOE;&+HDABFqkFRVe#{F3Zr0NU|O>iKh}q6r8S^;Xr)ZU_(8*@eroX$13>O~sb8%F=tP*29hpf3|YK&B9<%-<_@%NE5L-p&LP84+bB?M;Reu(IN*+wKd9fC zA(a_VJ!m>5pA!$>z~B(0c)U9z7P+}5750@HE#-A2o@r_r*lGh?#V^rPghC8-As zly8;2s<>w;vi<)aULL-ekb_BN8)>))gVc1$j7g}7z3({^Q<&##5LF;Z#4G9v>V6yE zA1bgQ2w-SCX(|3o2B}*4^pg`a?iRp)Eh>>d?uyq5`SVXRsn#se*~wN59o~JFw+g7Y zd4p2e7tY_-sfWdX*}mmfdGQ{HDP>{LIkSN>{@T6Nsl{hR@WvPU%;xK%%_ZO4;nv{v z%lh;Vi~)iA^)C|){_wJ~m2>j;K=rowa7FcY^<$dq0X`5QcB?6tde`%~P_sMzx*`@G z_&WV2(75=pdGTWMywUUIbM^ONykcuQf)CzgeP7cWO%}kFVB6%<5~kOEqM3-+t;jxmSi`0;fT}?y|hqby14v-IQ&! z#8HOkMp=Fk$co6`E~o%?`g!W6zNhTIieplv?{npMejwN8oLi3yKK&&i=le3vFmFLt&bSBfMs&_P8})M zcg?CojE%;s_tF4=)^YE;?(?_nocFS41(knkJ?EzAaJadD3eOPKV?!?tqU4q>bo1fj z$Jn`L?B2{l%0`OCazS2?rIYpOkG^7%t#aeDU+@^3+H|_XEAMpx9KXz3= z!Jw=FHuyfRA+TXOd7+YeQM8lu1vBQQvCU`#luAFCN6(kMhgUp?15}|)m>ZVxR@9>o zs*^PB>?M|!2k0}U)*V$LVebU^jAJqjLBG0K$T1C6@oQJTg4?)FL!4=!0QIW4N2 zQ(?+GLGN1@n!TQnnyaou^OD8=OcrkqkG}rT(07Ho`tl;u{j%0WvGw+{6@T$B_X20l zf5|6+=oMC18DaN!cvxKeMi;mrzy062eO-Xj3o?#K0(aG(QJ%}hZc8rX->$^EA79qd zcluHtv8(_tM4td_;&k`=L{Voab&D!sqxfw&Jm>xWJ4rmQ(}vxU;ad=HI9wEdpryJB zbh;?&HgQK5LY=z=-0lP4i$|7tBQUp~#sDNGlS9}@!sx-5h_Xf+a9o3c;3U{WUaMA$ zGo4hs>ZgNK@MrqAx#LE-4?vvnz12aH3TL2%`LRSW`z|VVe<~dVIZEDzQD0n@XAdqq zuog|)*Kjlrfs%2#Z(LFS2aG)QUVuEwu_z-g1gj4%-2B&iv0X@p^P&`( z?#k|f#6BKK`YDonkh!_402#xQJDf8G=Gk%OLWzS3^MAeNPmGjlgzC0Q4 zS#)B1vO^pQfOx2fN_~NrF{Ht?DZhWVC7_HNX4TUi#K>{d{mLd_{=x#?m;fv2T0`|< z4ZgPE>4GSz5sSGZGLy@|@ke&^@!J!8GqT9}TS%q-&Ki$+kn6kEL$CD%WO1N21UD@- zK>H7o({ik@LsDCHMeZ9Z-bA~p)2;Mfa&48e_4+0Ps8I`5ZUobkyoZjQl*tiAB=RLcdg2nVM@&o>jn+~V*D{U}^s&Ew2nht&Q_YT20uE=_4>Z120AKrECVi+vEKBGy8)W z+hglbc2vm2#_T<=csM)0cshtLxSwumAu7(gnMNzhOOFwaN<#j<;dV|~4&_dEm?2}U zz@U=bcuIbRadEaqd{Dw;6R-BaEuuGN--T6itdd%hY{^gLPG;entW>HcAUQbsS;kJ% zLeA{JuNmye&`i%ZDTxI<W|coy6;8UXFEicvbYT^0_AsfXc)~@O`%0{+t4wKHGiD&+autbM z?!INUadqL~xohDJ+6>O>x|4hAu4umJGv-_L7BQVsw*G+|A#$o4KwryH#>h|KaFywy z(<3nz>LA()=NiVgxF2QYXbk_fqUP9 zyTVhANf#Wz-_`Q+@30XD6ZEM8fsTHLIw-EnZ3{8l0M%A4xzwmcD_4u*0q2w-9py=M zVM!c;gCr`$zETkp3>~dXr4{((r zY!|tPNUmUwlOT=wt=b|U`uSro8K%wztHWedpI#MOR4muWSW`q(_PxXuH@Pv-UV$n2 z_?HNKujJUWYjfI#+#Mb=J_S*C~@Sf(RGpocm z8&-~)N}E(wEvF#R5BO%0AZ?dSjS_@g&KF|t%-8G6)px+sj~mR~#~{Z9+uQ|b5w`8Y zCLe9N8VHIlOv5n7GR&o*pF}G@3Jk(<4o@F7W-H3>mQWN>sdM>dFRaPT(EB4{cLiN( zDh+`p3l|Zk*#n@HVj@_ar&?CRmWt}qPy#O7`0~fE+yct~E}pks%@O=V9os;tZGX|@ z+X+3U%BLPA<|M%|i)zm&tC$K+h}t4WWF;bBTz{en^!SRmX|ch)s5FN(+}iknsHeaCyg?QwJkI8v~>9)YS4Cny=6mt2l31iociO&#mq& zI~j!3(;n0>dkX_s&3yUcA9@PcY+x6@2lQ2+fF+-uFWpC9E}gdSoF0ZupJkqldT!D$ zVfZH!v8M(xG)wfa2fAMo|7*QBx*-V%nm+FmU9t82Wt((^Vc!Dr#s1QX)w7C`gcgE^X67yKJ zqo~=->X-wK;RoX6ifF_7laLf@n%LSf8it=N5R4;7IBIA85NrhkTqYxdw^RpkN}q z@Rn8m(pFVKy;jnb2F%it^i8oR!S35SDQxXH#@;s3MvfLX&$4+$%xvV5 zPcWwQ4j-Ja-m8B;){uRK@1JIokIkG2HRO&|GN+S_#fC!WZeCUlK{JC9C%QL7~#)JT=D=D8b zVCqfGN4gNpMfFVt&&c^oWUk{XD$&Kz21J_A>V{b@4pK6#q0J_B`+a>lH6(d^=fS^B-^VM-wH- zFv(rTsf&l{w3$l6@Jo6ti_^v2d=-I##@gG!TIJgc&gT4Y(RHG&Zr1Whl8ywekLsN9yMU}!+N^cas@MJHN(BTp`IMG@0H|KY1Em~aI@UC3PKLzDSQ9vc*d{8 zxkR4lhrX-Xog#!mDjnSnX>${{9({`aJZPqx-?%_hDZTvH9Pr*tmhn}wT2xEy7+bGd zURUJ=|9n%kLUSE=Js^!`M7B*BskeT@e60;+Y^6&@_VM1=K@-~mRX)r3tAFA`gpvP= zc!N=F{X{bfT@j+50nkRs)vrYv94EY>ZR22Cu5_OA=b>u50ZeT`Jpf|CX150X@J6&< z24GJkQr++qZWlOPs=}YLi6gAKlrmz^4ZvDerEj-x7J{Tet~u`knV zZz8+SRZJH%oK$kP{(Ad(OtN}r3=6Vt!arg3?C@=8b^p$~MC-Xk`|mWV`u|ok?GJpr zc#$W@)XzvpuRm-qt`pa$iX!bpVxk0XfHy^8H)2D*Y|MZQkB|* zWv#V_SM!N_D&4Os7b`?Hh5^ix_M7o5R%0zx=8+$M-IPDPU$Vy`SfDj8tJ8@N?OCZZ zvWeO*`-H(eNL^uMY%$!AAxG0b&X0Dn8v;a>_u>r{e9x-(=O#oR#>}YcciWybh>%eJ@i)b^YNX06~yz{@N=*UPgl|E_P@woWz5F#UhGQ z=bR#M<&$wX7KYRts{GdJH=Rm0zMo?xh@^U>zk(f|H zU^d91C2vCm@*^XDbndukCZs}i4m#`|Y|GAMV$}p?(DUVtTLPF6G%DG-ZXN_eurXC6 zR}fQDnGn-sx24B^S(X6gV8azcy(#bh+V^G3@KBURlQuq&f)Hbb;Njdl`9xV*hqt%E@0I|!aGv#f1CzyIsH z+5gCZ$e34Uh)ADaf4iV)VCWdpee@Lne-o_}`Oft1G;pZq^`vHZ>2=i6 zUqN-lA#sw#<)DJ14k3R?W>Wz{_x}oK){NuCs|^E{9!c{^;>g{*%&`uv8CicG$UMF5 zQE;^LNq&g^L5=&SS`LtDw-{N_mLc`2QJb2?VZfUAksU0-?rZ> z#^we3+w>rm1S{OK%HiZBQrLiZN$K1(+>2awOpVFKpNzQ&! z7x@ON`O341M~`8NMQx&aq$;4|n&pIrZx(@P>_NAJ&jCg>^%a;qPMskwN~$GfA&+hZ z-+uH~E=tbsN8}?T(=JN30i@wF${6{k;1&-RG!9G}Se0H#PaSU``l0fu6YWLR3RuE* z##gli>#q2wi7UZOc0*w`je}Pzx2_1!Z!1r4W^3&?)o3qqoQFnVF0Q*>dK3zCW>$^8 zhcakUCGsQj4Hz@fxP#Uc4nJo>WGxHA7;WLKlejp=81Pw|bRPY~e? zWP~_>V<)lBDq2uO%Y}^QUUIL`yOs<*rZ|UREMsh9Y=k{x1ddj}4|K3@#1Ct2FYu19^9JrLrd(y!h@Cd$$ z6M4Y`c}!W26sy+h%(+dU*who%i+KMvIu?v<__j)hh2V*rA3QbYgxXZ*_8d0AH+ zTGAnt{==E^dkcCE&epgimf76JO3Ha>OE079VW`;SG(>X1Ez*!u- zNvb(8eH&Aa{NlDwI3s8x=l$c`&Dr~~Q#GHa_hEiCZmlCj7Tg^&pGl_?xWQno{{flO zZy{5|h98Cux3r$3r9<+b#_|U9o*O=rxOZNmY&@U)?flOwk8}Tb>;B427a@DQao^nN zzlr#Gp7(f=kLZ4vx&(f>^ZR%gHuYNLeOY3Vz7wDTnYcME_Ii!|W%6s7O_q@vR~fYMk!sWLC4sixUtlf`OzR7ZVcDl0Bj4Jtet0?E* z6j9}Xr~VWd*ANphTjnc*p_8|3Ff8-m)|bXK@`_nh?h$Jrod6N_w=N-U5@^*wF$}NB zij-7tGQ(1Fy4ui1=e^?|zk0=`U}-9IWweYTqoLVii~@!_@=dOAOeDzi0|!yjR8dGv zAB;0z^-x~FkW7q{0RMsYwa1#J(6(Gf8!h(~Z-r>=Il2Qal>Yg@=UN&6D6{?b_@;9m za>vuHwWtAuNo}1`J)nM3jQtspuI#}h8HK}~0v|^!zpo6jN3bkz>!2W+7-4<=+n>*S z_=^_#dZ<#nD;%kzwv^gu_@Pow^Z`47|t#EN!aX} zE~(FvJ8Ya~85Nsk3s4!y3TL&?m9+i6z_)yLtt;H-?@4JH?FLI)N#;MttyE|M=37grvTx<>b|d(dDM7EUmRu9(5nL2|8}b#B&{$hl4F*M>SlQsQ#`s*U^G81 z>q|`|S4J}zcVOpD5M8PNy34M7^6v-;qLF?^dG3_%i^^1^yD8liZaTn&ec>wRK?RMv z&g)NvSA=XtRq%!MnX)@uaIi0?O}hmLZC^;Dq925!9YCjCwa0V3a*`H{ei_XI-JiDn zejUx0`j3cj)CGKx?F#bM4m5LFkCdF7i(=j5f@vEoPwyIYHKB%%zeP`}m6}IQukJYO zXR!S#Sfb5JOgrD)=>B`N_Tld2?|XZ@*!a%Lbyp0sr15dTa^mD7m)Wv1qqD-FiqB7D zx0wONEGgUg(U|47_@pXs7?>;f6)9?tUiy0B{ukpPgo^;ctsER6mN6EtyX(DF$JWa) ze5)}smWg1Be(Q~8L@evuA>@*U%P9oy7rJmQG!eiRC3Q{eW!tnI z0oBT3gU>HqF{40l`?hoJXMuL_5tfZv{DaU^%h=s+r@Nv6TMM&i2MAt6d< zSamzZ{z!ZpLn`ZppQ@b8yK$+MclWNUMCocuweNr>d2`mlfVR7*Or3YnM**cNr_?ML zOavI)vxnrLRZ+(_^*@HcO@$)M02S_EW_(A+n>!=VuEqQk^fuQV5M9DP5c#a`*7fKVFg6L^!J&_e553ky`CH0nhv{Wo!3|e}BaCumxi%AG{ z>r6Ot!`7UjZ+$f!Pdl(<{={|}`G(MtnhM^X?$n;wvJ0jVC=Q)KHI-sY8{2d@Qhdr_ zG4UAjCT*R1uAZgFA7)H}ud-fH7VH+f>n4BbFRoOTtFs>nivRd)EH!F(;k0w~?#xUA zjYq+dG%)bg97enQOHQk!ocfu|4I`Gj+cMkBLtHvuh4`y=uh{(1Vs{j6m@1QPRE84S z+x%*cWdIe+p*gB8HK&oK!2_Wyy1X(o!-QOKJ72Fk_k>oIqj<;SU=WJD6MG^BXQCw% z+dQL$kHw3FcSSOA-=FBAvxkPL1{pCwW6Hrp#ba19H;g%Aadm|gw#Vl*@?r;=hc)5y zN!6)owW((Og+SlNYDpV1X$M@H1Y9@U?ET{F`SGVB_O2Vd-wF6cz;pD7HRMpYd~-+< z@;J>a(x1wVU-cby>zG`i%4w=fNo|RwkiS_aQ*fI+ag48%6xWQzj$~rf?vO#BUm}HW zYaFD5yT)o6>v55Y$82*50b=-pE08nH$Cq|ug1*S8aCGsaGE`eOohh!h^PQjZFn@!% zyw@1g>CE)%LH2BptMj$^@ML1{`my?n*LR%oS#VGLeE~P!WhBw6>9A{ix!|z0?P+m+ z@wM(M=W^PQ_gv@DN*696d?;#m@6%mh3+YFWRCS9d{e+WJMk%RBHSSx6Rb<4%pIT- zeZX%<(r@LnW4tudE9K07SD$%o=P1YC7nO3XYkW2)B3%*-SIDRuhl)wjfA&zIEUN*R zkf|vk5Ocv)*DnceZG-ty$@5zc|A;W1m*fG)NR*Q1xm~F>tXsIAvV>!iYjqbztdiACI})2}QfT2G=!y z3>w~-WIV2jSn^1w>N>?9bPVSHX!**dG$8g;b{!EJl{!5J?`@|P6R2^Gn5e2{{|h!2 zVmL8rd(6ou4)+Z0@r49gf#{x% zdhvTs2za!`z=`m2o&j9(RwOC9mMRX|YklTleD_;Gs*uxHTtO9D8gJN!7)vg}KlTKc z;Rt$E3ABXm1DitPTNQl_*Dq&(Z|rfTo$?Et(#w{gAa5xHU4SH95qQP@*q%<``DqBqb1m()xPJ}f)1xBnAhgLBG1p3T{ zrYOu!mi2WnG*Q28dtP=`paEdA1^mIS+km6 zYW8~WhC+l$#U~d)gzlj@9LI$QdAeIe3xKihRxE{2GGmGv0FsFywL{q&d#}!;%dr6S zdz)s=Kdnm@%l5QU)E5m{9 zi(c{Am@nm6A>t`Iwo5@}doy&);+IiugofE`TW3BYf>uyR9GuVN@R*PneyfpUcg?T=i7N)EA$go$fp==BvTj z##|?%atujK!T4PLUX+d8?v1lF7$R7SJdC+_=G)2rw>1x7;N+v`M7;c^1GS0dUT|l` zNC!HnC)%kl$4#NtkFoM4@9o)q!Bbp-waxqVeI6&HXdIaLyiJZAJ)|~ zzCM`a;Hg0)+Af)h$5BXgsV>6HHPb=1ahD+@3fn*%TRWN>Dbc};Y$Wz3EJ$zDebR0*Z2(&qaeqP9x9yv`KrVc8VST+(YkO7*0QZBn0Ab%k6Z99q&Vt z0N&tFM5_d!2BQT~8FI=>gNsNYYi|g4Jv?^OLB#|-bt&42gFKG(my&4+oevMlMK=LT ze(7;DiombQ;!}*2{r$DdhD_;)6*MxlM>l{aJp7~OlPQB$WKPph>!uA1s#U&Vtb9Ek zWoP7ia~&HOAT@wz+CEMQZToC-{Ln~XUo?K>XA~fDp;8IPP~R!MN{3z^Lx5)-!nFJx z2`1zZW8%-((NE{WlZ#0~Q}JG8^`pj7QC69$9uhrjXDF*8t{}uw;7JQrI<_ zDVPWWx~@+p<7x^<_V$Oxj+CEMJF7<9E+rPAqLP&7oW2J&iV&Mg2*6$;U%1ZunQVLb z{SE4J`A--%Z68r5By)jge^z4ApaEt^EYSVZGgO(VjU}reGRO!o+v4Ygn8?H0T8E!uwx4h0q;7L|(H zf{U|lQiw#Cc=S5;&X&^ao7JP56TyYtmLA~ee z?G(-ukrH|*PnCQ7g9A&ZGc`K7od8tghFOOTa+{18p#t;l-UNC5pkIZb#giuDCwXkC z`W4b!++~RJ^^zr6Y`=Zy=|7PVAnyRNk^VCcb!*=0l?S&ixM(cPHt*in&pUn8(!ib> zztpt@vae?YdU38Cf2T`F?IzSBkgh?{ger*AhLh1X^c`C4DVE_vABJI}O9=A}_UrZ| z<@9-cNjfKLg?(n1CtBH!A!S2@97wW3txhwU(xEBOK_c;)u5N zT`gT5pX=4(K3!2-v(W-LU6zO)>@_||p$>!s2g+EOuD6us zWxvz(oYj0x)!BC!jvF(-q_#ZQTQofPr)tB17T?%o-=fp_Z!!b4#TT#=Ts)TZ7fJAp zKm^V~11D6f*%f75ALTNRgMZ*1RJ6P=rVj_or0|BK6e zGM+#^DfGI|d=ckC;)I^5kcdMNi&#cW%#wETubZ2Ps6mPV5t&W7qYB$ZRW z6i6sjM12twq&z-Lm6WDL4zh)c%0`v>$yJ&Up)PMq_24`K{>UW&Jb2M}>!f32QwvmO zp5E<^rxrH;%FfhgL9qG>mFdS-3gurFIcwZ0%e-corw)T!nS1u z5sL~}fBi8Y({`O&g=6%idAl8JMD&sTrg2FQnnCL^l8=!b9#cqiOb5NZE9SL4y ztl_feRm7#-D##7(5>Um4Lo%LrJav}jIg?JQM_>vYq=<8v&l*0^3BzrYs((}2LsjUq z8tDSae>^x~_*i%F_P<4(b^e)b2fJmzzkkBik*-bE20f;TmjcQ}TZkhQ2lCT$BH|j7uKK4saJvJ>>*g2VnN>6p3 zJh!aY)rh)(P>aK&BwsI*(?sw#ku^SYo&&`=#SiG~0^^1A*bNDy$>`KFm4Q;=%;(@y zPlK7P+`a8i>Zm0COCxaNZCuj1DoI6!>))=He2Dm)SDTD4GmIy)LeD&ozKm08tWE00DUo97FLOOn=V_r$@=b zgdv&rPVs(Wbk+Ts4|+Ywk;9td>aut{(f<4S;d_7Qm>x-g6QZAlYtDTcWf#&MQ__R5#l`sa{@k)K{R6Ch%!$sS;SIKUEbrtkG@rIdNIN<;5n zH$y@jW3TPa-#;S`#DLeK)PKvhP#7poj#_{8Mw_5O02nY`HZvx)L>a}#!k{6(-$c)_ zEwRCqdm)ajHlBI2Op~K8Ei0Mh5Bc zzH!^J)eod>sQ~Edjb5R$_+q*8N!0O^G4w@KyQ)8D1IGnc5TLoZIIIuKFou>l0?J-7 zIT(6Qa0PZ{ldU3q9|Xb_wuGgK=5(lPZjd+a3*h7nM#m+2^jM^#*$hXGWd5Yny3YiV zs~`aD0_o9n*|0bF_}FyIH2vX5!MGJ6NllOVJ)}Tor=*pxG63x#gN}K1_AtOn zM%sN5WZ({bF&*-wjPL?zzVCly`gZY`RO4{sGwm(UR4+_%OVYprv*4RH7Nihum;G=|2*VE-GKjb9*9P z4I$!*DLl^7T{DWy3ou(Vp@|@W%-XMVrp6^ZU!yg{wy;7npe0%$A(>nchV3F{90sPu~N0>Ct)-d|4m2yTx$pAn& zq5OLSp-VW|tQ99LnwpbpH5Z34`z9S{7VGP&E7sRBnX#>XRU#Zy3!C@36@&oDdmN}7 z(MnoHl3@DR_gq`b8G)_DWD7*xV&+2;rP7!RKN*mD8;9LA9OdpL+82jJSj%Zf^CO-s zH*qH(X?c2BNo?!F&!enXOD`hzgd^RR(G{K8U*~rUs~l=0GqU0hb2-k0_}n2oI{(N@ z8y45EgL1fHZ+89@fDRBSFI}f)J6C@bqEg>Dp>ZZGMo zA!qi#n2y)QHD&H?X)=IllDM+dAGt^q52;s~Bl9;>1ffM|(5&7iL(+2*JT`&?VpScpN=6-{Oi*+V!>} z6_d3lO~-&MEu5QO^P6P9eVaZ)Q}%|j*pa@`M={SSF*M->*UCfxsLO}4j-%W}*S}q2 z_29$7&)#)aXMFO^D;V>D(MF-EgpP+dmk!`_v+I>0sYikm{7)Daj8RzrxSPS41W9wh ztW&|{pXzBeTX7OW#Yikp_}PDmKgrdI>iEo<`|%H9I3wXlOg`-QUUBF}JyTqWG)x zvV7ThQ8dA5rgXBrw;2n&xPLoLe6Q`HqW24G$p3O1H|+m*l|NTsXCDOxCF--0R{p=u z*m<+;isX9d6N*j3UE;vBnB3}7CWDJEX9Lw`wGQ`S{`9$)agdq|s9cLv>xLtJcmcw= z+^&-)#%HC=Z>EB!R}mGcs2p#83ewit8*kYm%p#X@sNfQUMH}bYlGV3viuqRN)KIlt znW)!0WTo6;>4Xcd=9x9gpu$N%5%>a>9r3@UP){A`MbEV$i?caUYegCLfA;NdFN{0i zgwZ$A`8fZ4450q@Z9d0({_x>z6$x)%OC@sJdk1J+H4u?IUwgV4gp*9C%6_(Bx$t%z zcBZA;Hg?y8{B2h{RKEie$&`FtqmwZ_<#2ef*RzPx_4Aj0!6HSMcI)v2l9pJtTqJ^o zlbnMNC_${$LOjb)S@|Wvm1O`FEq4hx9_i}02A)N0>A{5Ch#<1P>0`vH$f(E}I7R!X z#_KC=TB#YCGIW+3alh?YvClD4C5(PVQnK(S9(&6PfigIBi_`lEB08fz6#A39bvE(Q zjx_LlnCypn$&(beD?$sL8r2pO<_u zaAZOLF>H0+`rlF)aFVU-M{{oplzJYGwd%Id{IILTaApHdGsQPb=c(mXq{h^K|oRL@AI?727rDOXB0jylRp=7)Gt?mTI{_xgPSCjyCbvjn4H0U{NkD__l53aTsK#%H_w~$+9b@lz~eKb z#hJs^x-^@$k@HqVA_Rz@zJby@q3IZ|pSQaMhzs5Q|1NamW4j8AuOzIlgg7LXm8#jO z89LY@h1>C{ZDUBbN+_h`4*r#x9*GZU*M>%-VB12_po3+4s0k8b%B(guK|PLxO=^^% zkUU(^;>}+dAWee`;3L5`T&1^P!4gGxuua%p-D#E0J!zh@eQ7Ch_S^cQvv5PTmWL*g zK~Pt)2E}{!JS@+wAV3T}e38Kc|uqkfH6n$KQe);{hR8E#&2*6?R58934Iv<*` z*GF;|ap%{r&U=**Yz`S4SJ^XjnEYSyd?Rx(;V>!h`iFi45!Uf*oB&6mEKX*$ zyn>)%7SiRhVUVTAOjKF~O&x4!FhGogQzdzbfNI;Z-!Tqu@)~9ly(@Ugj@lP<+OvXN z?==_-;k?rQWd+xCHO!BeE6(Lw;%ylUm{;wX>(*}hV``R-D|k6dP?a8n(F9`G=?P%w zLakV5+A0&^5p~v^F|X==KcUd{9Mnk~#sSJu+zRpE#YONHVnT;e(D!~TKk&s! zea`<;fKKgfk@neK_x&79dC5 zZyPR=CNeeroBiKZfYZ8AQElDF4)gBjmcG#i*8^1mpW5^LP|p_m$!2CwZBQCO`3`7N zLatu$u>>>Z6?z(5JdzfH053(mE`Kq4Q^&a|4}^Yva>{ZM*OU&3o(a($f+67Quww?~{RWmM$O zFaUEhs=E4Wt7Q@j?8iu0jv=Y4n+C=jo|U%J!JGlv`@HNvb8k-FwW}}xYC0{ZuQv6L z`%Y_%8g!hPT61}8>&9L)v{Xt{buGm@|K~2u;wv5cx4lEVlp#lUmu%EVzNdRv%@)wS zrtT>b2Qt?I|Dnrh(8NOmd{5|AZvAUQw;89)Zi*r^%nOF|0`T0;V_2=T)q-@Oy6xpU zmb1IUh`v$Nzy~WMu{9%XaQ7g*JVd&YCN=Xy=lgoo3ZJhRGQ1IYBzQw+{0>X_Ak1s&r>+@j|T z9FxmfSx5uMiNg7qxpNBDqgPc^O~GuPpQ}?M-mw8Sv3(ms71yVHYET;p-HA?*Qx|}< zH-^HRxkja+Q3StL0?q0EujlGV+)C*_b6XYQ%!+T;3Lo9uFNpU?nDPVyc2Cl4gfxVr zs?0FY{2^Z9{4Y#Ir* zkHDR0`XMgzJ3_NKQ@LXTuMYBa6p;f@@lN%GocdzG&CvhTJk292$Q9#gpn8tY)sd%**?>s>`O!G+UBb9+sw-ONSkayF6eKglX!6C`5J>^Q#nGHla%+qS6c3(gZX0FRiNL;kq z-kBt3FU{A*^1uVATp$DCtvQzHE>(r2xpE83(9lcDe8<*iGG;gHD?8JUZ%t1$S;Bqe z=7%5L5{cP+W_IY`9WHD(m!xgaa|zSbkR-Y?J0!vthdH+oww}d|5}}IhI{m)piJ%lW z31^vzgg0?JP^z9{s%4PA(@oQheZV@vU_oZspv;RMY8(cMb!DeJ*xiwki)J`EI4iE) zqssQT&l)l1BcIOf*yn89okL_Ma1T38htr`AaU1I-!_~ChRsE=HsJ5L5Ky1R__abG` z4U^1#r~_wrN-*6-%uA;?%ZM&hVdz|Of`;%VdY8|S`7Go(V|&BNkCQJr-0izdtHPLj zpCqYtJCk_R@bSp+bamkBQRwn(-IQ&iOfG~x%d0+Ymt$^_(Gw*@ncVsP;kk8p0Uq5{ zhw5o_KlIQ`qCS2)2ENTXy6F8?P!a&)R|ASOZHCs#WOVfjjt7I)%t3k+cIV#C!xDxjreUn9XVy4#p=?dvsr7U60Q_a=5(Ci&2s(Z1`~n5G1OD*!jc=nlq;Myz zqcm&k&iYlt|65gN7t`cmcX`j|J4B@LMQL8j$X@auif4Syv~lkuF*wFc8H9 zEoq*1_ge18o1fYC({e->w3Ym18QTDaj{y)$cEV`0JA;H{AvNn)K-`PkiE){pe!fHN zA0N|9a=i9+S)DsJ3CjJ#wZ24S+Ot5)A-wQ<=~uYrgl_smGM#w_&)5Y(I#}a{a#iF_ z7-#oaXu17lL7hncwnt>II_7vI?7KWMw;H!x=MMsSfveL$kEU^|Db6lcWnJe6gXTub zN@q4Bm}@?`Og^>xTqU3;WAoqF&tzIO&gViA*4aaK>~K5B4*ijd)7q#;ay`fPBL)4p zNDKuoH3EtMdm9=Rj2cLGv+w`aoHZiQW?dsH+=0keqIXkq_-HFo)1HN~~ z)YN8cWjqq=@*Yd2n$3!(R0~)Y@2ekG`fjtCWB-A_mm3;g+dd~sWoZB{v_k8!n_~ir z%{UbDw87Qz?}}Fl`P&!VcA5&%vd$e;+;)TIRv+6$pjq@(x!261=P0;$B4GLKxu4dn zj`Zh$Rz}GK&PqBS(wrDT*N7OG1V;=a4S)Gz$=DlXLIuyHQ0Rpx{xoXN=g_AkkK8?d z#Xi4Rih%{Iy{D^Y#U}Uu(?D6g;nR>+q}iuG0ejQHiXG&Cyco`)XQ=kE;|d~wqXJsB zz@D0~h&2e3gElN+Un!(R;rG>VbD$gHl%K;{?3=JJ9p9nkMiP+%G86yoIlE#N4;0ez z8?r{Y>W1QVp9MT@5o9X&GRVg0wN8taR{^yhzGqhKQMt)EKA|WBp}j=se2srrbShG8 zYjm0cVkEFpo%Xfrooo!3v|rtG*4^J%0^UThQRTaktAUbN0>YnX6dkbBIftaSOOUC6 z`<4o6o`%`bg-9m8trHTz>`acS$r?4=-I#alLx1{|SZ&9WHwLjpQ)~|?#DlS9mFo%ve}A|a6+HLrf(&xaY}KDj zggX)qR*24GggA=+%~9Iy?HlH(&hx!)bj#)to^+Bn7K&AYhR~Sq5Vt*kt9 zK7X)J&Tae9a~o71Y*-|b-jsaF-QuhcAe|BaF<~Umwc=>rh@9@k-i+({PMy~cd%Oym z;k3-C&n1H{XHuh7F!jR<e3y=pjX2g zmsWqsIV7@uE><$jZ%Cah86`80{ce&TetqnNTS#otfy1-cKHJJi9oR6C{Bi^59dVHO z`)8|N-c3AbYCI_i9rM^wkdKG_4vM5%A`fDH@kcufDdb3a{~Np2v;Xo`id{8!GA2IE1U)#Eg$iA( zwLM&(NUEjqp811YI`Kfe=qsf0s?er3EA5eSfvinp48lxIgYlV464y2{levhq+aPHg zm)maP?GZ-T-%n*IT(EmU+OStu8^`f)u%Yn9>x@6yS@Xa1H){COT5#NHqUL*kTfa*C z?=Y$EX_K`lGTTHx(T1NT zJYRoU_XY1cHnKrW7Ty(qn!(`TQx2uNDrqi^Fs4qs4CcG92xqa%zTST=JnA9!cRzjQ z_dsr#QnDBSB!yJdCLmz#O1YMCB1 z;q3$0cA>tEpoN=@zO)NX;E7zASo-SaA2ro5({VQD5dy6hC*xFrUIR-&uzC}H13@1t z-Oxz@n>vsqFBLtXpm`OF{8gG7l?@mcj@$cGIv+}VP5%m&&GezMH%{+E_Sd3B_DmAA zi~(FjWX5{e-+3_R(og~?T!8P_F<*!G7);Aa)vPXyB)If4d7U~nz@z}c5&;}Of$i6l z*!M<*vrvpspPnFL?NnVSU1B+-h$mp?OGOfESF~eB&`Z>jR0R>CR4bV)+I5y1%H6>$S>~5a@PlFP^j{#d8n1t-^cDWFqRC2 z56tGXw@GWNXn2XIIwfi8i`3f@jS?IE2rW!pcY}#CC9a~!dyeA%**#EGr?&v@3WC!Ad6n+}ltc$6y>E^xAh1SjFQ#j!1WtQ zdS&$8Vin6z{Is%Kk$H?~ebj1Vw?IO0=W!8L+pS{~w^Ok{k$w1uSI6^}hfV#? zhScqqwU=EFsPaN7TbYH^kX|bHx_2Z_tr3}DF{iGi0ExoViJ-QXRpF$GNq@gyZ_=$% z#NN{x@Oj&pGllU{i)u?LN*GrciJnI5^_Sk!jbNuCYSH=DRcw#I3d_v9Ccr1e8MSEg z@e^O}P4ocKiaR9L-qyu}Y$au{AYrp|;AGL{?@2DFav}A1 zeZW6oX%>bS>S@ReQM=B z!aNHjsIgpVSxk0v`Fho%27Nd}7DeXiR7zX2z~M*{$6`Jlr8255xQ*16)R&)J1;|Qe zZjSHKE2eJ=K9t9ExR(L!g!S{0*sOZOdG%Nx;auyM6ZHY=B%&xS3-NmWY429PUVrrh ze;Y`;jc>Pj74rHyvdgD)C0)(+hhvk9u4{ej$;xKoez+Ussk0=r^FU&agKjd_V(2E2 z?YXz@O$!p28ax6_63X_>J$Uewf}IF0T2LiJ_;3(ju}kFeD@S9?cG9Fy?JyM<9nh z3ko`1H1)Ohtjj3zIKRgt7V?eG0e@2KWGX37 zTyYbsJRD#;RcK{AHD@liFG-akAUQVEW|BZag|zHTh{pzJ(>OFqp0+^&{@_DTG`w z9>3V~-lTXf$aVTnDb7tD4N!6I_w)AXy?F1_Z)31+F{m8-z)7ns4VL$3c0UO5lubN> z3iw{5wWDJ`d9VtN%S$)6Bu*Jb#2r!=;oOT65_4^g_Ap8 zy>&=R%Ch}^qe+=XDB@K<{C-)2Y=~ZpS+X}#{ne3Rd~|se85@Z0Rnq(~__Xtt7Jxv% z$U+YQ1OQYNWc75C58Y5_hfe8P1{tN1JCI)qx@;-U&f{8XJ4e;v*d}lcz+LFnncm4 zxtg3g`7}M4WHsG7`)a1w`#!DIIq?rKJo;e3+LmIi{&T`kHvVSt7_uP0`!kECy1rXV z-Axa1s}{OjB^Ln0onq(;Nc=jSa4CTpYM3i!pbN(&j!6ra}ln+9l>@aO6y1dxq=$p#W79ekcD-qb5# zDeombc>5TmmJrTLko|q#XAP?g>wBt_9R7Kvyc2=f0A$UlM3pKz5>YCeOcL9@gI5p6 zfMX#A2oZA$601olp-$qTAfRT75KqTRKlX@sX4_giQ3Dg3I!L6Aeh&@IYOHaMw76dB z?{mj(c0-}{A{)lqHH@eaTWgFhZZ=i7#(wemag3xRy!uTnKUmflWMcA#8`;C>&b@yF zwIwT`PLyRa>*$3C%sQ|R+xAsQP;fNrTZ}d2@b*Ajxp`dD9^k2J&44ONJiS;|=xSL+bhlx63>dTFNLXXAhSClj-#a2@M^Gv) zd8(fCxTW(@qXkldiOAn(VAC-H;fB7oVMFJs=)Se-IE22B=x*L3I6@nG#5me#o^QQ5 z->Cqv`M~gZ8J%VB`iWH#0ZDh?Lf>}h@+x8uO-XxJ7JNDIYl|T`Yw_) z`Wn|?J3_?w@Pm1jb!;#l07)q(@4-|*JKT8Zo7EetTSFCEx9~)Op^q9G83Y?)qGZ~N zmx}~3Y%3`H+@e)AB|O_O`Rd-pq9@0A*@_eqi+c(Vps=zdK?1&|(o^+dJ>mkXniyDN zyZ-%@F_l>+!H{kn4KEn??UQ8EOcBcKDkKg$a-H_MR&uoMI#rv^txSr~lf5HJ6B_)N z9ARvI5%%P8fP)-CU^PU|pR(K47*VEYTh+y15xzx#Hzfa)9bNdZq{BnEck9Wf8aOMe z2$}hQ=B<5w(r(49zpK5+Jh{Y?U+a+@h$ujCFu(P|RutxZzo}~Q>wHt~_maihy8s!Jh6(O(ru_HmIelGOu)Kh$_r7&4Q~z&xya}Ry z7}HoP7z;(v&=DEld)3P9c$kt2huPS|TJ|B1;grJ}Gl)>6h=-kXzy(z#B2cxlHf+FO zcreQmK<_P2ox%C;kJP%}-sB%$=1apK?Xf!jzdzdxF3%{iQO7CO$|@CEUrjiLv^5U~ zvN0yp%gV6X#-fVR<4Dv zjrg#A8oIJ*QAc|fP*`(OIi~U)F=3Z}C(rZZ!;|t`p4p*9G^7k8P^qLp0h!%rw!^ zIDkJHmDM8^oNaEN(T4w!ezjX^+- zuy``4wHtBPD2|EOi2o65T{H6$Zauf?tL(rD`gXZX%L6&PNY&;s*%qr%b$db*T$2~Y z6x=8U)HsRtQ{P?QxuYPDZF3ia61YX(vaKIBfww%CovQ=bgu{<`d4o74;K{3`2fvu# zT|5_A))ba6mX+QEhnS%bp z4Cn!M${l&I2c29D+!GKk{yjmDVKzO5K3JqEo!;yjpNL*tkYJu+SL*EuL{gs(TpXFj zgqweyGvc2flSS{U%R)v@MC|aYJ58NCCe6XG`kY=h#)Rq81jE9TTtIoWmIN%jz=fPx zbY@3sko(y^rk?94z%<}=@9&t{&`i-j4plq++m4uooBE%8}*E1UJ+?Oqe4KC$4fBZeF zu0doI!|ZWR(1D{=9D+x~qLC_+1%1AFH#SKRbS*$8rQ+(cS+)170`V&UE zEdyXEibB>-S1}A?nc9!$o@I0C_(>*Ft_;vHOZT+g4v%5sCQ8u|k#%oF2-FvkPAFxv zZe3HU#BgpquHmDriqxSdN{gSNvg~Xd)_jCsJeU<%StJbAHojB?EGjYLAsY}P7Yyt? z=rQ8)lwJc`5p4Lv8hcah;>6QXA(cw+EEAaF;$Vmaao|foNvw5j+F|sFsXyIeR2J27&794@E@Y8(k~tQg~>!w2JHY>NI{Nd`34hT1r@!vCBc*o!ktZ_ zd?6<iHr{MT51vADfc$Inh*u(e!V}58Xk0)BVy(q5e#7bXieHH4>y6 zG(YRBW%4>;{3k?Y6!D$28#^`W)PO_{&2v)E?dpUeoNEsmi=pmGd;+Si3BLwVv^X-r zk-emgu-FrYB@nB87ZAFlx@Ep{AR-A$Bw)(p!lL@+|7fB6zTQI9)=Lx}R{uIyjU!2X zKe<|}!(tFr%#osW{$p9G0$QTbVDU_WVOP0Jrk!w0rgWD99W%$?4u+g_@fQ|ph<^n( z+73Sa^-T*)HQaFvgIQFgk@f0+;l)fih^Y=4%rcpRnU)QnItRWoF3{=PIR+u;oj}(x< zAlDa6HKwTIJ`o?%ypAP;07W4bn27%FIf}ZlLo`bgD>*FcU5W#i;mz6m^$A#8zwd*} zrY*vRA}`b7 zLLvV{^+*fdp!?nxkPaC`fIdLu=OSZ70Nx1te@xJbkBgp{%d3F*oDW)~fHybazqo$r z@1vyuC_NuLHnWDGq~CuAKnFZtzOS3EK_GmWgNp%gi#fhR7r*(pjl7s~J-HpI#;|CY z6tLBj*j4g9yLR3D7lgp65HkGXCq7mjw;@nWYD2qgWZXobAyf8mtt4QS9H=hMNYC@u zOgqPQ%&v}=_OanSv?Oc(Hd)!+)=cxH+|1!OB6;jc3A%HjN>bfrl>yyeb6YN#RC&dO zp+VKv&tk?tK`-DTY*V75|;WGJn-9 zp8%_Qfxe@@w-=+3tJKey(_JR6h-06`;Sl{~Ko25czqfyosHQ>AhA}w*HIDSDvpsDk zwQuZ~UI^4*M4tVzpCET81M&V#v(@{=;S{&B!eUs=MdSI>imZ~dB!~8k7ErHHuU)t? z>tgO$!bwQSI~u$$6n#OKc+yW1ImxeL0E_Rf>yD0Mea5_Ova9HkO2LXOJkWA&dLfL8 zRBYrrrX}R|aQiZ9Q5qr$kj$6PD*7Q;_HkIpCLp3b0iebqL7z<1zM8jsE~4wyZm0rs z?Xm`AXYHh+skBAbIv zOmJ>{QIVm>K;9C+xAhanIP`%v(d^`GDW-`EUnxy~I!JEdxl$*?x0U6}764bU5PBLK zvxq{tsyt=piZZninBLuEBe`t?2Z@HC_J$M?e_ISTyY^ufJz+P3Lh*1AI=h98B9QFX zHE3mhomlAdRtVMjU$XHJ)|Yx##*>T7%j;d?_uZ#rhi4k!jYB^Yq3a*YTbr-kABw)e zQ+*NsWAs#7z5EYa>eqAmSoF4NSEFjf-dXHiTm@C93RvNubc>_-)w0&=5%ssP3&Nt9bph5_a_lz z7h=3C6+F{TXi1O-s15#_zIEQ%P|00ttBUTaKJ zLnx*y-7(D6kP;EYoPbso%er9^*pEv51FdB0<6t)h`G&EBXlCQL;mdyO@FRAg?PK;j z$>Ep6$Ln}Ac;b~ih8Uc~2etmiu0KYr7^R^`@MzohKhj^n@YF3MD3edKWZPp4EvdW< z;VhS&m||N@JL6J!Nj$FZD8wMg#V3)M7)vax#+EmKXjC4C`l$Lh z{U%dx4?`PdevF9u%p9D6PgE95&I0X>#5}wNXsBQ-0jV@rRsj=J-Eqfp!WC=60H5urqr`|av8bgIHj#l7o&683axfJZ zV8NAuzW|diFF<>!Th>B~aqaGFCxlL{w>>u4IH;G$O&|sAnzE3(pC&lE=Zr>qWUC6g z>37}U7Hx0&;X@KI9e?Rp8NDI@umz+e6~emwN1v0o#Cj-Up{q#TjN@T-|9E`Rqnxo7 z5_+Hbq{uQbzT)7GU^KQ_!{ieopP5B;cKsfrylu)F1d(YzSzQKvyj;IO_VnmIKA+a? z*1RTjI8H(O4+4n(1O0iwzW(_T*m&7^EWY{g&CB`sE9c+QZ5t`^oz4c6{-@z|Qg(;i z=_G&VUzO1+tkm^@8Huo!zh2H|Ogm0iKgU~WwT2Zhgy*6=GdRciRB2W#&$-7LFBO}s|^tfsOrc3 zd_q7IY=N&x0L+aewid%5(;-F1EBc#ZOSkvCAXa%pBS{R?JP(07z*JkHjFQY0GA+AL zSj?j3?8F?GWz{>WqWg^{YVbWL2iy5ricfdrg-5_y()@5;iUnq54YcN9{AbXJ6qPt` zFX@Vj#MFkq*u4%fMFPOKRAT8uS;?0@A0;P|`!TH^BT4gtDEy8o%-r)jJ1*+=B<5rv z@vAY)Wb8jQ< zUijPGi94>{hTM4Q$eDOcDMY$`Yy3WT*t<`(Jb5tTe8Dtc@geYMZf3>b39cE_OUgiBni49pe(jBT{l=3RiB+G5{Xiv z8cxtcS$WvYa9Jv<`S>@O3D<1@=H^Lm@=t-geZ5_SqD5a7qeT^=n_b_#`wD%$X$2vO z+oS(UUdt!0ZGPjD8j{2&lXq)e|G}|5Y$eEeYJ(|WG&9QsIR25T17-D^>Rg88N0+FuBoyXGVV8)fDif22-xYS_>CmWY;VIb2Jy)d-1yJ8W4QtDLjLr zcj;s6CkeyK_3GsB)i*Zh$c@%a`VF2DzWEYYr2q>p0-rFNM%8OzEu~G5KgS}yoYOIi z1}PFYO3KM6T27CF72&jSh;S-QOotXxOB!`AQyRqcj^~gV#*-6H;Dc91cBELHTv=FI ze3mh(XB*a9lrs$E)5M-gA6FSq4jZxjT8<7P4;$ZC&te!MccfUE@em4*=P>$%!9>p6 zC5|8`qh_43NaZCsY02OyQ7+3vg-@gZ>ya&j1ES!BiKD0Q&W%B&p`>kTWJpUY_r`Mn z{((WQdBtSj*n04=>e=Ji$K;p-OKP7fkm`{nJ4L5C)okW^@z9!XY#?p?meCD#2RLCc zXFVW~o%D0b&^dgoB+a2L$vsHM>DSVjlbUteadPh1T*vaT00~=l z?z!_2UAY-Qx;?G;ym;&qUwTS2d_V8d{rW*%)bl+w@8Na7~z-(S2Z-V3e?+Tt! z--ADLc^5qV4?TADdmyzNBY1-CzqNytfvd0IZB2rek$9GPlvw(sug1p1p5zk#S~u1v z4t!~+!Y(3w{nGC0(lS7@p{p=JviG@z3OTvE{;<8SVvMTcL67;fmwwG1{pzVMVV5)0 z#w9!VF%H|doq5NKzr0&L!SpF_xvCM3K8T_TVfO4z7#MVY`9mpcFVdzEQT#x;OHcW? z`!zkm@8lS*CAW4k2@K$Cd`RKAc9v{p;zsG-W0fX=yYEoHV3At|@N7tT3W+`g<{*`P z_oM4laU<~3C@NE78nNi!cLU@D@#3QjCYP5`NXw`QAqq@gyZ>93MSGY9kwS|Z{jdoC zUF5nR4lfi<0goXPdXFE^gl~9^(Hfh$HUP?<6#)uS;r%{|zZzRrCl*B?6Vim$p1hWV z4^ipu0?3V#`Ox&Ekcfv=h?TPVYJaBs+4oC7@R$CtJd!MI8(;|2uu}q>Inc8bl*DXO z3=*J1VtB3r=j(-KV|<^RCzT|v(p@HswL+08%{spEQe6X@OF}Hx7-5c0@MSWU8Nujk z@hX*p&MbTBsQym^*0nC$1Z70@V3`oAwJw@6iAs4Ds*crGXvyNE;*mV~A;N8yke!%O zJSqojG~S^b)7O7xYNJU{+2e~%9Bs%NWZ#M=#Vc{>w1*pQVHT?FSd+VCHp8R*PJZ2# zK$EmC>D$YdQN+`?I1yb3qWS_LyyYG+gxF2WkQ?6kaVPwyd|mTq`$1su;NaxurfH|H z?%jDPc-i;y=sNw0yGFC08X%cv2gMH9d3!t6ncQhmd|%9QEj))c*A4lCJ>73?TT7FpPuw2j ze_D=#AayWy{vg&pUtFs1nOeCt z!b+&fK6Ii-h#X@dfSWrC@smsuU2IWcLG-$Ch*RnE0q+u}HXJx&38xT_KD)8=bOz6_ z196-{_T#a^L+5*JJaLxSZVvK1H~GI{zn!2XC>Pd;Xi(RaLd=Q?nOAq~0!{hLPDRp! zO%z~IH9*GBWfQ}@kwXjHj=^u{h5WPr5AS7M5fEp6=Rz}AO7+|yj_;viuDwg{lv}xG zH=ay{>(N@+JQM&k%wjr0?M#{Z2%PgsKrB+WMiMnqCKWCKR%tXe@lQ_cY0TPzO@f-} zjCrz9U_?_CXY62SjQ|mv;-2iI+H4aQ{CnXzNMiR-?CcqeRZ5B^U_qRP$lTHZkQ-Q2hm6p0%^P5al?dHSdL?;WKyx+>WktK2xR zfAmA05mj^rHhy3%9q2hmr;0r_rQLtjXI>JYF z{jJHt1!ACx0VWLW%rg^&{EG^sPy#VPByy=&P`yZLfV7ss-decLw|GSPAslTvFSBW# zhR6mo^9VH+_f;|EghP=-TJku=V9?`+<%!>|NItH_!<*ZLt=R8z4&{?$4q!zE{=X* zU+DvOjv~XwSmU@r4!p~@vKKBV+ELtx9Y3%_*5En|&yAc$F;P3>x6uQREteIZmDH3T z2yk1dwyLqbw`JjWHOf1;X8CLz2{6(obTG=_PFngbt>i{3h+fTUU1x z*Y<-IIZI}I3V8O?XO~iERRgqA;gBbv)!)-~k2*$m@2(4?S}@Bl*!r>yfayib&s8%X z4-!Au1Dcj0je=oBf|vl2^|dt&XoZk{zscV=gJ$}bDsUP#d_%Df{aTO?ZyQ$A@GiTK zT2{n>+JioT#?-i6vAi}!!hEe!Y)AXHyD6=t0^TxG@Lxq}Nk;;7A;_Hh= zyPW@t?s?xTm-ba#d0X1c*Dqi! z%C#0(0+(9m8^)qmB6lA19@Ug`oF_X{egPK-WgT+!A46}jNM?BYIqszBa>lnAO&BllB z8p(-6QOyCNNF+{QfkS~zwZF(J174p>I5Quo6US}PtJwHbcf%#{>OUh&xgEVMP8)^} z)u*HMexZ_#NSAWm72KkEddY=9e)MpB{1YTUaZV_U%mJ|i9Xk$QW48PzrfKboewt7w zVmoJs*a0!`o(8QF5%=7kS+SwLqLTqIo)7@qRxn&}@+wC3w0l)oW#Y)nN;niQ05yn( zb4#t!P^Ep^9ZX^at91ah7=P32f~MbIoA@4$Bk|KFDFvgb2+g|v6pamCBs5OHSofP5 zOCaV-v6Es8lA?lk`__>Vqbvg==GgbtX2L*&0oxGDhE^(;B4WY#_5m~OG==XZzoFXX zton^D+Eqy!g@P0Cz9))e03`WIdWGPcJv0G3-z=)ra+*OlQ$e09MNhb=PG3^SQpQM4 ztDeXkmMA~HRIn>5=r&7|8@SY@pzf_q7a+07FZ(=YrdnPi?1uVrZ|s`gx15-=mKvIR zcY+_ph3~|l`XTRw?!W)>mG!&#^{*f^PtD3@&0AAX80%9IYj51gQ{1lc$Kr?E>*dGG zN9^u(#ADZ`-xKb8C}jQl(qz0EeW==C`6qk(e$jZOo~hI|;Rm#)1p-T!47E0+izN=lr~Z)`7#bz1o}Wwq+|?@Q`hy$b-|axO$>5Q|u&B z7r?pTEA$x{&UbggB_JtmhUKJsD(%B4Uo;Is9-QXk#Fo)?Qgd0<4WcIMmBACoe|9Af zp^7(!&vb25%2?AXm1`wL8N%PjcQF5sQ7w?icj&^^)8yLAkw|(;NiGXZR*fk8{co7v zi+j075CujWfE=4)9?0tg$);#5gAe0jb2alYyO&0WqP>&UM+T|A<3|^?o%^d zv#Hgo)y3FNSc*<5XXM}XSaYU+wwZI?Mz3oSm+r!a6sOG2kh=b|ZwCHripstd$r83!zS||J^g~$~6;CrU;K5m*?iU1t)xca}TU}O0%*P z<_^0KxHaBOf+&TckX%y^&z5A+{r7x{?k1QSa*1eeB#({I$Cxdj!lGZx3Ln6(%FFj9 zH53kJDBn!3Q%Yh9LW>$jV$Z=Y;xHkuLK|~l-PD0q+QHJlBkEIjbz*6KxqKFgK?s;JwEM52(%Gcc=fh$ zI^zz()cQu*#dI{4`WyO)XnE-@+OU3=tcW41*`Gm7x-}L0vv!&~ko<1C{6F)^gR8Ju zM;(N>R6zMpVd0gMeqj7dD$H1@SlV|_tCF4QFdRl`w&#7 zt(Zq%v^ZiArp3uP(H0um%5nzgF7m#}T|fyCj)X-EUTLDbljMhmrece)&DA7biMYne z)_@uZ=fc8)1^!hrKONYeM!hwPK(l7F2*OZQ21Vt6bTSFbZO3UWHRs!D*RZZWdwCU9 z`GyqWoRF<|GQ;qgnURi%(A1xUk%O$1*Ha<QSD2_Ke5dT!j z+j*U~4Pr7gi)f(=5S$`MJvV7qHkZV}xkUjrccU>*wBv|IF;72sv`1t9Y>z`5DvnJs za@M{U1`_5V%WpSCP&3V%fj#m*LX`MLujpDvBkLsJ%C8KfMbCHiAkHoS0%@J2gTQc$ zr`-(8mB2`vOzJs~wF%YyQ{3a=SAKGfvC4_m-r!CC8$1l}z4MiU5a#M4)H-H488YGT z6$_jPQE{s8EP}qPH(Tem{5h-BLmLGD6WmwM`@~+#;&<%=iJkohHHjmuz8vQ_QPn`L z+_X?XqL=lueQ}DgkV?kKr+eZhSp>~aI19rY;_4>5JNoCYRn)H`!YHCD12!r;MHnK?JN>EpuiXqw4wuj}SU$!BJx#*X6s}jX_e-BL6q@J{ zQn)8;q?cJV1IHaM)2owXwqAWWA3T#0gblt{r=S)b_sp#@mEqGho<3bT2frXX@}=Bx z;vLKdorB8oI7U?d`21g!sRqG`TN84LPAuNZ+ zXH5WkPs6yM^{|Ro9k6;fyK|=1SX+eva$UoOouWGkLljmWL`0v=AgRm5-O;UE^THB{ z+jV+~77~a#f3DJctT9J>Jdc<1t(6gg(36dOg4r%bsBYAw>Gh`6^En9ZJHO%Pa#O~G z_$o%RN&v$rGf#IzX-J7I>a*X7-vechDB5`U%?x%itZ=p^5p#Vy1$L=8=CP2OM53e> zMRuaV+}qQFQOY!g9VS#4Rkgy#Yqe|21;V=Kf!EL}-Xh!DUVjynFK!BQ>At(g)UMKM zM#)At|J_B^R^ZA<&t(JNm80X~SzB`B`}e9USX91OjwGioD~D8|0pzjbZyfg%*M;S3 zjNd{NB^S!Gm<|#TGRW==(x7lWnl3N!=7Sw{Qp=cgFlugAaS)66%t%tcmFv`5HDbj; zSH4A|$I8Ek%1=B2w;Gn`q)|hNZ4dytqQ-P@)Vq&8nQobVF|-gJP^^I$1+WkPsBQ z_Y7SFg9?EhAhgEkl%BGdq$NZcq2N@{Kz1b-L6J;E5u< z90olP#$W;p-$h_)fC&<5J93U&=0Kz0mM(oP z%AY>`$LUunZ?ay-`8H!hqB#5M++w1kNT_Oe;Tx&%B!}Dm zzu!Yf2LJb&`I=_JhAo=x*=WNsj$S_gu1C*v`AHI%ttS^N>JpmwhRR2N4XJD`KDVHF zE%#Zek(u_^tHW|6za+PlnEp$`OmLZy6}bo-)M$cG5I!_VtFrL#8S2iht$bLf18@O| zu9Nw}(jM9G(tG6v$!dU0>=7=in`3i`Ve!3JTpGzz!v*6QT6C#Meah86|PZN_-*cFHB4PUoVA) zUH`^fp}F%Bw1`XpTK3%(#pUJ;`k|*V1-N6ERIRa)(S$TS zxJLAD-#~nI6DaCI{CO;;#%ETE&(ys{Ik>97#*Dq{B!km7=lg zfDsqBM;(4R_UaHcAx;NI`pD-ZE^O_DDsX{r54Lk<*5eqjFtp>S`{yUsKzv*9fSpE= zRaJ%=z50SIPrj+Sxr6*to>^O@Uvk~@yd=NKYTg&uCj=@Ad0`pbb**^O-rSYvKGMRr2+Ob%~4J( zJ|x+7vHg^_ke%ihl3gK@Bx{F6pXpXqHvy7yS9nx8&n-O!X)3gBw0%?7R9=~<`g8HU z+7NMc1kopVRxsjxB1rPF9YEGl4=+)8LY1nCMWdi8)@NoLFY)k8G(qcoB6fVXX@-Ug z0{E4U$%_l>m&K(IOCQ0F`8!FBli8jU@W%K&+S||7KPEuI7$nk@N)#P1aS1q<+D*o8 zV!<=U4uN-OqYP~qYnptFyDheHrScb-HBph`QH2|c@*GqQLeO}Ww><{I)n%c+RY@a@ zNJM1`M=r8(Gyn^=4u8+8REUdSDD8Ai;T*jIyHp5Pqk6PpU@*mkJNE~3nSjofU=6Ye zV%!sJ@-O1f89c)qE5C`a^<1Bz5+0m@DpmAsW{-3~z_V<3} z4_DmmQU)`V9j@eK|J3I+`_%CcA-Hla#P%o=3#jj^d}o)@EM#pO7zZ~<`g?B?NenJ9 zU&D4-?sA$hZLYPF>IR2-|woG{N z&LfHgI!*NFj@>F6n(Km6Y@4INvIj~09lBkWz~%!l*9g{#!KrZr=hNq?Rq`sW9-JkO z;pZ(UA@gWd{7v!Uhf)h{GF>NEZ6|c|t9pXF(MSl47*fkA*^53l+^U#s7Htx2pcDHy zCKNEb&c@!0-n`jL+EBjB)29S<*Ewa2aq@A-i^Fkut{or;###r!T{4~3D)$bcyd0wN zk9N9T8d_2JAuLj11B6z3MdB*{sBY;g({;_CH*0nS*i|110dsx<%k~f!%>{oSPh_c8>+g` zZ3P*tJF$siJaevS8+>!lfn_T@!#$jB0sg>vnY{ifs$)|R%Uz4kAOI^h_uO1x)o3!< z-~T{$KSo)FC~n@X&i-r1VyG}`j1aR9^eqH7G^dHSw(qS%mrX5l7g)7TFBLA0$YEos zGK0Q0S|d8Sz0J-`H7$CmYS0ylKr8do0sW>QDWT-;|$Ud+#5Aov;c3BbXVX9v)9RC_hM3Rx-YR$ZA&jvh9b zG@G0c-_gIuIyDHKiA)x|5QKQr&J?gS<5yAp5>BT?D&lRjDwMnn)Rw2toTIU@4Sf0Z z@J@&dn`w1+41*}s`+gSphgdwy{_>~mKvmCa{W?+2`$7&O>a4k*eT8SyM@k#jq2ZW9 zuv0F`8e+hQ!P;Jh(4AG5Z+0kGXfFVf@u)a2HP3F0Kc6iE_8WeXNHWOoZ6gtE$5Ib zN9PUyzpsb2TCwJKiy;)L!M;Abd5HHCgu`Fs=Uz^1+BhX?XEk zhdlJq+Kq_*cN$780&Ph25tSG$0#_(D`5I#rRno<(?l|vwN%Zu|^>hWE3>S_ff)K06 zh~n2iORwl8J+|fJ_|wi?*(R|IVSNCS$442x0n3B$1W;Dj*|KFIf!e3?L}FOL8YXQ2 zf)8j@3d)_Eqa{kX!Wsl%XwZ-4BNok9!)l2c8V1g26aq?sDLOsRDj{P4*GE*S6I`E{ z|Bey&gkxB8tD@9>Hw@?R%QTMU<70k?#NrJ6YJNcc_!dHYUT7h3?$I<~x!{Z%>!}ej znwfV9OuQE2f_E#lh$%m$9$W4!bj+nCS;$n(6!BGq!FHbFnxz}{Q^ZO9Tn&&)qg&RN zYlSH2Y3zFJu?DZ=W?^bO{7l$kNI3*Sq&cwFGyV_<3o1IT( zWTS^{r3F-S1{oJ0nqA2I8gM({Hpw`$XP_&;q^iS)xj1%iz)8XeqN+KCFfo$Ewv_^a z-&9~zYKhnNObtLrQf%h6EEcM~N-Ux%41~p7CC9_kjI>*jyKQe*Y8pvO(S$1CZn}4^qkF}#Q zX;ZfUsp`c9Xrh70`p}H3mngu&NPtPQ`manl((-@Z5oB6tphXZ_#tl^u*_&)G_;??F zGi|l1LBJ;kSbUM8`MQFi1Tx0?tmca(bJP$I%A?skV;12H`dy{^9FSdTF#>Xw)^NUg<8V~E}nWjNc;yEyv#CU)<1w)K;zS~pEH`_ z#)l!Y@b!e*YePE4?SudC{yFd0Qq0zyH=eUF$7^)fs#CHXzi!$W&aE^nY)9 zM`fYUlR@b4*n)Q6<>rWz) zI#+Xa0kI1#Rcax=VP=g-%LPtzLNUqF~$Nnv5xNe$`$*N0SBG%4vHz!jeQ-<92n{0qaCMA|Giq+h}W;w+x};^^Zx%QG1Eo0=Ee zLffBD8&!H1JRTSTF}O9IO7zKAz(Z5<89_<3tzP${4 zHE$g)JhpM#k2goJ-h8ik>sESOg)+$?-O@;F97^d$g5#*V8~sYyrXXyvEe4hx@ls_b zW8S=!ATHA{e5MY@hLv~yWyA|zbUvM77{s7$ILcF)Ji3RO%+&D2AeXmDw^K>QO~x&% zC=N^So;IS(3#m4?xJE!5>f6{HKmjHwpeHGWn+4HBD~GM-jc%ApM_t>St5DOa*_R+G zK|&xA&azIgzUS=Mitz4pSTyn$7>!rn7sUg+*=Vn9ND#PPnX0)Dh%V2LIA4W$vnOMt zUFRDvf877Ugt+I)=5TjC0{Z8kjZSTM?76d&O^e6O^Jq52(aVax7wlAp$8fO(+2z5k#!pUtFSak!wl zA~0u@b35a)-R8jRJ7t^W^w%BOFEdKLdid-WP_{!(jYe8PXJ)2qyzl#$W-hf_zR|dq zyY7i)4@L8IeB3~?P5o1^FQUpXdYy?iYT24r0$GoO^a|B~%71d@m=6$#%zyq0FxRud zw+-DX!OMbL4CDCBNQs+!>RDN#)G#5$e|JEz$X;u=Y3<5C_(TuPzvCNGS^g`aNKU74 z7UIuGqmy5ZSL@Cv!u!c%P1?#CEW8He=qyl(B#e=^G=L7_(Aj6+oci~Ba`&Q;-(R=p zun`vnl%=8u{N+UUR063Q{ha2#-iK$=;rP$rhx*`D;K*kD!dHNFL4(b`2!NyPwqZ6r zF?#_9r67cSG?BK9V(Cu){~st>V?<~mjBr!4LiMCEA68idL&&^h@2FVT7mYI-e|tvI ztD{CJ;GFHBkX0__NOF$-?80Njcd-Pg;C@+B^#U{XZJR5EH0r}CznCggKQ3)4zDQ&uIxa0MJ-%D(fLt#LbIVhJ= zHTxc-@fE%36dRLOP-CSf^31yZf|SP_(>})^G>Pd@a5NWRE5cCw#u2C%aAA@{#7A2dN-b{>d*Qp#MCSV@V=ZSHaH!Hd^WD(>#CB< zEPF3DfB*Sq(CO4WizpTZx{mMv;2~>~Kt7h;2g!FIAf-KU^GO^YhVDa5Z3qhtO5Rk>CBB-%UoI=@1GLa@=Ifw~fq z$o7KH*Hj=brmp&+w8@9@91GyYk_M=Rif%F!92|n;L9MkS-d)?u*J{QlW4hN@lgMoM zW9ao{IiJdPEbWgCz5Gj%L$u(`o0waW0u$AukBFgI!wJtk-ctNRQ~R^Epl@t;FP04p5Tr@*YXta3bSW7~!9p8Iw^q+bwX6STro?Ho(X*VSq@ znw>G7OI_VqT1sF<2wP96&t_@9{zNiq{(<7@|Wn(9n@|kp@O+ zJAgc~NNWyanx($4orLT$I6j0TqerK83Z)mO4mp#_biFv7hjB2+N?v^#6!r5i-qu6$qP}3P7o(MGVJ-^W-HD}{#no3gJ>NEgP%8zv;)vS z9c%Hu_}OdjPgYCOqis~i@Jt;E#^Tpgoj*{f0t=z06+R=pA_L-7uO&IR56!&;$@r+M z7Ia1FTV~A7gJy;dPrP&A>01ldY7EEqSg22Pd5wc@5m690%6rB4qw$E$RSbYA?4FQK z@8?K%S9r@xjfu#>OueKJX7arV({t0XpU`4EKHowh|b0NXB4>5yP#Qoa-J$t5NlN23aN)vJu%M~ z9DPQTj3olsB8lC$XNp>-iP?PPZvzEk^_nA&?EDHV6%40?mTfMVq z6@EJmqeErCxDoG#;GdVoEZSG8M`zCM5qtCbkkE{OP?cdY*N3q_;N7?T{y=lD{P}7%HJ7_8!03AZS1+ zKpt|R75W*K^&P?t`1Z1#^QQZHEc{k--5~tZ@MIQn7Pk#?pW?p1e7yIJ_-wuyz3zTI z3BSANYy_XppOPq7zw8{7)e_s!Z4GJFs;b-Xm4JQw01Nh#^M7^jG8x?^1N!g06|1B_ z^>nW+YFj>jX)_9@@l}-3p(Uc@O(FM@o8w5^iHzS3Rq}m@BXOs{wEN_-YisdEV^Cl)0qW5j)*KCF&~YLK z4aSPcI*9e|wT2U-d|u3?|F#HW^G){Fy2H~)i^ng~!$rl=N297w6@K=fn88A25yv`Z zz2L)r*@G;$cnaR`A!)q-VtS-K*WB+?zDMGu{XLJDZ$6%j9UGxN_bu-@-H$#!xB>5f zKK`}*e=MD4Sd?Acg=b*sl$P%9mJUI>8|iM4?xB&EMjGh`>F$w`?oR1W$!~n#@7K)# zdmnpWvDUesaLCMSHr!ge&Z0_!dfHIr(Ci!XhVb&HVy33RWcBeUbf@`C(xi6s*A%Vn zY>$oGdX<2<`Ni?4_q8?_una+R_6@*9x6(2hjVZdJWIfcyCCRT+m&*ihj$dNxl`nI? zKYBOdy|jn%Hi}?L0RFt2G=0ir z&|K8Pr+*mREJp(%nF7SC-VKQdlZH)8LUYUX=y{Vz7H)ZJO&Y$7qy4}X4~OeOA@8?r z$?sq@L1=u}^8-wb6{F&P%JaSav%DRUD_X=dM{!{ENgmA_Fq$V>1`{PF#ZLKJqlr|6aNH_IMqRP5u5UV7ySNo1BlV3lT z3O0{fK#@}(>y`Byq;WnOG$eqodAXMVIoo#xyLJRG8lfUluhYf_wZBe}irq+o&+%`U z&6i`OR_7f{TYrYnn8|+>|MDU@pr5#;5ZSOB&247Pf@nFV7g_52&jvL;DLSw=E;$qd zKPizMyBzB8`;&b)DDsb)!9lCKeS((f+!egHqI*J0jjp{>uC2NDUm-tt=PemA8+;se zohwQQNo;RFNUhQCn5q~su}UL-l1szR3nfs??~QjUmy-l;+0gbP(}ATcGkZ~iLdnsp zSN=;I`Y4kHlOMpMCZ1CtLUYE|vm&V9O8nvZ2Ba=#xhsA)ls`)knlcrB!OOy`R61-S zTT!iqvp^MUR7Z_GU{ASTrbaY1HK=L-K+j=-o0uRMnoM3uoUcRCfW|C|bH40zMj&?9pBfe8PlBOQjJYJ21rYssGc^Ng{iUXL=q$2{DT*RULo-f{r z&FJJO5M*xw7Y~P1(v#r2!ocBb+r*Q1vFud8w>&Q(UG=w&BScBHbZ&XXIEUB}@b`)W z8oehP@1L)sTWuz|o)$;dE6peg>zH8L&jWzQ*@8GuG-~jV&{b^3;Q2eiBuMrH0=7nC z=#0wI;y}4yw$W`g07?@CmDK^xle`$>n|wkoJc>^iqT=s>AH9aqYjAM1gAYn8W(=ehfu!j&aX&`7egXAtJEu45r;)#Ar zP+2H@2#pv?@Yv`|hRd?Q)(A2=C0H5_j97&HBFsf2Csjwe($q7I3fA}>u&Oa6Y18^N zYgX7noL6WG%(8c#0WK#%Sy&RfWM*k6^enOIv&lQ_^?>SRL(`iSPCOyWS&4t44?%WVDeZNx@ zzu%1Sm5Kl$RVq3P1`uvR@&S~N>Bf*ek-C1wbBPQhI{ss0GeJS6i}`*jukGv*lF!Vp z&OgF_ikv;4R=5WD+a?m)N<+!CIZ;%$Z7R*~EA`j6Oc;ad7g-%46biBr%PqU7q z)KBs=BAeR$t75w=%dHo9Tnv8{_&8X{)b9zXJzZ;uSYFv>@OvB!l&ED0w6Y)&Rb8vw zi_U4SiX_TMx^6+VTB0B7Cojt(?-_3u5kE^4g|A@$Qt;)uI;LI|{C;q(B@z5E(0e|@t*zUb%5`ycb5 z9vBvzOcr}D%WRJiAxtO$OWtwhP+-C=y@kJiCky~0EwHVZp;42NIoWCu+Tx4Zn?&sA zAq=P*o0T+6CjT=CW-O6NRF476Z8e#NZkwAK>*U0>tQ_k4Y-j-7dfwW zhOd_<+mv*xwfIOu>;-}KH_|_%U2-1@?n9I51-c9e>kX}n^1i4$AJ>^Biw|<{tLZN0 za#DNpf>Ue|?_*tLEDjAh(XO9{xW~_g*WvG(+W014f365$IKAEIita=U8~yj*Ac#JQ z`>ieblLTxc1OVh1`1dbF-zMI;UMI;OB_HtL1_LioMLH2y?EEhj?S)-HN91Nq1N5}l zbgcY(JIzB(5JMx72#$G{Y`*Bvs&a5@*x$Lw;Z|G=*$?Y|{AFDMV>|%t@x4%2ok(fl zBj4uYYVU8SICIm>p6-A7W{<90HmpyZTQys0qHSv}JhHue_~jQE|K2Zv2W{8nFqH;J z?Tc6{a^dbO2b-)1d#wk@F>AQ;C?1{hu|>+B+PmVY}W-;XESrV9^j_Ak{DQ_;A(aOjm?DO2A0-#zTAb_7bEOH^@1ik-pL4#5Z z{GO596ilV`&W_`gTR5>)4A#4v0ZauG3n_B#DkgVEGPbU4B*R*NbRjTSp~lmr4+#Qe zE6MHU;KN^`XG)b}@{jaNmQg7dP^CAXD6d5w^NvllMQ&@m%1-Wn&kvvk;b!P|Er7oM z_QHp}C$vbcAvpuWkzP;Jco_{+n2!NRk)Z$Pl@~WbN*Loh+Qk`M<%yWaK31ZcLt^`} zSsOL?EmH|hMO!4a-xE6GacNZJ4x(igPh1a;F-dly8z^X#GcKCp2+$zYh{zsi&FC!d zS!NETDKU!&>-pCuY}D463+DoA?ru*Xi@N0ZYcKtx%$)e=PtHGwF?v7a%qF#Z z;TQvDF1c*okt{DbY#r&Qd5ZSJ)mHW)_}~!HeEFTr7dZC_K?`wi7`;*veng7R{Fs~| zOmuNI{r*+c26V9GggAT|Tk`xx_$-C0!<2Id5(G1Wrm-w}@LFXS!@3_ebav+ z_wj8lxo5O2FWRy@ld7vDD423b}sCR-6BcB11QiGPPg~Xm5RqMSq zHaH?q?%fYu4}*ZB(~5$zLWdD^=~JJo=o5?6zCs_4ef{e8okpCz zXMO6DPR?ICyCZTR%X0fpPg!1v1Fv!itOD<>y4Ol>$bIil-v8q+J?~Duz!>_Tl4Ry| zIqDMZHK8Qlp+$F9*+!|So@2xcJF`Aq*278rFy+Q~plX*RFokGA957ZjRP61I9PaUg zLViygyTnKfw(9tU*mZ=)O>?i#nK*+Z3Eo9>(Ad`I&uhgoa3wI`I>dFaS~Yg=?Hh_N z9Bp&~biDkd7*tg4Wfm!p(~Lh=*hk|=#Y^h~CB2^G;hG5yYxpr`NkD~*YHZ6m)dDG~ zJe&c;rWfr)>5d>X({ zc)h}?q9WsO31Jl)(orFlfKj~!J#tXv8djl4dY_XaOwzDs{nG{Kc|}G*siW_W|AS;=6wUJYxc6cjSW0M_^eMpMZx(*$FDYKO8h=2_ zU9@q*{5g(hzuHDsdAK?*T1nqmXqyj5)A4g}#cyQWph*|zoxIpq%mu}LnID3Acaa%$ z*2J<`^hNPhTDPbCC2#8X7y2r7eHg^`HpWu^h{mBk2(ws>E0U0o$C(ZB9YA~GY!pS@ zEA@3q4|%!QvLwJgsDoXyz4a76Uptx*3RNT5JkS_myx4>NwRIWWPBNz)Dd1n zuN)gmRxb)$P%7o?y>pUXxuROOtL8lsUo01*+JJ~IkXmd*ajhz%u&DC>P03Z_3E0Vm zB@&T;>PV3&Eca!Tx>FuaSK*zl``|OqH0DJ$=4u;bg-BDydrpL+Oq%@Dlr;sjU`c23 zt})KKIa7b4A2i2z-UKwPAE!o{SLpv%YgcPc>_)k|7N)Tro(ELxcS|<9DF4l=?9I+V zQFX|ienT!o^`tAfYf~$Z33vO%_yp}2Ip|puwwvuomZ^pks%64VrOKrk%ab(2gTb%cIpZTY$EYu)nOmr~0+&Y*DXUC1D))sAg+{E>-wg z2%sI``PEbeV-cAl0rlttajUr*^5o%4W$v?g!_{3j{+fUFzUo-@<|SD{clyi0cJ~i3 z$h90Hazo)a+Uz%4`TTMoDEj&m__h&vtbCwgxeD0P3Rqn&xy!mMBc~$wMX%kokA6w* zB;o7dzvL4(9-~d?w^%KDq;ZFweC7qm!+FNDLm98cC2jtd;;}G#zphg`sPt}G_!4=n zRbk0+OY4xi#|uF=trnrIC2ShafEac1Pwra+j@FXyv6XaJPWr**@+09a`}t3?=O~@I zI9Ux>uQRO;rs0znc5|O(g~&vc_dN82>zbCs+8!E&yNX>+TRX4$_jvNH)raA_ND5Q@ z=Oh5O^^6agFWf670<@7s13IOg|I<{uQJ=0u2s1(zxVP!@C#G_r;d*j~g?qum-2gu>y5hlVa zqsIrp=$=E{&A-_77Ei-pGE5kJ`|c17TeAFN{F=hd{`BXx<2&ibBR%YJMvyL$ps`*( z*v`={#G8H=;N+^*IOG~iSmDC(@dU{`D^-4Io&cvY%$EgmUvfvIp%Mvu9>!U!qQt-*cj=-}LLUnlAs_!^dm`J@dy?WP}g;QD=H`_Txh4!gDsX_`ex3 zQv6=Lbr-ea5c54oN)dT<1gq4&9e->;2sm2xn0mO9Q*J>MB+2(QBI}DqmO!d%or^WG zqYybW#I#%$4g|)ztJW0^747OVa#RFHN-a6G;#si~2FyXIg?Rq=9PI}!S$NvdHk?Bn z^VDG3CC{@qs6jx}WIC8`+95&`I#(niT`5dS^!oF?e-jDt&9bGM|J!Xnvpp0aJE+wB zU*R-j^q1-KAC89RbuI9IgJq=q?WvnA_X)~G;LV=fQ6zKB-+6I?XM;X~cQ@5Q8b1^* zX7QLR@Qw!NRT-iZcpp|d=YDaY_+Feb9{b7OrrQl$G>u82!teAsnRVsD|8Z@=zfGF{ zJX(rgbab1sKoK5ZujF3b1{*KiW&0%i_X6Jj)yMYyM3uOg91_9!MZ^4f|EK45&wNAN zojas(`Wx^1!6V6Bk!=0y^BE4CZ~25`)COi5?SrUMl?}r?VVnjG6|G;gM#C&vz*Zy?0q{!lXnWnqW*>EK#uz?_Rx2@K1CP#3zpiBvWWa z+R7B(3h|d+|1_Ehmv@%J>@CEeK}{A!x_Ec4-VV8G!(Vhgl`6hUG{Jl;sxNwJbohPu zmnqiF)Dkh6Hp&5;dJ8;wqEH>Zab6HZF6V@$F>GO?mkW-ZXSZ*M(=_jNEY(#jrkEF* zT`+L!8%EreNPzpLLer!_#yv3lm!qlWDHQ&+g^_143Pk%ob^*X>#~#L{KrK_}c-Hn4 zUbPU$f*MzDHI+k!6h>7~+>}P<+nIon1;-#=ISlqxs?~b_7XX~jRuA=TN;N+ZG@Rah zLyxRpE2EhZoAr(px<<>I>N?EqLNdCuFbdCoS;`&!ZypD@2oZen^hWQNmdtXSsn&-i^VaN962?s^nOioN34674C3|28>X7*tc+odMJidXa;!HOL4 z$qb5c0%D+R+1?t6XsMCdgDD3&dS@v+x`2Y?e+w)3iu3QKq4%p66;Pyi6g`BaO^Lw?J%apXs)1rm6 zQ}42h&sRfp$D7{ONgzJ3z~bhLyv4I73^ri#-LndEr^;8*q_~AZFIM%(Z;UwWIg&U4 zn6_D5`X1tRmG{~O|5pEco7BXBZuWJ8u&jf~o`>22vT6_`Rsv)fg3ER-FUhLj?p3isrS8-=_=lKw3lGTW@hNiDSV- zKIj630P03@wcLhf*is2>Ko}uS0}~K}3PuD|9s_zv&z+Gh9>4te`2ry441{=2q55GQ%1pkMYLJ8br z+m!FVkyr1$C*64dy%8h&m+p;v-D?N!d1&K-=#uVrAeUZkC*XO945A%)tG&baQq&Eq z;t{zB!q|*|sS^#?HuAsS?H=v^>*5voIR|`oaPWM{8bKUI(`f)hrvqea^k*)U)%zf< zZ9E`tbO{lQRm}UN-jm~d^38Se!-UIe``&y$4B&yLrq0sh#l{OIZ%mGipk!cFpb4?Y z9~u#m>Q+*gqWPLyRLKyX@sV`ek%H%x0jRFM0{mIY%#}{DvM^M+WR~}9Vi;VpPZ##V{6H-@b)VXKd;3+}2VTQ| z(4wSB0BCw-QY-*Gs#6~6sdew^%a3wjs^ch z0oWrjP%ed$%5ipfcHy0S@BJ`R^DO7Ll;*!w`5ON^ zB^q`9dUigP`|>jo?;N^xyHMP?QG8(O-fP|reA;;B`)p(c3kaAv|111(@JTeP`%3#2 zdB|m;`%`T9x^%N4-x8ORR_H|^8dWQAQ_I+}j{r>0&tz41QHLi0zc%rjd*9U;_Hdx|2 z%Z+uvepNf|!~|%uG;CSu14gkfbQEXWaBDS&a!A^zis86YXQrwFiLns=Io@1#7Q47) zYV1X4-t1hmAAujV?KE(RftEUpF36Kq^Y3=k(V=u2MR_=SM%u=IcsD!m?zbxR5lu5z$g-do`>Cx+APKA393SI_aC>II{&W>>Z}MX*TWcI$|v{>mO!Ecym?=HM|>4Eg4j*#yYUK zr0Sf0P-pa`wK_PcS$Y+ZL%+t1iA*`Vpps%K$7#t)I94;HUjq=+>m5V13Ex`v`NWhh z$pMbCBzbN%O7Bcctq5k!z#F_t3*yFM=;JBH%i@h0-TEa|dcxd=%y#@Fl^~%aeZNL2 zrfMLno>|IDt8wt?y-}T8p>RVSK+SL3X3`?&F2%iT4)a&V_7vA_(1j=@lf+x z@rEq=b}brjLO;{-?;Z@i;Rra}4?zEi(z%c_x{)e7`A0d~4}9I<=_h+-7kTK#*#7U& z>5K>q5CiZ~!1hRKU3EzXMeRIfcVByIK}1O7`AL695L#rpG%Ihhr--D8QCwb&ZRroM z{KX?Ln?njbY&eLL1(f;4)kN+`u&31#JdEH!q$^*rl;s30*)qb|? zhb&M&A%Pq6KnZwREL6i@%SONO`a{a*gKHD7@^K-dFfAh!y9MbVx~F`Pggk*hT{qS6 zQV(~Mrqz?pk9Ce>EckiJ!;n%OyD+OK*}UR&Lh?e5qFM5&b$2*$27^jLJ+*j3S8n^h z+skKf$|o0mk~JE@w7kXFoC+mQp|aM4bP4~yHMe@Z>;0Moj+?iYU&=(oL>~~|$h$6yy5GHxcVC}xo@K5Cq>NK=ts7AvC2F!3Ah6=KC8I3>2ef>ZT1k~Ua%f-U!F|Er? zIrURAu;4VamfAvu2aEoM!j&v@)_j1#%&$b0nOL65qnu0HO|fpYL!O2AgKbl;lC~FT zM;HU&5>b{y;zRD-ykuKD?=%IaZeN}U;^N z07z!T$q~+v4FK=zh2N}*HzkPucobE$^Ydv#eTPD!hBsXl(b)XB&NdU~&nhxS&_osx z#ZwSeNc=}9YEax*GC0nrL8=4oJ-r4DI+#3InZ383K*)~$MlQDa7X$wCN zY2_|ET?yMm?V=|!dTb0b#w9 zh)Jo*hc9-u4V9l)WzC?;q^fSLj|F{QI(*N}A;^WIAbUi`{$#Jaxz9-t+ib6<|3(5R0sk2750w*-%+KXUw-p!tqR;&g|Hz#k zYb7n#|KU1oUY2VVpC)pI`~Qu=h(%vfArxY;?=_-Ez(Tz?(woJcn_K1Y=9``KO?E8Z zWrRHt5tyGpTbdQ^$j-^Y+Bd*luz>zlh~wI5;>x{%v>*cVxvn>Zum(b;Md4-4>3aK{ zT0L?{bj5wdU32w0AFp|S(FDyS9p$nA`O0ee{ocU3T?JB-jK%%ykk(28`U4lXi>ZV& zHk1mf9&_I#XmqhPtBTEhbjUK!4XLZ-5M>FV=y}(r*sY&sHma#-+$bo-viW7hth=bIOaA_B5MT zH~EQz@(vn<#!0w#i|(^9SYPpZ-jlsu1>PagV?o7$ENAgfuv7=WD*w^IuyR0q?p)_{ zwIi#3$)NN@|296Qh{B?mqBcxcuZzl`{SBJjd;-L%se7>Tf=;>LvrF!r=uCRoC!mkw zp|3=>{LEH*@O0v`l>jP|pJGIrmphezphPN3)Kq=5D}J&6Sg5#rS_?B<)4Pw}70HY} z#0|;hKMm$l7W|3B;%Rfsv2$7`qB=5u zHnM`Me2~fEV@kDdtQ|pnz}%`QV3&K4n^e6(=qG*=WE4tPhpxI8AtK?>-a3#5usNq~ zG&obD334s9PXN6Z=Upb8U%g)-T6n{@kMlt2K%)QaJL7T?u42jxnhPX~QQdE&C!F)|3f{QDf z&EASsF$m%ArvCy!wlUBmQ6GWfx_+`rST=w(IvZIlvrGzkQJfX|QoD5fUbo^xxIQvZ za>|-?B;wo)livJ3n!9NvfqHh@Q||Yg;b+6yKE-#53cYg{sK!(*S}d?cGXT-JpJB`_ zYApJdZ7vF%A#83IBRa@Ja$+I+O1CzNj-Q%<55hGX;>jp+sNDPmlkMnj=U4;K6c?D5 z13|Z%&yw$BUsE#K|DJnHE-1rvEE+@vHmu$DoFTIddadx)D(` z;Ms03aOw}>x|fuz6YjTP%{oWhjjOeQH1i@>n^b)Bn18^P+42R%z%F z<45a}{_I~uekshAH;Kh6BO$Oal-;pKSh=!FgUyO@4JQ2<&n# zx7gJd)D{5a?+&|R(CfKyVSuqA8ngf(a1S!{Bp_kei5LckpfQlq?Q<+%_=xJo%>tTvN)kT2#UHj{QeQi6mi|RL#7yBYy}IZ$Z(z!bdAL7?b^G1O zoBIuQ`^+KF?4gDS?r3TC0%0eTat2MdgbvZSHqf0rP4C z3U*3dI%$TRlrLv$eA;ozM_6Bqb~qI4eGST3zd7K&D8!u!#qC38 zuj)&Vf@Ook|)$sobW9y_{HQ<=Ej!L?52;F}nVTWJAvYf+wt`&dpb#B6L; zmmWSM+jMf0p`8|UGmq~4nL$t-UCo)F$T-5yL^dkE@+6O+Bmka^c$HszAqkbOuem}# z$rSK29qZ5^r(Eg`WdnMMV`QUZ^152kZ?(l;@;w(+_-a6Na3KWM-AG5#uZ&`-GyUf| z##cRSyOgy6^m&snq)$^@0k!Z$R2T9YbBNxwCh)K&FE}ALO2a8kE#6Bw32DiU!-J#0f1<075gHD` z{+bUSF~nhFE53C5rJqbYE#qAemwwVv-jr6S#II4^x+esmWv(epzUtQLWHFOJCO86G z=-GO8u5Nxcb^2zYHZSt&=^{cj=0AYz$;7MubKJ(IAn*NCKG`FM-xhwrR^>BO*EZ9I z)$3Sp(Ektn-gZU*Dqjo{Ic?^?mb}?Ke|nn=3=KRXdxEH~#0Ng_Je1cwGCmh|-=yU= z4k3tK0R8Sg1NjX7Z|8pRi!;~-Zr{JLc6iK9h=Ddz0g@;>s_GG0M4{X}5wk}+{@CG> z99k*lSLtI-t+hfDv6mg>zfjCeT2D1jzUB7Ha_BSeR9fJ`jC8xa%VHi_T}0-xgjh2u z*d8Ui5=#!jS!et1=qM&kFe;!)NsmA&CV#_zvgrd;tS25QymAdWjTq=Is!CE1z zwe7z6bx{A}gY+jg$B*s^o@|j(7^3JO&?8*myYY|WkQP0S=)vQWi6azcz$J`C->VE}81;CU+tLtNcHbO81T5kZExjD@9-FK_s__;ouz z=&X{`5ktAcXRrfaKBAky`$J!hFK+?{j7j32-pv?u#G>!G1UI7Z^o9euYN6N+ z;>{tTW7Jz#nD80)FXx$_o+17A8wHVv&B=Z!aK&q`KvB15rpvzMM@;r+1QedZ1eB{% zBFZ+j+YIeb0(~-F`duy0zGrD=tg$$rK;2>J4vkvT&i-nFys&1|6ABncO&{3Es?SI7 z=Kx9u3gJ`>5g_nCqM=9FR$`OhXs4Eb7g?uT*s<@ys2Cdh zM%I69+ry;JPnT)rm!1CYjP_5Ly6O3w=hOR;6_dVwzo*=`!b<z1L#)W@4XTF!>oqmVn9H~;;EdBS+5!exEQbgulnin%x)p%E2vkcX-Ayu%1(CxVC~4Qa%6miCmL z)})9I*e<0pE@|V$=k~5(R$^LQ_k3e+p+=D|*U|S%Go!k=smC#&PJOSb$6%>WMh^al z$oYjfzH-x=fd_oG&VG#)&%vRwAZcI#yDVnZjWtqJonK=VbmXa8B5~`E`GZTPj9wG5 zSSAJ8AybJC(@cN_&mGr&Bq=+iOuh7(XOc+lQ~~w20oKLdx7< zA2fjvO$D)cqo8T3BnEH~vJxDV55&485U=ev zpayP0Xw&*KjI9?Ezp3!t?BHX#;5zsEj9h<=)qP6_Xa+?##j`}=VE?mlD? z;)DRb6Hp_N`b`F$8wl>GihH-9rHvUOG=U(eR}psT#Q{XruV3BqU|vT#qdd6xXV*cU zEv;I>uN_U-O|=>KyNpHDJhDJ=v<3u+Y3M+u2G;}ET@FdXE$Z8sL0o3sG}0M&0=V2FPP>rcFt`unL9z0U z$Ov@Fy>u<0B^OTt`tH`9Z<>OLknbP=p_Q;?@j0r&1YL&%c2TUBlm_m5lu8 zj)RQ!GroA}ZN1X?BzV^SPhDH{cKk+2_PmL3c^t4-)9>_v)(r&wZ~j5|`ZW+{Z@F*6I4a$XEx;>kO9}Mqc4#P<{Jklf7g;hhf+5ABiuVOI>(eg{ivX@f!lwW2o2& zUQlkHs+nODEq4?uN44#PQ$Nz%pf8KEaJ!P*h z)pKkXCY5EaiSHaF*{5(6s*o$AQTYk7$-$m1i_LZ(#=7tsb85Raz~8vszVCZdVTOMk zFIhybK7sT|s#@y$k}d>EM$LAiu|9AtRCZLOq>GbF5|xV3^J31_mRDqEb?zR)Pa2?u zY(3x>3AV8f)ND@ZB6g93+k=5z*cEiH);_tpQujYxFx!ZfsjF^5*S?AIqA3U+)oLNvDpdd9{D0r}t8} zJcl~0og-uZFe5q59H)LTr7xDyD46XTmHZrBs2A3b4A&Snl+vmsZd=m^F!BlJ6pVe9 z4QJ0D2iov}$4krg!d63ZLRdzYh?D1)q!dyx`DhSwZy}OVd>_SaWm`E(!ZF!sV3H4= zQG=|{gH0{kM5%T9xc*|QAc@-wqd%dlo>8*5W)q{R9kYsCPgo0u4@*+cqv40%PXEIJ z$K2`W-s(OGckkbCu5_KQL|r+({&5=mmu9&C@UIGPJ{%*p++Dxwf1B9&+|mpHHGIEaHT8H)Y|Cu^ArYAH-!2Oh`s9?oK`g%nrW!pz!tkyUXii@Hmb_zy}IW{ zsJeC;2RZ8bYY%zX3g7V2K3zglO9`9Z8M;C-2!4TS+nhFSxfin`AfD=K`7mh~8YVb+ z9>Iq*HyrJ8|;=J*Uu zBl|T%>2{3GP0N<8CxZqAbK8{{rxHcepPc(fco3NUQQ?yIeob?W0CG!TY!!p-HD@CMfw8_+JFBX{?*iRW;v^Gat7-uLXlMfMb&1zb9L@cg) zx%)`Y)a-F#n0D4g;C_&tW7PIX9-bv|J}N{2J)8;7`;OfR)ZufXomN@GeG`bz_gHNS zE-PySp8whAxii{#6nfo52U5OQyGQ|7bb(lJ2cnPhczdERGorB*PaPBeFE67h>ydeo zh&Qq~zK-SlpFY~(;tt^vx4?z)l1f2XERhs*?(Y1bvsNGBRvSjZPp4dXTZ^^K${NdZP9;5hA1)@?3+-qI!OQqNbuc&(o*xl+R%| zAKp5>W9oRqF|~L2x19M*3E_#IL~~5-P(76;6DE~STf|X}e719Qo{3zue@Tt*X{R-1h)sKCetD% zuXbnD-^JV`Kyk$ve}qq<@m%mRo)M+oANtVo`q$DmXdqS)M1|tr99O~C-S3jl(1?TD zby520(UK@AuVB>b-=n{b@7IgJ8Tbh82HxL15C!f)8-u@AuNlpUm84E`3ex;12}#E+ zuGXz$!fP-fLwbKOSdrI zcs%N|9eNuCDHE1560~YvBcy9)*)Tz()JHiatWAb2n_fMXEPhw3T(*J*TL&%ptwQ4K zgkJlfa#m7jIqdQ>Ld@juZ4*-5eHO$a+QCE=*Rn?NXr10P%xcxS=5s^0OkxYY_sYS7 za)G#?)fD2t!l4N(FgJzYwutHo3CHj&ee&DJ436lkBa6<tb1t0lmB zXF93wi@c7cVm#KuVz|f4XXr>&7Vq`JVZx9>UtQrhiuu7-=VH7C=;JDmQK!u5vHuiA z&-Q$}j{bt#V#EawR0f+|U;tao4>UETMF9`zUwj=}bN!I&mbqLYm#Cur369Y7ZOb;L zx(1ZbBOr$V=V;~@p0-$wHsQ-pE{0KzpcFcnr{eReOp@xDA$P0E_hEE=Q2(wJeqIGD zy|PcXQH^4ZxQT{J8Z-}j5uaUZ381s;f)Use8&);mgQ(iv1)v@72Qf+n3AMf^7yEO{*pL>7xY zyo43KfPI*EdeMSrhW-vzZ}87hc;9FfG^3}0i-(`oP)n|on1YAQD<0yGfe++saBb;^;P*Qs$6vdPD)wEDmWW9M$hI z>eTZk6?v|N;$}o%n_O%mN$2YDzw%f6QdE}OK-?WrMSo$kAF9X2jUQ)r1$4{duETwN zd|bFX*?x>E=7t{=I zic9K0WiOhoe(Sc1AQo~XtIt2#AZCxni)eobJ(F%N{myTwcP$WTYclhS=dsw`@Lttz zL3=A4L_{l&7UH--Q`(DfCeNUs#3bPcWN&-H5L-SQ5OGM7ot%Opo2GtVgb}7~avJn- z$n8%ANf6fdm+-#J!g1d9Szp%#Agg~L>w&3j1KkQ3`7BpBqqSg=Va_&92__K?;^u{&kdkY!M|mz|K>aT z_PE!LTXXhk`c4{!PS2|m?{k^GNqz%&#y$!Rx>TvM`e85a<(yBV{wLf>jpSMLWqITz zmGSLBnLvvzboFRBHD+VmYVN@&S3Q-HPfLB~WoOfB3}zemOZPu;nXB>A>$JKto5lJf z^s<|)%a@%SeYUXA7|k?|q#l0{if|cY<_l@Ob0+ci!|D} zKYp%6sqPLXzKJ05mY(9dcA*JASkRUU2=&I)ci*~Bimy2EZ+{1WyesR&Hl!6p25Y9n z%z{Tcl&u%S8S?|s75G0x5bVD!&+Z)-99{>S$!(AZKP!7;2aNmB*z-kh4bPj>MoCzCzKmR|~EM7Jxjn zcZ`$l;`)UUnFskkq>b!uuU}m)NgYBfCYjp4v|cwdx?e6Ey}cjzBBF~3y(wsO!)rur z@Sc`SJ4%U{J_a_Z2$)-bQ|3krl_3z<{cNku8FVnx2aM$EL3|Jq0lEt?h!ysI%(%z_ z(S!oXR4=!(vw_G+5?Ckf^&Ri9%_%PQv3-Q7S%1MZ71z)Wz-RO(E_GEj@2xa(eq9E9 zvel)Vt|Q>KDdU+3M7X*}enDoo>5rc=mMvupl6A=;N~YKkB0oLernN{XK4-b&STEl# zW#IWEt_7uaV0EwBA5qIo_0w-Hf4GEQDPJXqR)xAhAnkCU4^! zV7R|fTS;*U8wcA(&A@kI92$Kpp6Az#cex^nBfVF?hi)32hLUTb~&~74&QnQiY;1Hqd3@%FdRoA84!^fo_=*ZsO2aVChh5-)TD^Ix_s(LIlCwo`OOT| zO)yF0y@Atou(U)h5#HFP93ZaN!Ln8|7&Vx{L{tm~ohQ=7nGa74j1JFBid7M!FmzHI z@ZRS_M8oKgVUB#<36{1a9z-Ojk}``!$7}eM`q)H1qV{CEa?MZUUYdAUW-wI}%EXOM z7BXU+SLCs(q9uXvQzd}|{Ak{@*vO&wR}Q><>ONt`YUFvHq0jVlDiNkgC6}F#p9g$L{V&K`U2%O(q?+6c@5kvo-GzC;=9tZ6}WDu zp<`Ofcu#G;nzx}9fN49L8<~|-CU&H3g7;mW-YXbH&{#wD1c>9*GOZ2JuKq15{vn~L z;_8@qipu${b|hK|%VV13{wlW?k#Fiw456!Q_e`S(#1W41KKR38CTD-7DTWZ;fKa~J z+yzaJoUbm56p)w_5ga=Q%?+GNIJ*_{#f6h>MS>{Gmwd1((@FUoGL0su=r*@zEYFyT z&b*9m?}ROy(lG%Ordlq4+)fh^9u96$riGR^==zs9zrR#rf0%z;xN>}tlYOTYcxqZw z+c9RHfUMv;TD+&jzMAH%jq-X$xR#{kO+G5Zw9&E#!Mz;*0n>-g8s4{o^I|g9d!|x@P*X zpxEQFV=LO>ImN-#OX)6P`LBM#**D$6=O=2t$LH|YdelwWZ6gKxBbhfBXcsWBKZABy*&9W{&t~o+Gy=6^NHo&LFXbZr;fEq&K`2#6uXR&F z`j)MnAUM<_*)mLJiGMBQfF~)>iDt#1IdPt6H|kZO|&_hI^nFWBs{m#P5m_ zoNs(!!{sjN;EW3c(GX^vDWvYZu)f?kzB5EG$3MI-5i`}85>MA$jnOL-u#9;y*N>I2BKo^r{GtEF7`*S8;SFLh4p8^2Z!lyjf=oqA*W9hc1#ou;b)ZT z%RL#pa{m6iH`R|~wK8dws(g5#J8qCXMnf~Pt8Ti?&>#~3DN9gISBTsb+I9mHR4SB1 zH#0;J_N6nioSv1ebw1w=XFr&gz02yoC{as8^~_d2-4Pp01BZ_>b;NL!R;MA<7<58^ z#ELwa)*Ll?KOG4DfTwc{kw~(q05T7o zyij6n(-kopq>@hFfJE7kOS)%5W%#Rs(5}HaJefFJV16BC=MDn!2)%dvv5}`#=*MYHRMMSr;yZPKUVw3AvrGPhv-&Sc?wFJ%7HNa^u1p~mZGY7) zlOz3c=ubiNR&al`((e&GN|?>J{#@r5<#Jm%HNAjzo2_;}8FV~oXid>5JM9@L93?1C zg4tF+`6|DqHFtmg^&jYJi)!upvgmxf@Ww4-#jV?c*Q1r+J8`1!FF`}r*Lc{Fmpm|b|+!h0-xf}u`&{e#^y$p_AY#Z=`oDu|u6 zTonCHbW#F+&|*9HoL%-e8FlnrA~~}C*X&E{$<8DM*IC3>bcOX_I9gUqwdy{j-9hS@ z?W)qW{)K=YRTokvH1~lZdw}Vvr7;WWi{m>$8)_~ORqSaBGtR#7_%Pi5H@c|QQ;R?0 zB5Yv>Gx^k&vEG4T_tZ;-*pgg;3OhyrZPXvjUFgPsf3cU;|N9g7%(-g9Qey;GGAstk6TikWe!zf2f!uN z{Cd3Bi;q;ma%W35D6?&RhxOoSd&E%a$L(NCs9n8E|4n=X9UF~j{e#du;ahQ@cn^qx zH^@39Z}E*=cO1nRd2v9cl(dwIJB2a&Ev-fx|7e$+Cm9;gX{;pi$Ko4sSyY|) zRwC$|d%K6?*WZ9LetEk;8!{bi-|cl*_@q>m@Q;8wjY`TPy$BF^>p4)R+pjLi4y(fV zbnMQ{#Knd0vVY(;JhJLg)s%m0+u`*NuDp&=t{M+GaKJGqQKplF!C7I!ml%WA>sN61rO2?w(=SO9uhQ;EQK~pgRMC{&0Un63 z^r!L0TdHoY*S?U!T9z^)h&ax#5==xAup*X}1184==nO{iWumQ|h0qkH=-Pzvzk2gj zJCfMNQ^hd-qUXdT;;=Sl527;)ArHr(FlilPYbDCzl`Khj+jEKQh9|Zl_~eSLY0ifaJ2YX5v=$a_ zS{YZ{U#rIxoQN&zoQ6CcZfja$>qfi2ZG5Y|zI`+GZW}+LTYz1*0$8~v(k>m8o9mX{ zxwZ2^@P2r3@?_}0hBt&`Sqbk5!}L1+{>GE{u-!k1(2$Ls>Edg%10k&)5QQ#2;}mFH ztd_~X0R5lw2Fq*wA_L1ajx)g7WW%EQo#9_u-DHF+6frl{9Cz<^5^O%HY8W&QP#kxd ze4K1@#j8rB{3>2|H4K~b#m|hr#9w&14cU9qFf_p!g5TP|07OJHSV{_Y{X#J62yEO<+h31Yw?9*ap%S7Vg!W^0=#%>^ZaryBS1- z%~7G0Z$E`Tt7-s=Yt+X^eAr{+&YWrBDEvD;gZ_k+`9fF)Vb)|>9Rq2%`5WFbYzDa@ zb}2&XOy64|M~zvRe%=UH(&gLq<=~IcGd~i)%gcA6w?(CQY;c49^%Quuse3;H8!oQy ze|+AP?QwZ`&6B{oieLH&@ExxL$j#>+$H~+zkr@O4y>d0@%0VD*R_H;GD+;5E;S$!j z*4y@GF>Aus3M+ChEU`KOCA2vGK>?vDG=;UYjz`pasp>nB>U7`kffJWCvL&_ zuRZS3>u%O7x*g`e?@Nb~HID*jb~tUzjmc|y=Raj_Uv%{a-oN4{>U4{QBAgoEoOTA# z*F<7%)e0ONh1^&88Nq`iz}qKKEa4W?w$2`Q%jO(d1dPzQ{1_ z?lG$Pqhu%^)((25*ZRPI=`uHaCBvl8_`2MIf~ zWBWtt;Q_$dQ&vB0+ZM(J9F!vI3sZlDh$oqt44RsyFV@SDz-fUxuBr@TH^186GYcH8 zpx_8US_?wT3JslSxK2iq%FM zo^?Doclf8A_uEdC(t*S70y=uUJJ%b0f&yFa9Gx)tbB!4O(qI1o0 zBWGR%1{GlI(R@jyB$C78C==(wNt-Al>mSc;+HkfLR?naWCQ*<+qX-{)yx4o61|1*^5Bx07hz1bfs`*`;`sEv$c01bIkoj4xFqWg#D;FHb%lmVT4MehLQT!g}|5cCl^tIxY8) zNtb`K@1h%LOgcvMj$BYt~ut2pg))hluw=P?S8m(Qy>V!z?&QymP zP$)eDsUzO6i&315vZaboL<=!Nr~w8C>_Ha7WdSJU1mVluu>Q!70JpS0A#!V9tUHb- zS_0318a8%5S zf!{>+6&u@fhcBatgQl8$`Y|?A!M(0E$>O&C?dL%qLx)w29tw`-MCZdwEp)L5cG>&mY~EzxKkJ)D<}~&7ys;t zcbyXSXSEbQ^P2!G4hS33Dr;cXpBg(wjq#g|kTxTHTMq&py_pPZ-_r&<4}>Qm1HI^u zE05>%p2}pk#U!dDCAQ_i!?ETmqNca!_2(xln?YwU7@c{D>03%4`%1gdw?||P4PNh* zXmF(_Xk+!l%s?HWoL<2`5TXC9GGWcrcs{O=SOes1K3$Xy6u`$X?R$-oVcIjCNT&K| zDno7cSETL2=VIo>X|)(f$&`t{G?M-U;+>~?N}f`-ftl*~9_ zy~GZI^~Tq|$d1Inu{pM^peHsJlQqT2^1$h=6Wq1EpwThCZj!npTWh@5qmY+&Z(&!Z z9^p#*?cAF6&1U)okTZC}TYLMyl`TtFHOLBpb9puyLH*T;u0E+wG03$%j9?yh`8i$av~Ow$B%f3m zt;B{g^J0l1 z_ZQj7DdNz|ryC%-)(rgRh$<~wJPJch!H%gx%eW-Da2m4&q4EvrKUlw4NIUjNp4^zY zo}Ks~O7cCDTvB@=HCmqzXwx#|IoZ^gTCAuQR@#!8CKv(kiJPyUJ;-NZ(xB#Gic#zM z8g$V0xLL96_p*@l5tcKyfcyL=G-`4cZa6<`DN$wi)-B#t-!Wz>JJ@G z<70mnYm0o9inkst8Rt%%S|m$BL2k_I2M9F~&9l%0i;_lFvoI+M(@~gXJm5La18J<4 zRk-tE31YDTDpn%&(k~gT{qcqHvW_4?m;k9WUxFBR^DhioE8Q+_K%(#HDLIX)KTV`9 zCJsH&I$j@xlDsg7j}KOgymhQ#RGoYb#M~G8B)mp$7UF<49fV2&?+Y&zUm;s+YK@@v zsFOQl-8M2}0?R8GJZv(5o$U6et#!KmGV&f_o#&3FWLMN4(BdE#SYN$yw6eYOccEq3 z;N{XzDD=FJ<=g7i>9XeOUK?#Fev{4B?aIr_w2vOGnM(F%STySbZi*w-F9%b^7--r; zbb{9HaTE%S0abOnf?q-P1c-eR3A7R= zk(#=dzqgZ25l6J+E%BdBZpR~za-PXiKxkd+cpbltqGHK%8i(2vtoGk6>D z6!J2F>3WgczQS~QF$0kKt|Q`^4RgfbNJ9C3P8?eN1E|wE8oi59w5yZj&t+p_tJ2LrWKPt7Km^y}9YF$e>R^pPw9$L9 zPHz92ip}e^(1qumZt$ZZ=1w?-kg)zTArhjaA0~KPA`te< z<^x`{p)SPpr?f?_1xx7bq1FNM&|;O75b)#??N@K6fml!}>;uqi@+Yskg`Z!s&y?Km z-kz`XSvpZA4xU)B0t9ZfF*y4=@L5esKNW{cCCtnsTWXTD**Hw z`QYlN64SDB5z|!_MtZ)^?LV5|tCFl7gVjGC6YW;2AJs*vBU|XO`o^0#N>;rnl0)-R zW-;D*Q1CTbQdV@j>(2uvRqY1aSBcjYQ^!=kaCBU*JepU0>JHiy+ zyRz<&->IiDLT0g;CKntKpI>!>p+aWz(W)JXxRY0Q&o)|&+z-I$rrIRjRdV=znL07_ z0bS8Q%*eq|s5gJeTLZS6#6Ad9$fYo<106HP&`dZbL-@aei|CFBLVy}TGF7KE9SfJ7 zRtMrq+j^hhNORrUQ;?1dMw!fbq06FhIaE$h*Xn7_7yP$;meM~~nw|-!M%2w>ib0|? z>+osl^gGHDDAo|NmOWqH=a73qWc8P;zGM&xQ+t*%(pM%BGEaihk9m?zW%=!eOvyBj z{*o?D+BMto3A0Cov>IcS`yq#NC^D+IY&!k%f?uaiZ;<))jQCPtjFXkYtXf2~YS&3+ z0Kw%|#ph9h@b8ryj%230)|4qq`yBC2Vu0_SXl@Y6KpG=_yaI+Xf}^9XIYy*)F@N)k z6Avjfu`w)*^EN=Z|rzp5!;4jsnXbs)Gd)8?M{UKB>LyT(!c29zlBhVp9u6)#=$iKRy zpfU9y5uNL{{qnqizmhan4&&!tGEdUX64@pfZ1%aoe z%95B51TzGjn!}U3=}0G9atVm23b?Y4Bw{_2myJ}8!r@JHMweO>fIZ}tH=WzuZcPOz zTUf3WLR|f9L0umW4k#2-?`cs7Z|uZ z;-jxDy>;rXrwU^#z?h=I0gLO=JFV)F zSkxGpiYUhj3Gi*uG)AaL*cIZ417HW$kqtk6S($96ypEH{L%o@$;(E2&@Q&zUP@KL- zJ&b~!8^Uv{BO04e`$$dqru3bz@Gf|xS*h3AzkMWkaCn2zgNTMjQ0V_I5T+&~kkY3j z=t3ULQO%FOQwI)+F_}Jn?iN0Q2ht;(3tO$Ok?zwBwmY5((U;vziWcr z0aqg9x6>R6Xtknk)=q1pdWpt2eaUFGjf8!Fydq2vf*9fG)^gWb3V(Ae_p4pAYS56B zhg1+J$uwyv=d%B{o&63)V@4>jN&B@1dkmX~F%VO@)fsDvhhCNmU-VS`vd#$Vt|%88 z@CAy(Y*c|sel#1r2ALK&qZvBcSNxb-C|EzVMnp=r!@H>d`?0E=!L1Zy?c+2xb%u=| ze94wfaI(rnK5Wpc0!3dBAu!Hzk=P$(Z;HS^i$c%9RvHU$7&rib%Jg+PwlDIO|L*HD z2-_IT6tJpp^(4$cxmq#5&HJLq7<#kObfXPRsimkt;i`Hzpn+TF2jQ)qW2Q&2e9KYj z(L`a!q;xHi(VEkTGNMj|6Q(i|VjL813Cnedjpw54f(y{#wy`I)t`8<9Mbxj@E6&&d zOB1LS90>WcA38_U5lj}!h06P7SmEw9sq>{5F|5UEgb=i-3=1EcsRtz~~C4137R zb~g@ES3#UJXl93xINYUmHum(zHeQ}JiRN$J3~sljFx@LFl9wVWjVyw`8gk@z!Qu@I z#(aem?yXP}CLcqWyhjM}(t)U?jOarodnRV~Y-YrvbBc#G7hQd|0f zQpD@m9hh3OcSE}XO9>4q+NOZSfp`T5RR~I@GU6R3H2s)1LE5vF~X!WkrM4DXIoe_so zwTI(3sgxK5Ddq3ybo1Sxzdanq-sRHzX4nE?BwWj=0N2A>&wb#aP&)))9TRCCl$_w?kt*0)If2^dE7wS!($!UaBC&On{EaA3!X!e!TG6>0P zCiOP5=@B@r3Vw6akS*h3n1@TX#z2*PGS)2nvhzw>+Ny9OZ^YM#w(lrG6dGzlHYRn6>jXd!U)4$47){C<7D6-1e#3l_oo7&o zRl+2eWx^IkjNNrKMarMmlYf8So?KXSPDT*YoXF+QR@2V_x~h7i(%`-$^{N%G_%QR0 zQH%N|P?s`L61ou!G_K3z zNL_3k-ZL4|EM+`FYBwTtPON|)8D5j^s&0LHB$uo`WZuD*HB&eQp5IYq0Q^Z;_ipTB zK75&BZZfBcA8)O@z)Dd>G<--9IZg=o!~<@rp(qp}_ug3;-9HEkHi)2*K0wk_QQJ_5 zOYhN75tZhR0UGAUx0@9-z+u_9gwLG@L0+s4vrZu$I1vwUyv!0@vMIaTwYz`!AZe@I z8(M#R9_i_M@_yrj(;Z-PfldSbMiIM%mIg@nf|p; z{+FRCqj8h=Op?q=T=Edx?`WR{!tAY^K+$(uf}c8}ZYN5fxo6y_W{`0&53(^@R1s4J zsoe6DPo3;@>bIXVm9#5xBa>&Gr5lr`hAJNc4wfd|T76g?4BCLPHfzk)WZUiHDg>|{ zj`UB1A}iUi!WjOmF!&fW#0+%8LyX^`DVBjS&t+zw7Yr#Ws-5Wz*aJp)9wAw-AS(Cy zw10GIW1_JL5#9o}B^~^zt#0TEQ1LBBao?e6u{G7y;up!mY|5fRVEhKoceB@`ofTZ- z9rGX1(x3z<^aCB6a$jl%4age~OHoV$Y=(_WZX)#n1u@J&a$H)yk7$YC@TQCKwaLZe z;R$NhGn2r%^e|t>W*(iDYHbEbY@bPT#Meo$fs0{& z{KcFotAd@cciS2XmeG><=g{sjYnVzp(p8+UZ{k$!M24HH$yNK1?*Kg?>pD3EToaVw zY>FsUP3-IZZnkjekB2us`VfN*W=Vt9d9T_J zUjfqPNxT~OwD#&U&->Dbwu}bUg}=l*W2FngrSnsl<`%O}cdIsP;-l7l+VCB2<1$8u zfdVS#o!hSdg2G7|`i4E{3UZb!4abW(ekhwR^dY+|uapz{2JmW%HU7dv$#) z?c(@}C2HaXcv$Ii92iKNbZgK+QtJ$DO3`m<878zvRMTE+C>pSlLXBW8Nxv6#!!j8Ep)!O?mY+{xC9Q#ET4xNK=p82IQ~6+d!?72O zIB$hyJiU-lh7pSW{Sr9M0*+zbA{`N52*E{Jf(b5=+M6DJB3#0hTn%M$aBgcirmNKs z!l5a3lqBoPVPaig#AUgMQC87%EQ0pjMhA%oVlQfi>^-p&S8OWPRG@b|H>3o+J1hiPKXY-YG6GZOwRhx9 zmzX(fopXDG_0GCIn(aRb=5^=CAXaf-*uCj9!N?Puo8{N+zb zoD<$o0t15i5L&--Wxv+zKRaofDapK+uGV3q{caG(^f?>@WvtCMc#BpsN%pckhS*BY>y+jD-5|S? zj!_}yUd~>w`21KdNw5zgRdccYIk8Zs#E%p8&Zh@XuYvdwwYejKkj7b&Zx6<<%S~Vzx(;_ z;x?R#0jTmP>kRls6@zhhI$r*)SIOut;jyaF45>LxydW3kg5G!|%7ilfB*rPjiG~bn z<*yV4X@s?wvs+D~QB8i4Mh8})z7*z3lhlu=(GO5BjS8N^3wyw$JYrKM*0jJxBn8W} z+T@@T9;=(_L1GwaAx%(>xB!s&`p+zLq0!+KH)A_x-vIRa5)JxmdO_5BST zSv1YFHT%HNnoPabKm2LIt*#q6uXmd!Js3O-ab-RwtQe{C<+ki2Yk_6ey|B5W%*GnQ zNgT$V7|z601R&T7!h&^biW1}y`#MsZKyh8hXu`o#2xP!qWPlf+qlb!`;VSH;WpL6t z3r#P;xwS7q8X`a5#PK(o-V49vbQrWMxQ<&g~rJ8F%DC3VtF&rFD%0NpDH}WxSR~+Bac!g^e zOiAgnW1-QRdX$z8fRg9^R#p423joNeOP?h}OtCba&<8c|k%L|KX zcjSz!iK~OvTD*_IPbyN~DL*Df3a^ExP9d;VtlF$0v#NLPMJ(!@-rMJp!!o6p*z)k{ z^AwC81M92#J7cU)$MY5Ew*U#1{s=|>a9jX2`7cw{J3#)PVAZ-xXFK-Cp+7ba(2M`j z@Q`|C5evj^P;S+PR)cUgnDsBw^p-ex%!5owCV%Jjoq^^w%VDEDX%u;$P|6{dYvB7@-1;;cME%Qn-1g!cXM{=Y7eJ>`p zQA%%WBsPSt&^X5Ahq5GJJbQ6Or4#i%C-ZTF&Bv<_>39A&m^j~_h^-daNx|GSJZxT- zecEtmJ}G~Sc8X#}JZSb`s^5|ME91Vao37de76E&}^-BOdyn0-@F(>M>$I#G#tzl;0 zfOqJcc4kv_@2*iZhNhqbY+^ItH^M|kw_$-y;2%@wsE1Gb$Ovv+saMw~+QzR(Z_5Io z96wQc3O5>;B@bdva(=9@?|M`?ilECCu%ykvUqt4f0syk<8S3MZG|3LBSy2Q<>UhH| zia>aXC0XPyI2zudessgJi|w;)GB4yTILet=k|gxj3mnv}!Mrm^^~>E^*^P%E!M^CN zZ5sI9jG~waICMcbpY||U=iWz0F1JvF-VcC-Hvn&R?IYXHuorlOPX2{+Ji=J0fSY`z zx_d`hkVjLTCfgX9{e3&Zbyw>2%P%UAn=MFgf9O1!>gT#IrH>RVgHjf0Ycfg+o~Mu1 zVS5V9k?+ytp6+Jm|JWqgRQ^m%1fas zOb&=o%7}+l-o7)Wd)Fxwp#hmEbQ4eIv|Lev=0}_w2muMYr~KL3|Mpm&aPlA(;Y|Su z28|Cbr!0OId%u6T0$;zQbkr#hP?ewO~O%u z^CuQAbU`$eM&OongY%kZ!B_2v+x&?)NT2-&(|Uj8o3@elef7FX5Js=6Sl`RhyKCpa z-$w3MjR@4ePGxh4q>{lC>+4pE5b8C1pBuymg-Heco!e!KWvj^2i+`71f|NK??cjEPFd;gq%Ct!oGh(4Rvc&`^6={(6oaV&-vGZuv7=KzS zg^Z^p{VC=yU01H^5pK@rb!XjM=LH8T%3GMW#QJviu(P3vcsIX=Kk5lxDOEEgZwPb* z_o5BIZhzyH<{o=v=WJMZ38cm{Ye^TRMmX`tqP|AN%CVmY#u6^CmNM>8@N&!C&;qa^ zDFLCBy#|s9f47}r#+H3i677|8`K%{HAJXUH-kp)HAky|JDjkqCpQF zt`5e?*+^$x9Yf9E<)_qL!D{)=$3-FUfAkUvS zUapAC|K$tIrmwD{y~jx&pBt=4&+*puT~m5x-%LJSELc1mpAwToGLV3v)$lYbMYk#` z#0CEk2^xm4hosc_yeN9gKm_h&*eUW-$!r3ox%QIRC-RzPO;mbtvchoaz&EZb24>P? zEc3IG7Jg8gvd5T=zc>+C_v<2=E(R!2H$PsJ;D!~8(k@|%aze0$_E+gg`Z|*iq*`9+ z*6EL0fgxI&`3B{F^=_Rt{-;nq@`C`dv&n><1n@rxH=xThVwgHgu!r?8QOU`UPK%=j zQu78yC=XW1zFkk{;>PXI$`WhD9)&6uq+*T`g?|vtG6ZC>76o1y<{e@Tc;EO^EsRx2 z6=OX7gL0UCxf=J1iF!=?(TOZfg_cI9xM{I}0CmT=clUu$K1Q(cjx1!oAu`EFaLMHy zO!XQmnN9li-u7^50YBfh`H_UO7|V+mX>X;px50fM$N9^r1Xd zffVH+sAzD~z!Ca*-=4+{MKSj~Q!K;u^3ai|Z04D6R0U;e9aXWLGUxAKBmxeQQrDHC z|H2IZ4&#JZ;~m=z603;)J_(Rz|4isdYt=M?m%~a!2<&&fJJz7pmF5nIqr>`>Q}ID# zV3nXc5g!`P1+A!%$9(;|yo#Kkw$@7^0zn-H1h-4eA*Efh__#j_`wUx;T>Eqyp6{-} z0cav-L=VH`R;ox^0ORCtnp!~2e{S=fKUuou#=-DZ`}LDJ!BudhwRDdH5(}{>F9X;9 zE>(5-9;qnoas+5zZ_Af_*UQvR1yCgeluueiJ-L*|Aa1eEd-KJUY0a$Z>3c!aSlcZh zLKhqFpU=3R`HZvpb#fwue$;?aT@5ty=KL`ZuBCM4{nlcY#O4CAplRQ75(mP$*;Rny zn7U6rsnr>qIbCu2%o>&yw{h`}IsNCQpQM+RsB$)poTnu*;;P2Pv7^XB9m2fN``!55 zI3aP4?;O9tL2KwiW$dG~E)ga|!I^x=W1%Tg>${a;o$eJo5t<-mBY7@Ix|xTpMz9fy z6f6VZAo)^owgRp$kbEEqoVlOvtSlzgV2;3b{HlV)zfouQ7ol%+=?PL`kx5cAwh2m)j)ZF)if0%a=tr;{IH zou6#7ng8uUE>A$)F2=ZhM{q%|@H}Wh@AeKGuH1VyL2gR_XE9k2vq{BmvdqLT8w{0@ zN|;t%Gp(lu7K-cgAXP;vuMrV#K@y9q+CSZPB1~e2s4mOOcDF30nvAGNvO)Ta`GHu$lm`49L zyl2umh-!NMSi{)7%byZjUJx4rYk~m+>+7o>v>rqZz)mJGS)gaiND3+Ls=2DaCJ$O< z2FGzQ*bY^pS)E#AXmo8*Ai>rdKLH9geS1`DoR5`7ix+D2UNANVl46a60XVL} zYdXBnZd@f6OL&gHNJ}&uR4w)`S~!*w<+u(&dcYAo_9u`;&))=s<2WR6e4Wvgqw^ahsT-3cC>q@p0IQe9l}!G82TWUZv)g0 z>F5eI>*n}(ZUA{nI&6pDOIF#0kD}`Z62ct1HgCo#vsumjMu(Bpdl?grukv<>#*;og zoH2^a=Ib|Dtl$J_$9O;p91--_{8T9n17kN;qci}IT&~zwIHs+32Q)shwXLr?p}~yi;=QR|x9t9pY8kVj91m#&h^pJ3jGV$O{5*k8795hAHTZReHDn_$ zaR0eZoQ=lbi_CscNc!#VB1_CjnfG{~B}Tf((fvG4rQ#k*`>!in9v$%=fw3>3*hM)m zsup~M44e95{6bDIV<0qm*L{TcFQ1@dMwIV4sXMXn#CN*iX4GZh)Yr?jye&mZ6HNcL zG`>x5J%7EP=y{uneCYw_Yg2ywr>XC0!0!od??0Y~hUT6tybs8aIguA>Jdx!!zPn)- zL)h7edHTE%2YJWbqSLGQKXRzOHCK+Kbw!7z0Wl0v6t*IY=C&ZBIpnBpwY1|Tbn1V9 zOo?F&!`X5GE;9W{p_Gl$PItDm(O7@vY?+qax_ZV~%ra=$J$=s$#`jkl?jkx0wzg-5 zBDV`KVP1j;^W=AU%`B(UjVWSlb?{LuV@@%9E;Z;Lyl-Irp;anpeg}EVp-axx7CbUK z4NDDZeLaC!J1-3C9~8%dWQ10%bV{EeoTlAeO&e2+fp5Gbq9EFjDFo)tIDE5%B!|R% zvAz8$ZM<%GjT&;=3?-K>*HMe5huAiH*wJL(H&;S5FA#)6o_vdDhxM}$3|z411$$wt`-Wc5eFA3TkIL;Udy-# z=R9XW0;3pL;kN^(f5^)sQv25y=K=)du9VEaOU|h(JHqdfrHV&@ZJ^f3#g(*X{G$+H zG~k3tCWfVLTgLu&RAC2@GgE?>#UcCx6t?U*i0_Ma5ue*no@>v0e4o!(s1muq{&SZe zbrAR{J8f~ zhvk8r)W>Ps62Bk>UbN)4uk9Bqm!hzyFM9&hxxrf%jN$H-?v!ah+nC2qNL9hN?fr7v zwv^cake7P8AA5dpvFx@u+y7|?xUSlIsd~cl+f^ED2m5nA1wZ}^|EGtC2bEYeaPm4z zFGOM6F9a@cJrY-CJtX%l7kbwsb+@Em6Q&`;Z5)(h%n(SxIHD+|;!CoNjgWS3ctVnY zqJN`*Q6H&4$Fve*{~jWOe{|?`ZpK#Yl~^OUq%LZMzRn8;bJSWWkp&9ZR9*J&ec9BsKnMSA z1QwCF^V7JaEt~ytJrSQfmifgm%V}KE@SN0P6vDrb*EX+2tfAoOj%AhsGunnzhW;fo z*GUb^XB6FbD?KiKn7Gr3yY2+AnCUQ-qeGmpn5X9(FY1nb{S)bss@`)4w7asESnPP= zDmS?_E-cW8!2jyI?XqSu)Pdf_`b~F807ORCKpwPTlu$M{zz{046S>BT6H7&Jq(L(Wf_}X5zLHce*i2* zjF9Mn-|)I}P)ymPy?aFt8R8!#GaSQn_Sbzk@MSp%%25(>d&CF8VS;ETbsZnb;9USB zeIx7?^4$ct?*3be@hqEKTJ8^*4M3ufi=V1nrNs6>Zy7jsM0t3&WYIN4(nAKt=~U)W z9p3VhtdHADo9$e3*x3|Um#aA$MQ+QiRPfbM`6cYk4n1dT#d~Os&B_XjJYXaU5(yB0 z9phY=>Exm(ZYKq!zu*!kq@Z`_`=#VhWy~OGL)trv&O?uGG>Unu4RKUyFo*VU2j~2Y z7doe+@cv;k?>61S@`>-Tw+KvdxK^?PnhP{V(!(Qmv-(e#l^X#t2evtXllo7xQiaHM-P90D=4xQ&`2u%)bVi!TF*z4X6hzXQ z?LVTaqS~vSiqb4jXH+?mm&nl=Od9$^m|>4>(XP`(YRz{>o*X%jCWl1_g}?I-LVgGN zqr77J2ZN(ZbT{mUPu6+EUf$bvR+m9#&AX!aHoYKhY8rL|f#jmZjuVYTjd?g4ImEFO z^5EXO4zU{QLg*LP4eyFVh2!LzV)1PW1P{wyW@@!WiM0Hq+D5GUL%y>C*UD>uQ^NNt z+CD5H={<2Y?Mc?K-KcB39oqQE+cKJO#R3%C{f}8r>ARrNM|9?$Y`P8P<;zIx(P?ZE zVVjEO#)CDRJJy9I0~UveW2u;l%cX6R#0?9OMQPp)$32tUGo-dCx`8w!{YMshry@(9 zRMkI@U_ieviOAuW6)NHQ1I!|Qz1!1mYMB(me9=SEr^H$cGM9^o^ikne0<@?c{+ue` z+)?AQ)X}Ypf!6ZMdRgiCyT*h~NpV!d5I;7p=IwMig^`^3jxjV>I}r@2fF>hr0^%Mz zMW_I(@2AJ)BK|Sl=pI*%Z@VY%j?~=%`F-$B9(57kF!3TDHc!4?ys-wv^(RUl*MO<0;*gw~v!F6+srF60akjo_U%W_NG- zP1l?~_^|w+o*Q+-7H(YHV3)`1{S9_*OE|f-VAfECle^T9Lo&2iP#_|R10kPg#TaVn zPhf+qweWx8Xi)!W#3Lil#QGrjwO!zVDT>=2H413yoh&5=3=~}Tdclf0iSrRSLlY8gBJ%fbI;2*k z$bbf6ZOC2;zn^A?cAh?_Ot5uIt>N^VX#TDZ3y9HWvTW}U&eHCLrd9y%(h8XeLcePq z>ffNHs2)qmWhMeT)j~!J?F520ck3Oq{;BT_M5yVVsv*SIX9xd>rL&A`t6{o!aF^m% zw77dpaVsvtrMMOM;sk zI^f-RPc{RROrTSx47AdgsUr3(e_u=}5@4oAuf@;L(PT;=4*|=|eWJLNYKudm5ehp^ za|uqjs+A;V3{1kE!^<2q5%u;eP-S?`IZg^|#J-ZZ{u&XBpu%9eOv5CZ?ylgx+hcvo z6UgG*Z+riGJ7u{(*^G$>NURMiTYwX`wE7n=egp+6jxihK2PdcVl_hs$>x~+e?8@bt z=~*|rP559_KYMy_*cxaMiO8y7kcegRo75LKl=-(PnDoKDmq5>XSN_Q}wvGb_*kPW# zKnMrApXDGtmGh=~!rpaO8jl{}*3ggGNz)|$4W}a4_75;F3V-<%H7sWst`w+!&>QT> z(c}SCO<`uAPWv0-CYRk|wW)X0jtU6rm%{7x63UI%XTYPrYq5!S^M#Y0*KLI7Rsi?i zDT3jO$W?o-2`<*V+so9hLbk^-^xc_w74oi)T zqfk#Mq&_|@$AvdRUoDxXdcBwf>4>iS#)a#%=p32BcHNE?A!<*9ElDlh7EzP7528Wu z9R-JZlL0EXQizVpFS56rg{bj(@;R@Pv*&Jb&)o+N9J_ zspz39sO|5r(99^ic2heQBvW8fwOqBEDA$;7cj#9GsBRft$Oy|RQ`9f0ki7A|IR&n#Fc8^H6`p@a*~IXJUp?zXN>gH$-b#b|u{V**p!L zr~`{LeWO#GJIebYk(|fJ437AqiqFd^jwc@R{#=9_&ADhYS&EY3F(5`AZ zdSsFdwBz|>vma?r!`VgDXr61~RtA|_wc|vO8E2a=gveXF@HSev`&%b=83NCV)Suk3 zq_E0(RAzKy#1><47?1xE5SrKh{E4;k+~u`vD~f0zeo4e_B8AE@ zgbD+9mxl^C{_isw@@#4=E{y_B4Z_XTq%96qdLM$Rb(6TL0Z-#rF#*_#%7_fNdf=(; znPm+P9F{?fkOL)z2zfwB(=`Rfu>{q!(P;!<-F}0*I%w0J36yVzt?tddOYU76DAK34A z-X%W?+_`>5oBtj2AX+Z}?;k@}%s^MY7q>4}BJq{h$IBMtovfeT5GU`4Tc;@?$^_v?zPgA>;J(b-p&IbWj2I;vGZ!KVd60U&(*aK zN@b1C>F`nZH+M(-{#Ve#&kJcdxE^1=_X^O}hQ+L05#&@UiS?;GacWoLUn~Kc9|2 z2MCj1uyKzesXn#Pc(E;0YP^F$?8%h}NE*J{p4Jd&W>b4;wmiqeZqr->Wo|`ed(_V! zHx0#9z(9|s<^BCH3=*9n^dp%_P>Sy&hUUWv4XGEUQ`aM`GUpW#ao|Zxo1JJbj`+ie ztuYi25hQDE4)mM0EQLctaKC?-l+j&!7mW9957YepKiO&_PP15aE#-btbn?!5InsmP zZk&(bF<_UKNz{{pwR_2@FX}TI%U3RCh{Q#`+uOl$cC*1?Br#;K_c}I3YD8rDp>pF` zeP=oC8HE;^sW$uLssy>Rv+ji{?-M#)azBbOjuX$iGT}X2_e;C9D2hw2Zr_WBQy;;{ zEv(y@kkB-b@+b{+XsRQsUqUlHno-|44D2$t_Jo+cJNiWw?!T#eN+5m4rSHXGy-WC= z5QNX;>=1-s!;%lpymLgM5Q_X3?>`xEC0v9PaK5rrqCt|}t9&#Mw{^{9)%E?z*QqS9 zJ3NhQoARG96DgI+<{Q?^yhw=-4oS~F9&BT{M{K^ca$YUA|WfQj#{pT*v4--3xb zjkbS(+j>zmTDJ3EbJgVpI1K5ZNUObC4_uu&+m!)3Z!YS2H~ z<`q^Qy|g|qqEi`G8yJfe$>fmS_9Or49-Wz0%jFzjoAT4@=t9U@*wysTB?Eqi`_4mQ ztr6P)p^%*)&8gG}{TJMc;~zbLIMs1aBYi#lDvtzE(GdW}4jvIL8Z4Y#shI^pgJC%k zWhiZauw)J%vN{wu6dB4Z=6_^ay1MQ|l*TsYQKH_TEuTDIl)|}WRR?;`;v|^jU#;5!AbjPbzbz^jY9J~nztmJ zlyMJBS#>y!T!UC+;mZW$&olMk3oYVw_C;!j3zG5+)P+>C-{xg_XI@J%@;~?k-uv@i zRCybhWzw_^DSa5R#`Vt)jg5fK%xxstG=@bJl5%qXri-}q#nYyfnV1+<>HcdAV>Xjk z(*(q&>BTDi*mN4YxRXHTu+YS!x_2Vp=6WKTs2t;&d5FB10ciZ-;&2HJTZ`Gk(G+znNrSM&s_^t{)@?t<@9LpsAZe}LOn|_yi_a|t6#PpZpiRAlpsoif09=eJ2VkWy$FD+L{h|6P?R&S zWAOa^DORq+49Bn> zdQRUv&=Yq5`wIdsf$cz$9D40Lk~n|NMjEk1RV4l~-&TapkDguvOu|$`IFi_#nS*b> zaS~mS%x_{4DMn{acvm0rtTkpSF|N0}FSRqxouvQ0Q(j*FFGulz$rKl2uOm87FaI%l0NdZ53%8R2dz$mDf3aIgge2R zve$QnYoX5k&t}q$O=0i39bW&d`1H{+dg$K_T zR%oLi#zi+PBzGu#JLceZOV~FnyP;uOo96{Re~3!5UzAhJIg1jHbKX~;MBA*SGjqVV z<=NAXSc@O=MoX*pEfUjMge(vzugs#`J@JuPdXUX%$TFbvH-SiDe->rxsf=aTo(P-xcAZF>C8Tu$iU36D6>Yh6$Or6pJj88L8bz6i+HcX1!d1v$ zd%$#l0WFLZ)OV!$-X!64d~dL*YxBJ<+xXzY%U6c%5h(|Sq-|7rmH@hx(3*>vweue{ zKIRfTf>vaixzP>DJygQ`Vd9bKq90D2r3E)E5-fmasB(`GNkJwGN1u%lOhMyUzuu`l z=)nw*SHLmUPtEImWv@l&K5D&oc!%2MN@;iX94cEanD$GVkzzRYe&HUz)%clx`|agz z_2tfRV#FABbYbUM(tjR}e@1=1h5-bQJFc}~=Qv;R0-i?#DYhQ6w<0;;E;&bLUN&ZS zp1RLr$t|z5RQ-YhdxA#Kz~PG^lhqbKqtjjlD8zFrt4V7#4LS!xAj0tKXCb;MGtTG+Bx<$ow0m=I4;}6{R89bK=N1fQUrbw_qVLFZ$7F&Vo!?ihst#YH9Pt z#CRRG#7Xr^TBc@gqmkuoOoq})t)cQcM4Hyk z(I*(Zu`ZWRk|m7fG+}z+Xu~fTWB1d8e18rU9Owo=oDwIKz-UtszU#!WV;|*dv^wu| zRcNX`GqT!Ei#3`uLt%Ng9MCX4Sp;FLmZ8k%y_H7Kcr>M3QA#Msa^p z2)$qKWTav!nkrJ4TE>}XnteKf8^p}y_e|-*s(E$LOp$MJ!Z^bI6^Uv2333($6yggL zKx@~MJ>-UCf2no$oku}f!p`mZ>%!}PU-1n`o#xU~kTS&YK{;OH_wMaI#mv@A;I{jd z?L~V>ir~-x7Sx8_&`{28oqz)!sP)@f?$k-RX)kd%6zLfP3n6w642+)%xNZEu5Rj>! z2SKVAU4G|!6tAiuc$jBs-^#OfQZnERkp#T9=87hk)y35R&fDJkN1eJd^=qM=c{^5D zRsFa;C>&^B#1^U{hb*4Vv7l`4nf+O#=W4g;bu7+7&^^o}RO^46ZewR46&D$_s!Ya} zt{-E1K}f?0P&XEBZEUIkt?2aP-3YDK;*{~j@@3{G4>xDzW$Wtc0Ubh~7}Zw0@H~j1;0C#8_QS z?>DBN=wMPmosAjpje{Y3uGM|G%k`;+N<{45mOvQ?m~TXphK9avJ6%S>OG1>GfL8LU zsDWBag@%41)|f^|>U@znU*-0Trh)i4fIBStYY{m-%$E@;Svzu>P<8nsplsnolMkIX| zy5Y9AdtG@%Ke8JC+GFhA`OLj%zATxqx0s_8K`*`&tqs0o`8w5el5Ts_{Xy?uc_rnK zzrkuo17FW%YSE8?4kUT-%b#Li+yX|_N_cP%%}?fF-S-4ESF7U8q>Zy>D!W0eLa)40z^K1~nt}Ge_eldJ1n9 z{(8UILTV>q1UIY}=ZAX5qz?v5yaU`vDVHEeKJBFBB{Tb=N5GaoDI6|}Zv_gAcu4KA z@a}Ln0@m8m`m^eDZPC1%JQb2y8D zg@;^e*kq2^D7N;0xU>;A+HZxo9yAn1NQ~F^1s%cLh@pX1eUU|N_kR%>mJ}56R2rP~ zH1#o~)e`4$3R8lL4nifK_-8&`_Ul%#-FlniE&viF;za^TdO@LZ)K&sxb4_?)!%4Jp znIsRdou|6H5+2v{Y^p^CMSk^7h8nntB{Q^y_*lE~NNwblOX00JnN@u@gG@I^j$lGg zD5Zzv$DTd^R`JD*6Dud(Uz>?AB*@gneU^QTXZ^N}5yJ-513OwH8OvH=Y%Am`^-LK< z$pKNz%ckiU+f%+G%`ah!jI}+x-rGiJtMu+Y{u#;~{)D^}0;1A?vN!@^DrP)HGsvN- z*0IEUn7sdR7TZ$<|7SoD{lc<0JQLyZ_ZcRv%A_?k^X>rgEPtc9Fzit_6z38weCguo2ZW3-v(Xh%mwNqZDX_WPrY+Vx^e>*Ap|g3nPTGd65M-Dr zvpkO<`BZQkH2YcZ!1&u54w&{9mLs)cBUCP2l<4}i;kCUdW^Y#`r zk`u0vu?j3GlZ|h$-eMS#)oD@b5Qc*3p%~g>?+8ZvNnAMjRk1`*9>viQ;r0t9G+hpX( zSYJH!=rpcmqtOK*T%^skKQl(0*yPA!81RV|wx`#lrA;mkdy`1K!&YHBXkl>f*JNsR z!Qj>X=fv|h{tQS^UJeQJD0Zr}h}JOeDIxets>xT;^S-}DXERz^{#v15#4P_-Uv%?7 z4Wp||KAgp=tP7U*nSacQ2f4DHD4BtKpz7vBR~>OPL zbUHn>S+*yyH;ZOkdf4%3An13#twnt&Fa_-bBCs1ysXNRc*T-kXSG2`Yrq{>OHmL%7 zaSgx`#;T*9VgU+&pNw)QU3mNw0Olc3 zq>KDL>MUhpkkEf_sHv?>j8gvU+zF(jmGa^3Xt;7G>>IoBXnzbi?_AHGgB$NHn8|17 z$e66dFf8j*)Q`Ilc|e0%Izt!{I@E{aAJFEi(=dZRby~2iR?<%9@A$QQ_{iged4(%(h{8xK`9JI0FB|KJ% zQK`27~UAUTU|u z?4_YfV^@n^oiYkGONfCKR%2yxEH1VTb=xeAT$*kGs=MICuMXG#w*T4mHV<&ws(zI* z@Z2-rZ2EM++^faFu+!-}C$WQmFx63y$gb7qk{S@aLlL79O@6z&0^G3~GYQ&1m;C$s z(FVXwf!1rrBl%a;Rx+V^;nzWP!{d0>Pr^0nCrg$;ep<|TsXI?iWJvpZ2{vK|?XJvP zG~mUi5dQE0O3p7O+&9!_MkIELiL{Ew_v^Eby8|o~CdnrlAiXm*=JImRTQ8FZP-Qpv zAON~{g3%j(4tGDhNSw;V*>XsFz()tc{zQWeZK6idW#L+i;afJ5ov?RPPZi01cB=I0 z5CL(wYggEABA?kXv)P#{(lgP`(j-)RwC^Tz>Y{zGQGF7>$L{5nW$M62A#bF@H}!xg z(=PSSXf*mhX8IYC_SLX2R(uQvB1$ny3|Ss}B~B?y(Y)Rz9$~4)pcDgiO8yc=bOwPF zad=WYctiB)z`FeHOkW_H z;A^7u*!Z1!C}rxf&7+%byfMoA1O65*>ZhDkw}~-430&%4^dst!h;6!HtZ-D(XKzYB zwQO^bKUoI|!#^PRMJp$8%^x$7HMLW3juMKHKkM5D5t{T9!mn-{a931h==^QxX3g$t zomnxq7#d@1IRHR1?JllDKwtA}t12`ZkxcVSxlkow&6uc@JmicTiEAD4)1en{oDI__ zBx-K*T5M?fC*@1~*>-D$V(16DG5K{HHu<+04)w-Rzr zbAxA>e#QQ?cUa5#ReXAA)s)i8ElI1^Q~j+l`6waCZy|JhkjK6Vk;uCwu@ojAUO;8m zwxfnvGwQW)4k+>P&gPBkB9Qi;lKa3v5s>;)%YGzWc|k~ z=mVZNVzZt2+n&T^@R(!Sut{nPFa$LpP7PJf22obpVyX&e4JinpFOMhBD?Awl$60{l zP&rWDo$15hR6Y=fy4Mn~BIQ&*U=ceH( zlpIBIp=f&u;WVElMo&X>@QwixI#QDIg-BpKN57Q$&++MZeMwA9{9r|7dA~g~MBRJ>7fV)!VEd2y-jl+ElQu3wUP9f*=n_Tn`{Xc z#1w3Cf{D9wI+j2){|nvQI-&|k9j@Q zvF3l6;i#`NYGID9-@B^RJ4D!6ZxscuQ zs+kY$(K))nJ{0a(Vx&UmR-N7Z&A+gtzgY$fi%;Gpw*111Ft!Q(J~!FOHwL>5brI;Q zQRT2lSyRJ3>q?Ny*=`b}0napy?bE&H_W-A}W~?Da{6E^%j8XEid3O_ZHqvTCGwjQG z!4uTT=dCCf_BX7#d-OXDnWU;Le>Vnwr-c2d)F^E#`O9jETsuqlnYw_s6&4xEKF0FMdySTH=D; zt36@&!>Z79tAvr1?B}f5*3akF+wBK25G*PIYx^}EDl_6Z_UJ%Ys~6tEXeNfO#P=Eo zm;2}=3JovNS7J&-emWOGf)w#nSsuks*e(pKR;pwzDW4mcJu}e}V9Laydh|Q+>HoHv zFn1V}z=|R0*2kx*SGIE+30!#7W}uL;cGChz`9XFMz@4J+0u4w|GYbQH=S(Pf--SXO zR-S_;3O!5iw#0>=d7NX&8$5; z;CmbU0>wM8<94)9W$X|NJ660iKT)QqkqH$0ibjFIR{<9v=T2Wy?`fiNAQozc6n^dU z3|Jiv=es6s;goUWD3&Kcih0|J<{goVAM`H8lEpC;$H;B}Q2&C}4Qx&|AaTQ=%D0q} ztsO3Gd5r2!*}B~`5VnDPjtgG=w$FzXyd|q9ow;1OSS5|a_gn+E(O17Qj{hLSa5dLZ zc+VUzvCg&(C8LQE?Qyk4VULs3(8Off9DuHDqK6nJ=)VS$N67-3P`O-j#xu!9r)v>B zeRmfWFB^G5QYKM+bNc0H?i<$*avhHIeZUj-=Sru)AJ4h>jKJa0u%<=Q)1=0OV9-mK zF>g{W=~!!j)@7BA$JITaJs^UH#RA@_1#PwfFD0K@96-Svf>p5!o9)#crDZ93E!b(hNnfhEAYFWZ9z(!}Daz6S5shD| z#FzdV6TyEtyP7`KF*s~GIQrZmQ-mlXy(Fu*I2Qn*=!$ux#fp=!We#8WKb|=;UbIF6Ez5vWIj0R5~0Lh9b!%ODI z*h;r|)r60cN6^BOXMpQXxY)=5jt{ydYMOJLSP3CoBDiG1-DODf{xlh(jQ~$_^!lz| zS6x3AhLKm16Zbg?=OuHdWs%A{dp5xa)VvD=DR4bhAQ@Lv1+QCdNWtz*Miw;}T}IFwc!vAh9$ zmf-hy_|v7NSF77Q|1d=ivd)E9W=nmujQ!E9giHgx+Wq2d5A->h*n2SG3Hj$$;G?u6 zXcTcU79K%t@<&=)=fAz6WfB?qSE+hKSY zQb4)~GXI&qnSytqNRN|b7ca-zQ^La7WBU=obBPskzp37rcH?#UzWK_Fbl(ceTr7GU z^_K4U`>s7WPuI0N@W1OJWCRP;d#_M?FxiJkYlJXsNAN%JflsYzlsn&D z*{PiV?Z5KX&&X>PC}na1F(qC!LgsD%BR$y2Bue6HLP=zEJyL{-Kcr?h9d3jL(ytt_1v zDhE3I8R#`9Od9!f0m%y3Ul6G>!4OI9u)?0Dsm2dUo!*D*@i0(4?f$GS_|$2;15OQL znjxfG2|MXmTFk(I#rCQ%XT`AM-_--_`&k~^+wBPtlvuqYvM z-%VXZUrmf>swy--`7nDTAeEH^fXqxXm_VC)jz0)k(ry?!nKH+)0H{)dE``Q1jPy3s*! z3s3m$RNfb*xuvo`o-d;vjySF9?TCmUut|+c8d3wmtkL8rs5!+0L`HvP4j~p|D{xlD zn$WDE@C4X%kMvGmU2ecTC(uGn!it@f^{g?3`DaaG|C#4>emFnOdcIRt%bOI)CUSx* z1U#z)R2l^13NbKUx)42P=%Y+td`ijDoDoOqizBDsTMF*pl6cRvUPKbff{P>mn2CJ? z%nS4!+|fAC`-FYG)g|+``4eayO6m0F|CgSKDT&l^b0xX?_S2c2xe)^lRXq_WyWn5< z7(dam)Tkn^#)9L%;w&6^j6uWD2FqE^kZFuoYfA!~#@!P5NZtd1DcBvsKnjs# zpeW=RrFP>u&KcB)jVR*N_KJ$|MLWZoo7DGElt1D-x%VFSK-Ky5yT8F(3H2%MO{-ot`x%Dj1#cO{qX%G`z_E)Mi0Rt^38dbaT;JEk(21G*WY&!Pqh)&1)Qc#2@Yb~f z!<2-`J#m5T@1E4jp^y1=zI{VRjawthihvP%+gO=(_~0UthQnYB@&?no?a`y?#8Fm_ zj?+dV-a&4<<$6O+c4Q6cZg3mExBo4evJ{a_@BL40q0$iI&{ADx|Gr1YVx&`@pOkMX z`xTf(3K-O(_Dg=FX+0|(^Ru!MHSC%d!qP*edjs+;l z+~3x0RZ|lT^8}OS~%q}`eZ0vjRKZS!L}DNiBQXMrMNAg z8qqw9dFSce79_kS8GB!4=GdT{O8kaowCj_;%6|G76F9Wd%T$=92r>Dx0+YE=NQ&vl z0(l>}iH|}?(2I*8FHGaT{Wm36CYr`^i)Vm{ElQ!tep{wWTI@ij{1f!Ys`Sc5-$0DK z0jwh##gdhdZNuu!n66UcFh$JgmnCskqItc66H+%W3VJ$FA!FJVnzri2R;GSIpu*ph zfNB+HPplr!Xwm8ytDn1%c?WK+cgvJc?O1{b?u~fca0vkuOht631DXiIi}>z&S3s}r z7sUg_=4NU{u*?^oCbt1)><<77J!zul-}x2W9?$J_7sbxB&Ba%ZExU*BMxCh5f2#o1 zu`St|>aFd)#v6#Dj3MN{y~7!1g=^jLf@KRlkqbdGMmp0p`)0yKIg7^B-sRKZqdGE# zP;>*r6mVhishH8n#>}=Z@nGX5W7RqbH1d04Y6PsMB~vMq+(J?xLT4M#7sL;KT69mS z;duGeOPF(i(k7=BliWSKWWE4fEt983hQ0g`nPq8OJ@q~(Kx$-Yx;m+h^L9%9*g^$Q zrhzLG-asBX;5_{n_FjSNXCY4vTr;VZ7M_sM1&Y*=R$P+(r&O~b?PGMXZQD3*z!eJf zZ1^5YhwE9M53G?qve_4Z@&tDJ#+b4iO}RjJejB)t!2SMtEPC2i9!~$LDK`3_(=FOiFEa0v7&oOJuD`d^v0Do_( zok+Bp-Lhe7?f@BIlNm&2xThLyDdc|7WY5+)Xd%|Q_?GLq+s$3anrzj44+}GMN2awHRFUfmOS2f zBU`UTCmQ#`chX|;XwlMW=?J|ge_b( zIXRF!JFCXu3-P^OY6AD6>xjUzoZFv9+`J&FV9BQ!Wt3G`Ir?9}naYg3i-%nz7>Znv zXSX)RvAYAjj)+{3X&kqLDB^GhXN0#&;&5j+u}p^rox_jne#3S23|YKaX&f3D+kHvZ z;k|3E{y!j*Dr^-KA4=Lg4<9=ck|`{ky;IwM`>ajre86p>R@ic)BPx5ELE#6z@$`6S~w^;@pP%znJCqt}i2O9-w*dhbX6a~LTq zBBLrNYI53OQ<&rFNGzmzYly|fNW1=36@`jY!!`=ZNWYpwlj)XNTS|dRgQSP!R0l$J z9pF6l`MNL)hW^FF1;X5HFq`kEo6=FX=zh?n*1Rid~#XuhQ?49hG6Pn5RHm)NnqW|Y`qVD^U3+e;Z!P`opqR?EY z>7lf_s*+OdJEK;57yA-xyANCOB*RZj-%7`7f2@KtIBtIDBuI9VXrzbdi-!^G20m7= z)!`>cL{(;+mu`jCe{!a*^n*LYoFFXGMg>gPf5>@vuQmzUg2AcTI=?F@YsgKcQ5@b6 z*8vWWYy*71VbcsULQ)Hd%g!%$pW(k14x`}U8wpQ(UF;71TZXMpV*U?mup}%Zs-PM> z!o;7Em#_XtQdWluux9@ z$|7mf^8KH_&T%ctByd|g3=iP_4GP0D86d&1hBL21Au+WaYHe>j??+6c4vXX(kSz`~ z$5-m6GM=fU(t5YE+(SqNf{<>+fIQGQEE0uD@!FtI$%lia0S=k=(}&dZ18 z$KjUex<`Yzm92U0)jjii#-vgGRX2>fTU?J`nu^v382ju?@BrkjgH#;VU(d4B5lMFe zJh-&=Okf+|8Syzj$S!xadBzis#V=6(Ppc+~4LlTeCyo38ykX&npbvnW1P7c#IAx~n zugn6`^RT(oR71}}VGabikWlXhYT3UfnF{GH->m53fgy9&G(S(J-RxxrI&PGh)JF__XZ8LrBc&iyvtK z2aI$@MX}@hLdG6_8a2@@`sBoU{bmFP$#cUp#nm3Y_H4sRGp!5s9Rx6~|2U*Z*7V z5X+k;JYfz{FSPhszOtPnMGuvcT%9smvo-jSBj zbZvH{)N4gisk*y3mqs(Tt#$s5xWJD5Ih?|FAqT(5h|s<&G1|Sq*y)NWYz<*Re=D>h z@ak!#rtzNAA^{Ggcz2Tv*MzEVbb&}5nS+Ruh`o6zW3GVfc!QATosz&qeV~eGK62TE z-e9V`40jzepS~p=Cz%2|HZ%D8OCqK$G{)ZWtN4h8p^;=PDxY@08C_f65F&HxlR3PIPUcwwRWpVK1Lah#yKTu_K-qBSO} z=)0X}wS{5r;EyYEM}!{N`b4a)tIF!{5}S3qid*zYwGRvdxX;MrL=Nr{O<#LzW2_o5 zdSCY_236L|kEd-uMTYZ84?@LB>mE}rZCayort@cM!!UFRm_k_Crm z0*6mzi<)$zPeYyD@6NuClEBIgOolq#(lV)I5ayeosQP=0II+&OH?j4LljRj4VsvH$ z5O1Qt)<^SQzSAb?EeIhj!1%MlWDy*M&clB6A}#iTNuuNszqi~ z`Ks132d|!CN0a-_@{0ZKemZgP$5hvK6InkwVI9NG(UGWD8~%^d?#=%Td#gDpX!x^u zfqH}dLr{LGp4d~E@YqXX^jFmHnr64f(DT&CWM{f*&WolK3^M0f53jVg-F^!Q2bJa@ zV=P-Q&KLq7Pz49x_9)6Q(IoXy=+1cbdm8#wNvu390k1D;LIfBe<;P=bB(cY~!_Rz( ztZ0eGu~yaP&?Rt!@&Oc5j4q$b%4XY0g_xL~KZOJ+$kz42eWubR4Og*(N-5ln!Wk7c z<-LA7f{{+?ZQjMs6zQhjvF{(cYiNv5RtgFY#e!C9Z$BeRJFj|s4$_*Tt~`r~uO?rKh@{j34=E)qH7M{IKW!cJrv@^Y@cxD}`u;RW9;E z-bhVQQiBs5t27#?f|k5posBq7{Vx^Et{;?smWjq+@efRz;Lq;6nVx=Z4P5G1pVNM_ zW@R0vCHNrxub7gEAKww13t%yhUK3Ym`5V1AVtExz>>CT5=(Z~+k_Bn$5#-a{;3mK- zQZh;L)&R?_2`*i5WQ}nd2T$EE9GiYj z2d#X(qkO8P9%?q6AoMYuuwvF&>W&LWuo=MGcJ2LM(~cuoB{IB@C`5@NVE+pydGjS8 z=PPTuMY9={NP~-=x5$GfkI+Y^--W_hsJHKU{$8)|d*@pxVKvFFmyM1huX$S`Fp=Ak zqsLO1`L!W%SW@Xt*h{LJlg2fKkfnwdx_3nZ_E&r(TzJcnR@z8ALS6S( z&~pQtbaI@(cB{tKcU(KG1Qi*@Iln%Nz35^I}}X538;S$ zlvUGt+LPl(>I!B#H#Y!_Y={R_Sb((Ek^S1=8-ur>M~6AbM0+RbbN^J6lXLLaR&Q{$ z46(4+?g@(VM}EAi9-H*Q1dodZL-4%rBlJ`+#`jJX$uEf@7(WZONi0!_0r6NRy-vc9bV=_G|zVOq!LF`B1J`Poa4-Eg9v zx9|V{{hCYDU~Cg|fbg)_^#IEfd#>ss@yMD!GXit-*~wlNY7+4#X{gyiZv0{b!P-A9 zhpNy*C}FsM%s~SOTpT_eHZre5GcwkS5qkPF_avXTHWmnouH%x79-E;6A@}1 z=*zmp*z8QOiV_+&i6!X`j7q3L{jC1Z-nkiB5S5*fynI?J64YET$XAR=QM3efM`x)_ zIu-zYUh`S`CDX<&g}YKWEQNtA9Aa5FP!cdrg7cy3#z15sB>~dH3!xKC#zWeBXn0uQJN1!zwe?OBKCO8r5E8%z#%*(xEY>L{1@zsfA2@Pv|2_v7z= zcXg)@Cll|WxpbKPM6-EMAgFW!{H;b0mGHm=MV@V3%0^QiRdy(@{v><6&$6w5Y<3fn zK(E8nn8QfwVyLtVIU_rGSnRgcUirB(&D8KJfLDka|zD%m|Y3?Auf|(W9|&UJsST{-WXVJx{i#DFRx+ z7@=;l&Gdxx8Hb~-KNy0cwd7G>TM5P>Cm^^|h={0fqTV+-%LEPp} zjes+OR*V#)wvR7w>%yRE0r z^c(_eOwm&x$7B*_MKAxzM2AiqyjjPUETzo|D+o4}`u$iSTRCQD%pDd`L>V-Xof+tw zCsBc&r}Jt+b}HW`F)NwmY7Rr)@$#coV)%btarU{@!vHR%NqNojS>+9=@J>=}%dBMR zeO5Cq`G`Q#ty0?IJYZQ;2X9Ytb|xlhrNA|C=fiz)~QBoMKrq zi6;#`tu1ExqM=&V7l9|a0#C8DKuB-D?@_&cJeL^A0aUA!lIpdZO%Od=XnoFF?ibsi zt?tLP)iGsFHVg}JmkK6PO8;KY$y973<1K-Xv5Zweow7-6TgJknlr}TFF$hF4$CZ@A zJ>LTH$&#>n05I)+gZ+Mp4+p5M^ZH{dtjm zW4JR;NZ$FcVx=u*EAD?1rGf?|p`4BdO+rly1QdDzvjkj^>?vh$&CA!=qz{*waoVfA zBDUp%;AbcLgq>4=?iA|qXu*$mYII_<6d@D1gxxbbEH(g6xXQyYzHS(Ao7Rse;iDceKR`wD2W~7NqmxLy#ipn2Flpvk1+RRl5ZnR;NY)g zPFk46LMZtS!Gu-u?ZI88@9oek8-Y4qxZI4cg=b`64T!A2%EutM^6}1eU~;9hF#x)aU zYqmH}+u7z_B5&R8spaMKHF(IWv83MbbqqhSBh3p%l!=^Abc-}E8;(9c4T>RD&&WM| ztn6#Y(;YP#9~+Rg*pktukrHX1uh+yo@zRBl-2MuKp6cjd55OBr%SqUx!w`K5(0B+2 zD%!+cgl*;_iEwP0`N$HS{?E{zLwl^M#k9Y=Hx=6fCVq~Wx=i&qIC5@u8HIp=fBS|m z6en{okHW6zD#7dS0Ue$vC;$dnzMAaGAFuZ9fhplV`0qIlZZ*YU3i>_UBGC<2pSX5) zmvIul;Fw|s-1NMzjJzb*9t8hyukq85q<5g*p&4b$e2Fpu3k|8$`bk_K2~2FhviLK}Y~1=^jm_s~*vb1bv76#fJUJY!q!s1nN$pwd}3m?}tMxMk7oyj=}D9xdn z;clmqj&9FFmLJY7jGQ*sk+>)|j?W9X8oVzCd{%MT`>;&VL~Xf>kR|^gQ)d+xb-2Cl zp@vS$p+S(LyQEt{V(9LWp^>hkl~zfS?nb)1K?&&`L>i@Yh;R1Z|AX&nE{=JBZ>+VR z`@WT`tu*WK=1vCS!0-wHM^=(&ZdM|51(egqkH*0{6>|X0R7sg4bo?E$iDZtJm5qS4 zh3zKQyL-skMg&e>{nm?4itzaqWv)6|_wY*apWdT;^*ch-#AobL9H2uBp-h8eT(xZRv?A%-r)zJxlG&-AZ)wZjv14T=~{)a7D>fnu$?FJvM?K^TMeJtnGK&puS@Tr63l zmw&214aClxxMg0#-a6W~%jRpbppxqHV4&bCXravMfzWwH>)>_a*J?u1n+BOo5fQjK z6{#?EFqhvbwk$5h!S*rYw_J3VFeAy}Xeu>0o8zxTrp~@%v-q7eFPe6;k>gfhAHKpw zGU+~*&rf+>rlq=~!_(7+x*44&0)vtp&b z_an?4WA)=R;wj~pRd%8OzWIvYKFD;a07HvTwmi4)SiA9bySn{-XKnSo!*cgZr^mk| z`Lu=Y^hIvIsp2hNJ2tXpT}c2x8{-?hQuMjJAW1Yzr(|Mttqu@G6zirQIi93PS$`s6 zi3DP>J4c=8kP)lXB#w=Qr%+(_Xg*F8zVYtMj|6?jVV1$6V2C2C$z#YqYvQw!=6S*+5sQiQ!gx=i|Z^d7SxbIvlf{T0N5VUClHqzDY8QjPN2`{xV3R zfy2elNJnd~SDvH8r0v6iES5pz`4APP7xmOi{up%jkZsk?6?JUGkEpu)cJ?Tl-v?ky zW|NQ5k#5c#q{z;w7cO^NMkO`*a5&E|*t!#3qIq*$`&?_rGttz^_`hSxI#$+_5PcaW zY;52ce_o(}9hs|3ANu!eYTq=urXRlj>62+@A7W0Ix?AdMh82?$T?B?M0!OMsddBsk^hQdMn^6*7+eNA&Xg@K7;_q{4Zk*sjhkyipbe;-Q7aTxOZfSE} zCs*qHir87fF8W7;0vH;5Wu#+jYE_T@o_9xS?ej_W&s8?{A@mH{yn%rPG1TP!kv<7= zzcL&H&(mNEhHpYO?Vpd`icGa?01>3TKNka2qk9?BD-tEQz&m{IQ7_#B{hbsi^udg{ zzrR}|{x;VKbJYqEc2974u7!1Ni_a_&RcU7JEtHqDX|H3ORQ{ zv=lD$N2028r?^#ni#BL|_Bu^tYY0oRs$9!05L|qTfk73g+jm9a8vSzsmHL|~QEF2fpmK7Y5KT}Ts2BA20o_KJ%8T@Bd!M_^PaksYv z&KuQ9tKKsSkMlFvwn(fk<{wzpwlMf>5JB7A`i7m1`cj1R=yUFf&w~@Kt(G)H3^TfL z{S)U;vL*B2d^tFpOvW9K;SqUxu-%RqGfrOLmSVdPNU(egGzfq8V|#P|YBjG=%U&gH zQXm2?4wC)p8vFgNL7uL8qh)mbI9~S@0JAG-Mp+(hg;f#gw(W#Z@LIJ@mI?wO1XsPW zjL3{u(u)BDRF*z}-|y!X?8VEW9a7xSw`s)*Bj6`==_5GbIvtsGPWgn{gX7ExlhdAk zn_=Mj5{~4{tb+#GLYe8%Ol~p~Yb2F^z~AHdHeT1t?Ec_jvCXfK z{*;<$bI=D_9VS~eS1S7iSI{z6aQ`$A|8I$Bj`B_LR#?V7?;p z_b8@id4v0ulUiLjZTr!Va52mwMV(+PfPhX}OrO!_NF5OY*fV2`7dREWCMaR6I59g& zu~mwq!5P%3!eShmgGyt~`S=QAc*w61aUCXs?)`Tc!B_0fOi)C$P#*%_G_t3g*nJr{ z-8dd($FXfyc`IfXBG6T=v2{@edhnk-%IcxJgEGW&5#AJsR3~qe=%02IJB_xdQ~h=5 zwljm+PSgCg*IJg-raD)#(p?dy-u=wfRh2y^v-XZbDNR+HS zPzz$@`FoOfS~YddN_a1-lddb>>|GmE#En~}@a(phiT)mmJ6x3jM{x^S+>|j-zEpO# zh@YW-ZPxgG?_Vqz4^o#Xve!J2KVBu;B%6Bp8@&TY)+A2A2jHBeC29YhkW|a z2H;*c6SMdE=Z`M_A{lh5lwmzpJQKuB_7WNtdpscAv5#wu9j8muDIhz4h$9HL!5vh|@J5ae=%qh!HhwH4ozbi6F zB7tWiYqwX%IVFv~inxLR^O-*gF2p6?dpu8Uu&UmWxk`S(&`QMD-=en3aw^C{L_>Bi zl~1x#z`g>b5iU&@BCO$>%iHw3V+i5nQ}n zduA?4^8jNlIj=edLL1XIgx&@D-&cqv| z1Mbe7%%7om`~R(V#F*2e32>X8nd%5H)HLBKf@@pwg`FOaRq-p3-|Jxyru%8IS^0Ky zA+o5!_Bl68x|HO7N+=YXXI3F_h!)~9JB>;m)T6gZ_IK^HVIihc(69AP`Rr&uCjSd9 z9O-2ou{XU;n4rrg4~wGiZh|-Owm8xI67t$d;XqvJa@0h;NL>IX^5y`DsXi8Az}yGP zbP~wonG~h`>19Tk)=`HPr-X7#G9jB6VnvLVX{-Xx)ej6+#Xy&0>f$PF-4UHENV~Fl zk~;Q*VLcSvCz~yXoAuhrM4|?kbneEvGcnhF9^wvAJRMuiO!=K|*S&QR0uJW6D;3?& z8xboxZfZz)0;6kQI+LG9s9TA#fwxB8d^c0lwhmbU$V{m!OgkJ8hmlbQ`BRQ=3!rGE zPbEjOo$FU>-zI8Ea$o9OoxYmclcPxlZ;#RqMd|Vte$YYVq@{?(l}_XzibC1V)TI;; zF#P8I&jCsKzHEatYztCRs+qK?tZMK^$jO1<2I5<%STReOy(_PA%(OUP!hG^zw{tpHG9;fYIP+!i^6d&Ew)LlEnbbm3JcUs}M_Y`}zzlE@H&f)BA! zK)bHf`Xf8^ZE{Y?F7B6F#1%lG;GKBX4l={ULpU=p@YDB(q6>5J{4D*yk{Iiyx3~KK zv328bii-WjXmJCtj&pQLyR6R_!olc>i774~!IP)#=0rEM=uw>&^V>1_%ZjGkJ z?`^Z%cojzRz9<|{-gYL~0)C@!@Er&BqJYy3TC!Ac9RWLZptf8kx`5w|iZx4yxRPU* zn4nnvaAG9~%kViGz@b*e94N|@IRez2-_O=j!-m-qr>@Pwg$6-seeb`-wNh zq-)s=(@Y-l_=e>xOGmS|%vju`%`_lWT##?stjAt#g9oO5#Q$=Ns4!WWAV}u7SA-?I zGvpl%ba8enfAZ#Tp5cQTV!q^I?hp_Ri;12wNE{HVh!gq25B>lI{;tUGm)Kfsdp^TNr6 zRersQMEOqCce{4%qGSR9ZVwJX3oXL6{;!et1Pz28?6;ybcBsp{%=yT1j>jCV?hVS9 z&#jLb6b~mwhQ)H8bjgOlnT@g)L{nNTz0>rLBz7M9W&u#CdWjjD5cD4 z?37DVkt5h+UaSJ-cHI32M#Ao!_zUkwsyJ`~sMKRYG%~yEdePnYBPP*3sAWIZ%2HzI zzlM9ioo^@8H-oxmUt}boFKp`n>@fD?4yPy<_tE0VWz3Ia{%YfI7o8jPNp+iOAnM|j zIGoZ9MiJ~c6*!hc8E|fsocS4bVA44rhia>YFF~1er>pSn1OH(52pl^tXi$9;#}`Og zTP#I$-vcV0zxi#v(fPlrfc+4tTbp6=_n97{JJSfzj`s~%ENyWMRtOYv+WTb9{I9l5 z7$*FuEW5!QsZ9k6v3sk=`gQVFRE{DVu7@M@gnAdQO(M7ec-c56JQhWdOPVj;@BtpD zTmrf7Nb=zPqmZ9rEGhPTcOj*(Y&+xl0gs}jXWkNmQDFpm6&rw?Ijl#~V>WqRz-=F$ zWAsmZ2s@gY6!_v5W!$%heI_puxNoIse4aDv3r^!;R5e zu12jbkw-hTqubB*m`?k&!PjX`J*7|gjRPSjtj7|A>50D2rRpqqg-{25x zc+)w5>q52p>NRa86iN3Op?7*h9r@PNvPrtrXD-Ar0zhG+6yXSsNvuLIZa zmqZ*ToCCox&xv011w!P$R_bkOEv5hK61#dXvIo98fL1zifj2>uWo9iU3#WTrwQb5x zeNz_ArF*uqUaI@k5F1le(edV{zbhftRrrU{7qG76!b-e$WCT+M2frJ@1c&eDH_ih7 zzzXddjB(!*_g35hdsmY6rv-Qj1m$!A(jNri+R^#Qnh%RE3$!$}S8n}Nlu%U|L8c%o zx;Q|_u-)TmyUFM6jVPS0PqBSG?1!2}YRzS^d%$UYrL-g91k2OlM75f(+o7|n%ALi2m(CVKJR5ib-MN}Nji6liVvZP+Nu zJXrgUTRYx`3M`zIJnlpr>+XY7Y@M5sz=rK6AKkfCqfsw*Pmiv?DAaPa^h1B|ar=G2 zuGg7WdRohnv8tBpGqxp$Mj847h@lXJPbJ~|V&klI2@EHMP**uq7%<4!U2|H+lXChr z5p3jA+Mh48e4SAbjp|z#Hr!i9Sg&WaDQorQ)yS{lMt1P z=POIo^8+(Ft>q^jo>yI2v9`&m(7D$RHbO%Oq-60?K`FE5NiQz_t}Eg_VwX2AMePob zBhdRLlBqOe`aQnoxhZ zbMh{S8=7=+Gh!Nembu$t%*iT?n+JI4Cf)g4GNQ>qtQ%uhUXaDO=X{3Ya?EJRh;EMX zEy;>_5dcV0lIMVc&%qo1OV}`%9c6GBL1jc@icI=O5r$>ofSk%Zlx4+tf#r=VBsX<( zc$ktXg~AyLLq^h;&B7&45;61U70fcSpRbW4Jjasj?(4aaf0U;KHjmh>iIQju8${4{Oe4)Z{ngch@8LCv*9B&!EKP9S-H@sI*F_NwbHr&Sy`$!^{ zQh_}IL+`_WEub<3dqIzxv81C~jQ%&>yPi6T7L6kIaQx&r2Ea9Y6GSgjd?E&m`p8mi zqT@a5n|nXs3v@$kZBJN9roE8Tr5g|$f2Yr(3mT>9AAdT^+xFB`hMEssq%Y9K%0ycu z+Dd_$(fp}ZFR`*^FsULwJIN;-50E?I%Oh-XWb>0p_<)h zM#ev|E=0K%Ihb^r#gqi5h6pcq7GT+{Zw2SMaG=kM!aIdpkP}ih|J%pfsq*WALP5`u zmJv0%S0-W9PvR5JNa_5_x@R3&R%PRrzGm2;m{Wc0zO+S;o4JW{>Iauygpz}0UE>MJiN|vJ)5lI}Di_uNk;|*(jjqtKa5SL4cAa@rwO>u zB#tK_lGmL6H5MZ`sbea%@C)3*xGiFz7H_(l-0v3o;9=8bP`%D+22@|LzR-JBXk^~D z5;%rWo*aVa#L?>lZ+}E6+8v2_R636>(~CqmCY0TYW@MVWmc_FL6&_ne`?@E^Ui#p_7t&!Dckyn|*KQ+(&!>dj0=ajSg>+MDdKVu@&0#2W zjANitwpEh|&0iUfFi_{o3fv34W_SY43cmj53mF`7dfRG&>D>6+z(fR1XpGH8ZC=U9 zqCy15vIHRc#)eER5{T9xk;M`K_=1Me;)?zN(hdnM($NUat1hk<#8j{-0jS(mTKpMI zaFnnV^&xE;90G21Op0b0d6JPHoJgKut8Y~$JIuC>s)bu*>+TIC1-~V+IAP-PZ3_Gr|4zgbubiIe>bw*5KfS&`Y1=JMLM+vf)b_0e`2G0$ zY3J$ht0RFYD+&VY6&Y*kI<0kgzN*nw1OHa6COh_oI=LPoxLF1xhBU?%o7-4b-HGR@ zbx%nw2k|&v3lxZLn}{AVY^dU@(PNCBK)8TbYx@(mN|VhB^=@+~_v~4gmW%FwXsT3+OxR8t@IFDzreTSIGAq_zb@Mjjc~B{*U{t^MNg;zzUF95K9WKfS8wj!r*Mys`*40A z=+EW$e`1o2O+?^-AefAFzGeX&Jtv56!smO7XXwKp&8L?efhUp=#XXnlr|X-K>$f-0 zPn*(*-`P5&#?5WCH+^YcQ-bN#7H_T;k)QG6a$%>T$@dFH^RgSVs+qvQ+o;##-BrCU zerec`zp|jO*77H{*K@u^NUR`CY@TefY$3H$kqH%-K*D%P9F^f|DJtwsPL_8LP=(R@ zP+i%upG8#eQB3!2jG{psL0y^!)|#OV#dN{A9&C@9;5H75&h5l%;A zFY2d>w{40J0uM{05F8C^!?~voFi}Cmx-PD>I4|2VA*e_sSOU6vPV2# zn$I!B^b8*%CuBs8gcH%UREyf2&%Rh^N05Zna1_Iu2vk4#(HPQtrOMsUD5ASO#` z>D^CDAk_zuzVdFteUU)15-QD@V2Zk;%XdH)#I?3(*lw6}JPm`PWaE0dMtAbFUdisi zM8<)6c!Q~AFUs@fo^Myn;<kS~9#%Y|XVmII4DbZhV{D6CL#5r?Nd zH~pFDgb&{gSH4X_)$FLp$Fv`>yN4CUqBlt{5yz5`@3-gNh&$+Ad zCkpUq1Qw>~tKUtqAk>}Otn`Jes3o9<5DM9|^N85j34}?FpoiMh);u+LI05D}Rd%(= ztCu_gW1%fXgZFeJy;by-27jgxIfjiHP=GBiJV!l}N=gScla!wjpLV4>B6FCnSmf7! zgQp|8rzf{nwEs~GHB`jf$ZTN5h5}y3#;xGts~z}56zWia2L&=!%1e-r3~y<#5bJok#DNbu|l)VHQ^5+b&9Me8sl8Bji#F#Y2&`$T;9gUCX>|7A_`SRHy7@1Q3!D+C6 zmX?Ox7M@>s??StQBiqEAmv1*C>F@LD`xit0F8-E&EPamj&p8`;TB45&4!t>fTEAVX zeKdZEc|N7r&;S1g0dNCv1fCM-R2RB)%A<*Up@bU`oUOJ8pYwcV4ej-~4JtOe^tDW! zTCdXUPDNqF{Iv;qvSM|}7I@~zsF!Lj1}3qre0>$R-ZIasdo|Cr*&O8h;O_c7LZz%H#@<0o$x$PWwr26sld&6W-MDSSK>p{!A#wS@LL3H)vNBmQm70jD4 zUq3wkp(qK}d zOIn*@3#&9)d6%rP2^1UBef89s9;g(zpCq*5#lK?5e?7(LE$YHm{6UD$Uew0wEhPoz zD4??+RVz7|*W97JSF)E#_7fu`+8{$6vJ+c>k{8X5iaYO{9GvvDOf2{2aw<`(>1GRy zkA|@w<$O0&b4rrt*V7rf7TPxu{61nTXr6YfDEuhNz0dK{+(b*MbxgDq>qceq{Q5sz z6z;^H)Q#9fL+JI?>3=UbO%@2Dpr?!65s}afk?sGgh>m)mjwqv1a_{F#>%F8dGN<5+ zp^u9OwFr46d>^-#sQ0<7Wsb7f4mQxzvehRwul#774omc;Z(S2@)EA9}A0~CdkRVt8 ztR?TD?*_`O1U2Iy6VkX3<;A)a4=<#n&quClj^FR-G#3t>7bYjrD=CA&ZylEjo&>LA*&h-Q8PC z*b+r8}y$~(t# z{Uu%R++a7q)q>|2w)955AAi42p?XigDKdWFRMN?Yj0nw92qn*q7og6h}T0d7oOK668sS)bu5iPqvc!oOatWGJKlX+ZA}5 zM~0^$2G2_-tyyE3CMKg{iYmi$V4PfzDKg>AvZlzbzP0($f$O+_sKCG!5{Flf4pQ}}XbR8um;VWcd%WhG_Rydjt^ zF^U8du(cKw_ z{slEPllQVERk|1AU?tDb!;~t!SjLLB$y-P7PYxp#&6ZmpvL^ba=>#`Qewj76T^SE~ z5FF68RHPJFRKak2343MoSKR%@Rl|t^Ux^|P{SZuCEGREZhV!UR2}Ot)l`Bko$}wgW z%DEUYrY>c?cl1>-SW;^XKONRR1Xtcc8r}NO<5^RH{Iap$Hhsq9`rHmYZruS*>c85# zY^K~w0Jb8!?C zxhLo{KwGnTugW z7CqYlRdRe^M=i!Uz^L)*>_o+Nqm3tSlm(bgA3bZWW~YrB3;M>EuW;8;H_Z9x>)w1t z`1k!R6dYl5{d*k?oF2|q3kql3#|E)K=O<-%vom`c=2=ZqS^Q8_weO7k7bMN9_cxse zPqg~82xlw+jxd`*xb?s?X=1@1mZ<7{bk?iZGrBq?NH82f7N0^+!GHe^dX9V_ zx&P->xH9P>`z5#)K`1pkNB4{feLktc#Nvq=K#3X1Uug!i2O-1_{cxtW1J~40l&gsh z*|wDZUwDAgmz`>N@u zkvx0GsuQ;X4b+NYz2<^fMd4paNHYS*aDq9GRG^BP_ZNcYP~Jj8@;{&bR%-k32!sgY z+Ea-0am0OUfQaCLx9OX!lEwBkZ|R?py`bv_-RQ1RxXoX$w!0~h(eIJ;0P?!X!~MHj zuz{bXTOlO+zUV&A{Ls*kq&Oh(QL+l`J>DIcwN{@^dlCGUHEyNgPunhKKy9}tNUHpB zF~(?yec|LHD9kOvK}q!Iilr9-3+va9JxQ&dU#KVU}Xx>uZ@rHF=TXE*`EO=F~2n#LI`~y)YSwm6nlCk!LH)%NI); zC6jJYScEP}dyEkk?!Hhd_*1851#5cXBqkmlP!V0&>Te!dl0oESt?{C&iBDoEIyr^J zX<7=<%KHzO3^2@BF4!q2U-C#~&y;cwD?G_dlt*H8VnAku9v74HiCb8A$7Ub4u<%BT z21O-e^d#5Zaq&NniGCBb3-rqPN$^T?d&sZf$Ykv;kg%bzjB6t_gJ}+#k^QHXKtU*A zZ^M^u@0KWKM93!x6qWR!;+4$qM7j5mcaMml`9gJ1W1jwQaqsYd7W7=0r>}2}#eK0w zf9}0bC{EAG&7;(1a8l@n+Y^5GUSKyS;1sv#6!#XXJR=MGB@+CrutwU%z%M|8&V6ob zBmgnIZ3e-d)1&G2(N9sUFPYvpHEw~?;TOx7FTsP=%K8xV03}>kBo69tfXU-{cX#OH z$>e?2+Vo$1*^0M*P7D!?Es#YQ4pGl0U<^oN?az0SWEyMtIuv-pyf*;veWN@n;NN_* zD>u}fpmR?DGudH2TeVI1tZ5DZqJBop@5@Cv1}1r94?2YuZvXU{BU4#XZ=P-E_ru~y z($hu!7iOQQZt8O&M20m4!{R;-=yN*<-Ca+T9%awMCtYLk+#8n8${0fwnDE1E?dRi* zk!_q8Pj|Z`@1&;^^;}K7S>{|HLnMS7zw2GIqtl~}E&dF4Hj|u-&Hh+6x2ML4U}w~T*0#Sp-*_+*KnKaQ0K}jznW^@6SI62Q10v1e_n6Coz%O~x^f<^z|cZE>(9qoW5Sgl~Ua>8pmn>pwH@?@4^JN)3$`*BHHf-Qr5N3vPyR= zJ2@X{$J~tNXiYyV@uK0bWYetmc3AIMi*4)#*bV^sWCEek9lQEQ4OHFtT5VeLao zC>(A)aEK)^epyT7Y%4|m3ybngKSXX6r#R4iSgBPuAKZ+mi2cGPd@X+yLnNfZ3A-nrYRndU&yr^gJc>7B3UE925^^_2k-b22I@r z5&qm0=zd)7p_P6ZkbW>ux+BdSuyFAITb+6<8O16G)%wGXNtM5I%!6u1U*&Ss8y6=LyuY2(=G#^)C5i7{sY)^H%;QF1(uDntjs3o?j?BM5L2zKe1Pf%B0Sw8&#vz7pY9imFKBF$ zfYGRtt38iZCkr4P8s)GvVE zJ6}ys297c$r1Xq36l+qFJi8jMR_4tcRst_bi3+a%IY()Rkh(#STvpMZpfWG~BztGV zsXM1)utg!c0grTb3y|wMQfg!`QwuV<>Eg;TKA=qwvk5WowHhJ*xuVyr9bF6R3o;;f5tsJ&H-V_tOZa`q)wRah;zF;*+P&#HUF4jp&kJOSQ z52}+~d!(jZrJlLqGa=Mu_Di zDS99BSr$FHlED3D{Bq_KZk6ZAOZb1ckgezEsqNF_Y(>NoEKz^63F?;G#h$DNA7Qbg zfD#q@RDU%tl&CHKXBYvByT{dE9)H>q+wtFNv-9UXb}T0}!1of-|FHNd65HJjF&Mqn zF$>^ER;%k^G&p2*nn$g~784J4DnQJ4F*$ zCH539C#ZRhdq%aMBf2lZ^WE1_U|d0&NvG71{-9{3eg`zCqZR1ZkSgA8Ufdn9#%|Zc zRhXpTGwz?*)7-|V+_X`ATX*lDdVNb-?Pa@jAD9!m?QKZ?qb2&jIKtX@XtemgPOwHjWf^!z=E{%8@3$79~v+kre2JfyRs%i1-nNzi4^KwaAI zO*^zK(gJ@Cf95XC@Z&J5z@7f;)Lathx%+=n)H`iR<9YSe5j#IO25l2`Zu}_T#^g(> zB>%0NY>YH!m1fN!_!Yo<31@y0pcrKSI?wb-KwV03Ey@1JFUzU3;;CZtiOq0o0Qs^~ z+O&=2Ix`jVLLM;c>+N;0Wd!S3@eB=I`b5rs`jckHfsMusTk+j3@WXId@UO(*_V*qZ z$;Tz1Yi3&g>*r-uX5gF}ccinm!cbp@i7xXYEF&5eT@JK1BtGdZ@Oo{npj{ zf2gi?e^JwHU55PmPa1r<5%Sw!M2fVV=|j*u>iWSEyuI(P%Tt9!xLS;cifNyU9yHpn z+TK<=_a%M}FT~-~pVJq+T#Kfaq4qod99%R(fsUBZHQ7=>^EGH3K)*W?33H@F=Hg6d zPTogSlUsisIKeL$a30TqlnGgO@elmywz%BLc(InJ4)#HUMXy=&Gp!2=!~nMD)BtNX z92mx0o;&GyV*P|*3JQUM1X?>vgSa)nl*`ehGha3>-7ax0YXp*RhbT0F2AzfAbc2wM z&f0%Jk$AE_WaD#cw_>67{~DUd+)(%y-%EPOP4)D})%LcHFYC-58{?e~6~C7l`_^#< z$Gwml`{Nc&XJ@H@NSfFb&xFoU#tDW)G1amSJlf4d%7?1rvzbGa3wYZMji@?hw2`jn zCHZcAH8j5RmB1LZIg^J0byix za`_Fe0C;g0-G47ths<_R=3a|lzY+;Rx6z9VHgYWGH$grU?a)FoJiH5S5E6?(R(G8E4=aOB1h>< zuIp_t9PjX!h9jI}3C^>Fl>MPmlsC8OvZ)}g@;*|aHo40z=IFFdIWRuoLg(K5y-X*f znP4L?GYrVM0xzssVCx_A>ltAe|N6M(y)ehxX}hg(c2@HzOQrk`!!n_MxA@lrb@>hqWg@NqYzEHmg<7r*#uY-&Xvl6z z_m}#c7Kapk#p~d;+NgY@NhE>jY&d@a#nq2sksf|;vSi6{@7X*}-kQtxbNIK!>xYFn zY{H8VH=4EXB2=BCKW-D!nUScnp1bpS1oC|Ndgd-7s|hUaz`pPsme{ZgxRRQD-+wos zZdvl{6{_QCP0VrhWi35JaX848q3C?XjL?Dg;Xd^7dlKTnd;;UYk%~95@n2vo_!~Qm zI-0hRR~gi|E5W}((`9Crox?W0!{?pn{;m=+aw7?9g7gUv*KVB@xC%2;l1$>kLQa+xBe^?F#hYXukMzcaMhzR0{m)%+KW-^nYep%JsXL6=m6iXNSYO5exIOzlI$0An z?%@=&Ammjyd?&sw7kFDkO|hb57W>Kf`%n5`Z*e<_*U!3|hFM?qsTvg(*hw#7yGEY8 zSz>+HF?z4UVr{`_Cz^I9A@ov(DaPH4%(b30?k;JWIwkVBOQ?vZ9Sd^XhOYT%+My23 ztOvkaKuun!aEWe6s$gI$&pt|$%qz-dtIT74IM9dnJaPnhj?nHUVhAH8J|t@@M3svi zMfq?VB!m5_51(bW?J|KA0dUq}z^W>edRpOrN)(GxA!!9+^h;An?XK#lm1axZ6rc*; zzG}C}QJZG^CyHAQe3ufHr=W>{ zZ#FCNjc-iH#$JoLOxQS|O@f6m(?DyAgtDWxpG*N_5 zPx)J$@qnm=1ZC&FTl}!S5ziu`|6c%RXQ60na zLK1`wf&y|_CLODfn>qzx!|0h$OU9IgEBMm2dy9U~Mqk$5u0rwl^CmiLj0@3ev@<{H zf~LjSRd|P@1Xyb#`;A*K#x_Z(6VQ)F+}N-jGdwCvjN%_r2YJ8XyRf3-3*0^GI#@ZF z$Rq(ooL6p!hn|XgYr0Naj*IPT?~d%_$GTgLf;Y@x?;-@rRah8xM}0`Iv7`ND`mW30 zBJ(0cHp`V*>T?Wzzx40B%{q-ZcCOKt%idnoqOs#eM^^9O7cm@2SdPI}q-}}NA#(2r z0&Z&RgGFIW6J^ty;73(0H}U|?-()wNFsV`-#zF3_gw!zg4TxmLVq0?u;H5r23Z8z^bgw5r7%PTg)hweAtmw|%WzueRSrPmVo zo2;r=;;;6<7ca006lhT!B;A#E>$;}IR$Z=RZN>n^_g~LQz7E(K{##9`EGDH!1%+PwwPPnH~6efmH6E=ttKF+R->9~F~&-S1b+HJU5%@{#WUt5Gr zts_?v?`Dr5M=Evw5%ZE(Ocoa3+!jzy^>Snn==B{d3vv6z#l%fv5;L8Yj=gCE-@^{Z#v zH&&V~N=-qBx~K$}(tp^dHnWE7vp2Zk@|M$g`81M@)Az(!PKW;!jqTY9M`SlmT@>f= z6z6kc;`qNtV?cqN-t!03!@KwaoRHg@ZBG3^Op%UgxKC^27YE<{}m9C5`Rrgf>Y%p$L{d;3Y&ss~`YFh>cJM zmG{849ZL->Bl#J2+;O~{ zf48hGMVw<)LKuC~&Gov@f&?}q(B=yk45_Ju_qr28_BwD&0(n8ga9!N+|pNcCAoVSr( zxSa}IhCaLiPFFlk&M;tK{9I^xFQ9gNuP{TwYx@l*y+*%`lIhNNaNOghG^!yx2;awXm?kt6}cqYy#>}){Hr%)t^hh~Ei)8R(Lau;1U8V} zFc7>>Sgm{%B_`XU`NU05CEtp502UoX?IT<_ESoF%?aB!s;*zXtBTQ$lzJQU3GD+qq z>%!HGQuNqD?Bq-qclla)WGVz95(ihpmW@8O+nJ77NrF_N5{J`O+%OdN9H0U|?Yrvg zez}!0i==5h$se_1yqS{00xPN#^~*W%r(zi-lUfOwf|+9LGCMT}+xOqrTmNH^-JUg2 zs0+sCwA?GHfW|mz9ve4%NPpxSqJFL}fqE>-?j<`AqLm@-e51jEmw0MD0?nsa?_`is zia>Q)=y*IQUex2ynqvQ*J5p;IHdxld$`>EO&hXr;envhlAlbpMo_m7 zcsWu82{#V-)l*(Y#*D#R=t?rKui^-Gw4m{ff_B5}kCAs(9_?AyB`F76-1jbi%rg|z z+rbwXXRN6pc;4xYa9Sr(U>@T%y||lk$bO#q^J@k6g!~)8w!Ky>HW4!lfrT&DU)9@V zcXYw?>e2j|bE@AOT~DnRHPE)Y|C_glv@ZJ=K7#u>uKVK_yy?*Rf_l!y(<9>5^gvl6 zGfAgv%Q_4FLW}+1=w1v%IPBQ4yIkK_AA57oIa@7Da>8>Xx0bwns&g_EL0_wJUm zR&4eXBUyL?8n}6Ux#1Pt3Y8TlTCPiFYrJ2^EP;I=+faE~RSzcnwaRLM!cre)wg=eI z-&_Me*Q+q}6A{RahXy;AhMLv#i_X$YPVOg~>BirR*OGwqV+uIQw(JG|HSm7Lp!pY`o+dVR1Zvx)_AH@A`{yNIPQ<_(=JcNbso!2zU zeQO2H|5M26{(xlMi%)=A`f$kh4wWyGdvD;7i@hUg6O_T1Bru%bgY9*Fn9u^o z8*<3>ryF>SW&K`cDi3vkKN_X}3nc-2g;0AADUErXhv<9l%(I6j**9UniDr>h=)$f6 zWGF)_rZD)zv}btZf`{t|QSxv96rxD!keWuzSe)}_wp^zj$z&p2N=RV*;mB|7V#QiC zAx}l6kS&>&G~;wJ0@It=A>+nAUi;@2xX zfs|l&!PA;2W3|_+07Lk3C%Q((sl?_VTHQXAyachE?{CQ#7w8*xCO%wo>w(_e$UcEc zWBb+b$F)+vmPTg^P5LE@&}sMetwy9r(+TO=zVucK2VuLXY*00wz6dZlSPJ!}0fc42 zQ{{2cL+BfdyDOp=Bo)ggUH|%5&!26#;LjNXY;J#a*LIGgD~eJwih353yky)NO*jXq zc9`eC_R(fRnvY5OlXXGR3TzvJ2rB?D<6L!GS5Zm(659Wx>8rz< z@W1aj2BSfw8w8Z@ZX~5sLO?nP(kVDvX{4n_cXvsLEpA~W}{ae-p4h!t{GAZT1+eFTGZHgb?<)03R?UmapbydcWB%UDw{|<8~ zz7_oZ$1l2YSH(Jj{|pdCM60=Kr&O>r^GmJAKcVx-hU@Ijtm81U%-SafRY~73_VJ{s zv&~tWhy6a`80&oWN!feExbVXihzpcqp;tRwA79gx#~0C#q8_bLriUquc+uH;%~lG& zHc}26h^Bpc=cJCpBID z?0K+e$!HyosZp)A8@{N{>jd35+81WBvBP8eCttN1%DK}D;?n)9<6b*TSZ|6@Y~y2f z^0{}kCF$U3@3Una-K?}U%RRv{)$jYrMfSbdW|7v16d{`vme5t|%X^)1+Sp?cxH z$yGnTcuh#|i0it7#cORF2!}nS?b+Fg2s_0~LUzBXCMm`6=adv*M0)Y! zZ?Y_MYec5s%EOrZllsR)u7HEVfCfA%n$Yw4fk!_z@=ng}50+`tVs_LGF;}y-p)J)Z zlP#hjc$xqJf%=0+&Npic%dHYu*`66QI3k<2N!2P~d2_#^NQ=6GD5HO$NFrxtI=_q{ z8w+a*$M8|i`v3u0^ziiX$@K(j?eeDP_VI(k^-ue}z>_sfnbTO>SocsMQ5_#R0ZbBF zXB)+Uj;(q_^z=-Wh}i~W8C$S{b(kn)g{3>Fm-xQGq9|Ne#Z6oLaIk)o;WM9IFJHn2 zbbFBX%{d&#ws<5+i|5|yW!o3R*JZ9mjO%F>BiK39F;{5vH1>7!fJd9HpZh75Lx3J3&w<;LP7nu1=M2nXkwUPBZ6m&;$yamqTK)|kE~bg0Q$+pmSKhP=&JZZ?+iQ1G5A zSsrFyW~~rc-DwxvC_v{L+fV*Sck$g-e#yC=q*rfGr?lZz{c?dUW9-yZYQqXA*7keP zZ_l_$%_P;Z{h>Ba5Zn3XTwQ^FZUlXoi>85`8}yf;8-`Lvg(zmzXA<)AF3Qh#z&L~^ z!joP2o8`npaT>?;Wa5ilF`;`*lMHVm=65?IDZ(I7v(OJ|Q`&}W+A0G#f40d2@nR)r zqV<8=oKkml~ZSffhn1yy|WQHoi2L-SSc5cW*(XnZFtQZEEK)sC*N zmvmxet`l|YzWed3>?~88YRB$%bX=8*^+EvYFa-oXLTL&=qEd)8m$FjilbUKifAYmyCT}3vUGb{< z7D8Av%aCD7afmuExVH4k01>c{Q}*GSeB}sz5G}>zVjF~IE@#Jt7K!4&2^bN?s4(pi z#Xl;@eoWy>Ho~6?@X$((3!Ofn9ypbI|C~DHIF7{Ul2!>;d17%B)XV>?h+C7 z^Yh;_2YO>!nIJy{;Fej%B3+LZlb1eZj84Qsq%@ z&$6?M#CL@j#_<<(7F}?&DQmdNGp48LHxu>uns?C`@P!MKrAb%Aw=^4Zq$(y?laf>2B?LXA*H6SP758pL(2P>RNIm(Rppk2> zWSBM2+@E{6BH8zyM&c*{#I~*9ZO>$~r5XoChZavHEnTGmJJM?I`}8 zseaox4PO9CMb9(aYB)zNrO)hkO>+niYLW--kLP-WP*P#ahpy0p8z4t9ghhlHjfBjO z;=~CJP*Hf|LT448@QSW?+pdQVPz*s|UFh#5f^$MTO%xhye4rkOAhZDge@+-P+*Jwp zR+gA-i zL%Z0AslczlsX^69e&QG=k>%t=%MHwa!9-%K6uvhAZid{WCSsE2gNgKPj9KSm19r{} z0H0TjUQIyfVe2!K6=T%}Ad(;;(C(xgNfr=fh0JNO=S7k|tbhkr$Oc;EeUI>*=JMZv zn8|bDgP>-TI39uy>6ma_HBv|G)^c301GZP)@`5k3hE-}RdT`7$Mo~mD1s04`ThV+w z_3CSPF(+$Df@t-a61x9jr>*pwDt`)loHo6+cGvBapD zU8w5a*#4TkaN7ZP6xB!!SQpFGNJWsJP)OBX_%$c~6{lB-g7PsdjS(K-;HeY-pRd2h zSznUrxjB6MxZ7#uzLz$`9HIBI?DD8~xt#1>1l@w23nN#mR6+nqRKHJ0{BaV4Lhm8E z9T?=w&{K=Vq41CjK3IDx(+Iiix_u!7O6&H?Zvl)pR?tdgvNR z!~3ua$uh$zY5*ZVCOfZJeSTgJ!<%WN;_Qv`++M6@xk0M`F;!*9I z_0DNhI$%J2g}2md@fRPcOsFBL!!W#nYO@?OA63Lwdx)0EvP-ce&IT!F0tgf79+@tVVMv3N`6rZ;eWm{t!*u+GC}<>au+ zt|vlU>g@p8a&8}<)J1ctMF-c!AH0Ad{1Swf?>5!^Fj~^FO`Ijd(svsVX&i-IAV7US zR*n2A=V}QXUNGNMH-2_*Q(C8U+k(lc6-5LA`n{!TICz5Ynzis4Z6fkHx#+`iYf^y2 zHF7jiVYRfRC@C2%hoWi1%89yZ!st@Q__vAV-13}cK5wR<<^F0myM^VOw20CkHd!sYgRj8%S2CmEB* zYZ`vrki=c0geSE7b8Ev8%UpBJ03-iv8?`iiSM)0K#GZw663b(dyI4JBM33z{Xr7Sz zGQR@J_kUZR(>b412`TNr^LH3LB^&)TP0R0w+^EY@#wXACTyWU$`RwK?uO#PA(LVQ# z$TLJr1%HfFR0Q)z;!`F9E$Hl+$MAZlFMILC&DkQAzRfUS&vwtqGC8KZdBRhy zvCtt(hD=d3J%N)AlhZ97EM#gK+bk3-y_`I~P0H}2=iR1OaPT|Z z{MS4D)E8lW0flwcBrOA!sOjViH@Y5^!;jFgAJ=S^>X5UZW2=(bUP={SzQPHBYpcrR zlY!Za^D!$|;bHhyPO%`BRdkD@%~aZa$O{4VKtF$9Z~*t4SIG3YZ1U5^CC#%zZ0I%z z`w^$xamiOaa?v53cRU6j zlR6#A?_eR|0Kp6kc6PfQBRB{?GSk+Ss9Y9BUeP(3nq3>!R;yZfw(tr2WC4)@arp?C zC(4m_YB*;j^s7musekVb9qk<_+oZ=f+81-LdLibR)|mqLx-i}T8nQ_n1|2k-$~4{H zXrkp&Hc{3qa5)0YO{K^Jq%%wctogkw|4CzncOZjbAppRT%`@_B-|K0x++9UU~{>DeGo=V85YQ6ZFTgU9d9g-5cFBYKOl^FWNTl&TpGE zUXK4*JJd&~Q98BcbCvOL8k_oiB{OU) zCSdy2l3CbtcsQKpsRp!%$=+d@fcOS1#@sj!dsD79)KXYr#s#Mn3Fe)xN;5Jw$`>=Y za^GCWS4lH6Z)q$)@0SX;Gx+%z50@{VTb3t2WO1sY!0dG^2O*Dew}B}H9EXG*9X{=PDvgNM?$M1Qe3$w}w!O z@8(DOe6^9dtAK&2cf!YZ%{M-I5oRy3#>bx4TYj4f&|egEFv+^oZy9gXn7d&J(G~t~ z-hxa0svy-NJ4UC|!6@@0jI_epf9+rk;%<+#!u8^x}i^}v)?_C zNLd*z#yVun-bl2;iN!L}R?LS0_TX;|g{gEj8<8HxC#}ZL`~`8k>lu3`NgVG3rHsQmT_J;G0i0YUxvquyWCs)p^4kCP#oq zb3tGt*rNQVUPvWp6KRm`tjZ)$S^mKhZw6~|X)cInoil_I5Ly7jl0$m$(*IExjPv$A zk~r9}Frz6|z~{KYQ7U)qVt$z}6@he3OLsj?H$N)*Z(FEN?H~6I$=bp3Fo*JC@?3?G zMWID!74Ns@&R6aHGcS=<3Yee){0K=PaX1YUoBym&ky$(n{URpIr)#{M`QUN)dn#f% zXwYH1Azi%}-|+a_7SUx!WW^lTWSpw_vO3=38z0^eaIe@iZBrgNt1-H}QLpZ`F;)-l zla`|riLMiC9t#sX2>WWh;HJM!s0MpRHD*#XA&pegsTLK2EiPJtccG~ym zrng7i#kSgZRfkYp|CJh^<8~VZsuA+)kM|C?pnF1V z{Q;Bh6#|Ipn`+2C2Fu9qf#Wz1=t@z)nCBxdw8NBo^yu8YBr2yUI;KGQ&4|~8KJT2V zY2OoW@~8&h3uYy(|C-LrBTcKuh@kOM38jxI#-#1y#$dz`V{R~mx%bMeR8l3yUD!4> zt0p-!YrPjX5}uVypa`eod6y@)nQQBuN>318MfD{xr{}SONYn{~?3`j`guK3s8U79E z;R`M!m~5;?S&cOv5OXdq!9?ua{4TL~k0@^|UoD%7q3XEr&lln{m#E9qDO4%%PKj5^ zH20ylUXR5suR1yX+N-e*>Y27hti)!ai z)bNj%k>(@lT=isV!yZYl-N1v+wZ_z$31HGUip~P3~ zKfMus|41lZ*=8H`Q+1YoKR|)M4ba-p)y=%JizJ<xW|6Z$CE}^ROhb6`P@G$-$)QirrA^q5N2an+r}E z8@;qyCV+DemGYmiaOqxx6AfK{fBf@CDR>Qw{wBfZcT5cKiI0cZKdi}yzL}u(wW@=9 zOzLs_pF2qP-^F*F6K(es4CFE$zrifzb#vg(z6t?xTenrVO1LPs+ig*5^n0XK$lMW> zTI*p}yTTO0_Hg;A;K2R-p`5SX8v*KVhde_) zC6vH~i);BFA|bEuIb0uLrl|9@B_aT$2-PwBTvN;87h4b``f9O3Oc4;PTkE@L79&H0 zhbuoWo4X_BnjJUiTO+{{BU)tCRcPpbcloQl{ZnEPsy z+A39PCvlb8TC7Ity=S4`YKbeYpbr8lztg*8*9%{<$B+`?oY*~fKp{*JaE0g;m`&8C z^E6YJRU#S^p@zZ1x#0#&<}Xo~i?gLp_g^B8=9pi$$1WGaf`s$yJ|APfX)thP!1>OJ zuB0vBu#U$ju0Y@gpuB!Bm3v5phQ{BE&j7$H3>oIeH(Dpm$CLdy5&K-hCKtCi+zQEM zt6>~CV*I~{T(vjhK=u(Xp({96)F{W@= z4#cj%>U2}4{<9^}gx7m$Gd0p2b@Hsqg5Bx@P6OVjcQLhW zmZL2c33K;e+Oly@rOcCD=?GCxsLcBLug8}{{J?JZ1dP|M)M_q~;SO#_EMHERM^X7A z01?3={5u8wK;AMbHj%@#VV9%4i!&LBghbZ282o;O5I1kN7cr5fx6|BMSvM>sj`|dE z+Gply<``znF@wKKn5MEH>}_u!+CKQu|01P2J8-gr$YBs|{14XAS&)(7+c>yt$nf%n zfbONktq=@MgD-)#AXi2g;0YG+;-${A9meZ6qYvmr`U3TRHhrt^0Ag>tQ)5Oz=n+$< zU?%>&AD`F0eXAqUy*;v8b7N@x#CUHr@U{^@4yFN-&3@d`qJWUyM_INsF17ZbCdNc z7MrzEP`?PJJbE24MAlwUw4q|A2f3%O{N-e}mZua6VFfy2HMwSWmOT5t(UOD(K;bLgtCeF;J|hh!+9Vfk7Io+CX2`VfqiQ~vgCDw#N>a? zy81KnT`s}08(ps(lc}a_QS?>-DUI>vaGAzq;E#~lvMzuCG*q0aBJZt(!pJDiff5j( z;Y9eCnbwVevhX~&^PP&#HINtr7ij-LplbtABvWTr(_4)nQ|S8eXO%FL#wH>#in$_k zj7GHZhP}VN7^`IM)UEvFkdq*U(D!4WQ#{6V+Ljz@ileIWu@-AvDG&|cAq-lGF5T47 z$Kxq50?1p517c1dMlMrTUuV8IH-VXP)pa?5?zp&X>biDbtS&rPo^A{OGs`l%pg0cQM=uimzNv{x4 zDlN;7`QnA%@T$L1D4yE)+%zE2Pung_siK>IniJ{4$U>Dlx^BzXn=i{0Uk8bwQGuC7 zq@={y8vnlwop`jL*Wv@7A;#FN;2voC2cIjXZ$`ArGMExAq=1OBlzG zyss6eEt$+(OF;XQUd~ykrKSYN;g*7frH`q0cAB>d=Ao*S?WqQ=R6qgM4eqabV>rg*c&LQ%t94 zD|PXwOFrRj5BX8>(V3zkx>VXYKp15Rt$eR3*=hG7FsEOksvc|yazVsSTDW2eCtpms_Fwr=m8+c}o$HbL&@ zAM=#guBD}*?)_%I@7;@A=arq*q84RWO-BD=WeYe}8G+-Ix0&l$EJG>C*n#⪼{(s z-o&W#f_g`D!qXLXd(e98j`*K3&H3Qi#KtFQ7y89ob>3(U84wGn4S%OyJnzk2Qyi;d zv=vEw1L#}S!y@vIpR-nRd=mM7B7Ilt?eS-APkf9va=zrGN1M%>A#wPKq_1+ftNGvim-02KR-k6>1iX$oLbMmo=FJF z%Q*>doC0N$urLtr*EqdcJVfwjz8n=(SZIVQChe$V_y$IEz4ukF5kH3F9-$lP2L1Km z!EEhc+m4@h=48bdzlyy&i&MHJU20MPnw&_P)M{i0qo;YX#+f+c_a{x6!a6S8Jm)Z_ z7Wa2@Vd8-VREH~xhvGL*s27dGGN^It2SiA z!U)*@eP9F-)cgXZhAhsV14KrvOa^Z7e?}UbT_6?+_?<+6vmE#?cr|8~d{W1PmGHvVDj{7P3Ddh(v7Oit-*QB# z)uyig8L4W_291MMNf`UudzEy3UE7i*%=cD3b|xV^A2~c-U+9Ghsy(wN=|scss*@N| zYpL;sLqon0g(r5ou6=LgLG++YEAiZTWlfthzwdVO&A4x*sayPU5)o?`^y0mlG>4kN zE2U0ZD^NI!RgORY#Fzyk{~;PNVhSOdQK2ZV4mVZ9vVGW#d;VRs)`TZeWK32&z>q4O zn_7iXIgNXFeGK!#nTM{>`iXg^To5vUB;zXw~Y-Na^|CGrsjoEf(7+09-H6YL@0F1YM0so?LrQyeVHI6^* z_$0Okid&P-UJdW?;&pw#9VKNiQl=wCpIGUzaU8z`-cEU$(E^n9_a(k*5VSqbE-nu( zkL9of#Oe!td|#9OCLyaPJ@!4%`=ylIDK+va3sS}db{`tG z@KgHAu6tih+6|5$o=*k8TytBiyEu3wLdumfH zOv>a3^Qv^q9q>(jUXt%va)blb2M7QM#6R7i?POQVIaS&(k$Ly(LMUp!J>ip)yy$)z_TgHshhFQbBSr`%I;V)n*-p6E|8E z09Zy7gmNrTbAp%~iC05_w03ldCDgnWwB7`>>C&`OhA%mpOaLaTySBLPb?kJ=>to){ zWx4Jd1?els%0722EWf6PJqjN47U@fJ9z^e_acdOJnlVIlYZs9dQ)oHYN+nspSAR^J z8_+Paa0E?00@L@L7llzuWV;3&3QGx>3iIy%G1+N+?zNQb!MnvCd1U-ArgI>6xcxhW}ic{F5*Dd1rzecYidzuRaVb1vUR2#%M#jSiz-3fwIy{ zgEa6AClSmSS5Q{tW~Gi?gwp1xY$%bZorjBB1x}9^*tG7b%2|1zs{f!(ZPXtSY4<%( zDX@&_=wLJAi#<~;8-ck4?Hj2cHC0`o&RVvPIQJi7)!TLv;hpBI%D5m$l$;VI_ywSF zc&K&g{GpT=^Zm;voP*)mpp)XJ$G^Iu7>w`~VZU{_yhFs$3_H5akP#*Gx8YC$9PJfu z2+r_OKPG08SM2#sL0rxZXMO_rEg)Z_U7;3VQ-WIATEHPYm$m8+m}esM#ruPqHsn5* zj=ra%*|dc4-`%O2p9Ueap5Kq-=>4jc8cTn#PhpZX6Ykim zkxhJ<+Cj=E#4=0n983w2!(mpNYitg@n{%Av7JQ`3gjlE3H6z=-elNZE9W&`U2YXY- zn;f%p1ype*0KEC%4OfAL)F*P(Uyfii&R>wQ2Yb-+x};%%8!0XuNz<(Gv3#lI{qEws zydTO(E#`4vt>F49pTDuW9KZXO_VTcuCfIsZ@{`oAC(PDNuGc;@M#koHU3BGp>-rVf zyNYhW?4J#-zX4+UYHcC_lD`4GXqB>*KBo=yvvk=LtdP_nt=qyv?N&xSESUUI3cF>` zKOtaDiuv#8LICP6^15j;aT!eNBDX^5RHsTaStVlllnJ^XEFi;O;S>1ZV^iZo;&X6` zUa)zudtsjnWA~pbTJdo&XoodVVv=jZnl#QG4UQ1bl4$vaO4jjRjz&o0U$j_VNDR6~ z>{Y(CrHmgD`1$HUj2AD^2k)5v+f>V}|Bm^oyV*^d;K7^Uy4;$S>BGq>oo6}BFME9H z$tSs2GFcJz4&hLDkfh}rfD2P%O;!tQQ*X|{tq(+0VvDYD0VvEU2ty)X3-eZcO3hFi zaUQW0?Gug_$-!*e8{jB4dQ5Q?m7TkdwX-1ZZlT()AnNUC-;H7cYrLsY8()p9h@Vqn zD2k0;{(%BYXz65DHh;_!9B3poCez;LL{N3k30xrb4r_vO9{!LZ_mX{)8{8PwTv|h; zDA@72CLk<@>Z=_7H@WRSeKu-2;m>(lLQ+os423R!cA?B|B_L`^0l6*GH!oH-m)`s! zvE_Bk3nP~~_Y*RGlU}jU``1*HD@dtCE;>sXGAh6uEqX3dY}fIN4N#CS-sQ)|2PuQC13;KB?E$B%TGDMHGXqs`U7 zhldv|NRp#aB6VZIdR0^m5`p;WgAGZqVYq$tCm-hhx*!-h5tG`q8yMRciR7t@`K>6q zohr28z42jYjYgHgqG%)=jS>*eJFqz9%HmFm;L&j^o4IO2m{ZBy<3J|bj~$C=sztU5 zn^#>NNdxq=g!mu~w_o1bx<32(Q|{b<^@MW$oIxU?a@e1Qm?c9;R7tq?M*7rB)0j`k zJT|1ya*j?;ve1Z3FB*3-SNu{Ik4x$J1z38v1%nK{TBbdcLKX#x39@G@Tci7M@9`?; z;ncnEj6f4dME8RreU2IWvK*~QC%5v~kI}z>j-CPPlUn2#6>=nfwWu-jabN>O*RXRH zAX@2I?`%KbSX%F60&AznI6otnGr)z#j-=4sfa469VJYvtF#)_3CL| z8C}y#H$S3JbpW-pef;iiVk!^lv`mo&2>j)2pXFTdWfw3+OFF{RGZIw`yERh&ORR#U zDA*_XXmX|=lM3}p!CJAX@328I^D<4Upe`KC#kbHrmGdcQ3zY&4z(2kla<%7=ie45p zUFY`j&}P?NyB_)|UPB11R6d<_+@)h~VH`f-3_2A{Su-Tr<}GFTm;8b9*kaDe`w0s# zDpt9lf^-3^vRIbZZ?J$M3786|QZyJ_FmQiUg~DhVC`W+@sP9)-0pP$U1ci5t9!XO22z)(KM$rZf3>}qDp zsF-v{%zoVYR!do?OWbDE_QV>3_lUNG=)!Lel2u+^!FIwY!XVWfCb9o|h3#X984mP1 z4Y9@({Uv9;1x}^2D{^GkTeh5VfbVLkP&nWUz*k-0<9uhuFM>dQ99y%By?u?VUb3fzu^c@1&m;%1teKH-ZUr$Tt*mL3c7dbMdxoK zL|K0DX^Hj7IitwVd~-VH-owl{lK9 z8fskBAI0sl+G%3}V5s};A9SnVfGDrMBN9S-!iT&Uy1t58p!Y^^HFHdhzD}0sg|#FO;n+oXes?p)!l#j6OI?n z^N*P_>8uAN!1!zq+*V#q|XdXXKCtG0hT=xrpG>8uu?_#eB0 zkuX(Zqxmpe0D=0@CT8OAlCPAqvyWdWe=y&A(Tfm31l1~kCE}}hqJLu%7tux#P~`|y zr1aT2*HMz7MuzASiDqYCY2d?6YMj6AOPN!AE)Q3P=RN=)~Joy zy9vhZCIH19!C*pl-WJv(4v2U*r}A{R3sXP0A29^c;9D=?X>yafs7 z=Ni7FZ{a^!TD`WTI?A%iy$1JkJWrXvN0RseDA-U(qjTcxB_}zimwwp=yEpv+4H}Mc z`DXxO*1%?ov3-%ZEA?bu_8{2)kfl41l+=wJhRYfWuX*b z|D_3LPU#RhzdkV`?kWNd*-7p<~|5(B)w@glBR?z$X+ZGg_PYj}As!FzART~h$+Gq#WqA{by8CKc%DUg8x8(;WM_gFi z>WU`&dWZMe{A@df3kj3jg+aJfsUbd~SYbLp?JpIJkGLv+ahY0(%)| zBuz5?%BksJCbo|XYR_GN zfZ5=Qep5ICo{B2S?e&FV8Gin~uXa~-iOZ>TG}Z7wW$3q+YD+S!c}oDmX3Xyu5DRq8 zKQCGZnH5r^>7Rc!6)50BfBp6OvZ!n^N@Pq#LOIc_09xH0JsKR(>I003Stzk#tmHq! zX2iTqrw(>Sm#oDiBDs9kq&t~^`P|)rN^?4}qIqFDy3Pt?aSqyh!C}&R(FPDd-;^fJ z_LXyW$uKJmGluWVbH>hlrusIxbDW0_X0N}HjrET01#X$@2Y8KbROz)X;wL$eZ7o4$ zXa1XOD^^>xJ+4c_SHpkC^mbeAQC)4BHt#kyENC6m}uEjY=WFkeq@}C!t~rmeEn z^HyWO15m}6;2nW1&Xaob<39^f+1!)u-Dfmh%G@B_&sAvX6a!xz%kmo;m|AI;qsAEF zKJM3L@-k*s|5GB_wNPg)O6d!@GPB|Fv0<%P=aJ+VZJ>(+a@A9(7yfxpHmzKDVUA@h z?`YXhyNaPvVcK!dSg^Z-1sj~==8KuPJYKu(ePl@|4*Xs(9+r-{I%G-i#Wz<`GR4X%JYi2^tu?eruOvs-^4r6fU7^ zvlleH)WIRO6HIdbd;1 za3r-s(Sh${D1%y_*#D>Te@&*9Rw1;yz>;6SxzgFEfEnH>RWp@{{+55QK(2jaq?wd(`kPyDmf(aMS4II1DJ4a zN%81{RX}dP#}%&{(_&MG{EAUkM0Tvc=s5(dB7%a~tr#s;(Ul6ap&1?KWs#TG{D!hR zVce$W`(#So73t=uM(?w`{>6x(LSO6$?&YVnP%vF=2hE7%;h0>Hq4-Nf479EF8ZN}i z^*AHCp}`>RMODI2K01?^nVtM6FXUukMfYp&-d!ISbCciSsZS{ewX5eP)GGHOjQu+X zDS`*5i_ZRXU@H8{b-tq5!t-u1t+SpJqRA~k)?T9MAds#X`gT|l^y=W(lUaSx=#?>M zcGgjjMEbd-^M7x)ITs85m3LU`6r!tpeHbnyJZJ>sc(>?Y7iSYIygy$_D}ThFxx`8O zv^g_g$SXig8h7n8q8VS0X}W^xWfRnaye+b_#&O&A9~#nUfCuqs;T>+L#gto4XJXRm z+h0@46NtZR)g4=8gKi=|e|+#_UviI{Jtm&EIwP>HI$frB`kg`vFnWtki;08L#m^K< z$5O_JGv#1GLTk_Y>eUvd-u(x2;m7Z4mQmBWFVH8fFa*``z{>;~{&Cq0AS{zyM`vyr3WJ z5Mtw)HD&=#g9*-^rSok`9yl`FUL8JP-G4W%cM2D5qu09ZNWe7^YY zxsuBQ=GZ=Gggwib1PQ(TNj)ELe*wz_Mg8)sAyGX&XJZ+p$@m-}(dL3q-4?C7Gp0dj zRkU9jWVL;W&DRj}OYKf$Md8Us@Jy&VCyb^>9q{jwZhd+H(PZYqDc3_SCLI}j7XmfV zjFtWJDH7YzmET1HO%d%7po?<+;46>w5h?VoZYN+xOX(YJIUfsU#_1aSoKr3Yhwx={WUzO`lj@n8BUvEbdPV641r|GXW<#|6V6 zQab5jLWuNH>lQ*o2sIYe?fQh?Afsoo(j2sBjtW+vq6&Nv>)8Ap(wx2>{C@hsLdgRt|bawND03{wPwO>>$UT4eB zIfQ=*MBT_Ke6rheM?p#FEtfC`8$HHTFeRGSl2uB`iMl1L;UE4pyN?=HW2Cd5jt zN+p|OFQ32h$uwWbe0rcW2-6%;Oa0`!r^65tNTwd2Ind~+KOl}Mr^b2QiS;?xh*JQ831yxhjRX~MX|2Y z^53d&qvhogDMNS-aD}+2wf~$#G?cBiNAs)?bV6F&$l*X1#cqO6w*h_ z{Db_zPaEswO>2~FPi;zLPSS1F<1j~#%I=n;|euFz7vGXk4FLr zTNh2=FSLAVosPfy?kWDjQdE=ZlF)LdfO%>LEdqDMLC32Gounh6JMAH;1foS|g+5FP4ZyW_o#2 z`q+dIeRD$z7|=9IA6F zZsaMp0$>{SbBCOflpn~Q`(PQXQ>HkMfqXDwTA+ih{ZsbN#!9h>X5-tT=`-wNTH5Jg zK5IXho-(2dFj~`R22e8!7l#=wW`Sg*WC#I;O)`;oXC|=)fA<_D%(Pm)H&=3wJ6`fLUt&UL}#m#I?Zf4Nb)Z0LQUZWY{AmJ?yV=GaM20G*JR&E#-lQ)=JamIw* zKQFqAE`7&-y$W4Mn*QYQr-~aR$^p9yN8TBi*AG*8$sYqk=`HIY#V5ifJf_01wGIlw zSG8<23g`{-u^fMg!`%_?N=D}<+wyh+#mF$!GxgRB?CK|&Tl4g~+akAqG){ry^9`4C zS73z$S9qPyT9(VX1rnXns4@MnTUZPHJ}X6_Ju(Mr;pQaw0SnAV3! zCa8~7yP+abBmI&YYv}*MizaHN;tfJBT_p){c_4v4?Ota#fpE80tr^a z^~>TwL~F?K<0JjPKHx}0s?JABG;%J_p@L$KAesh%&@A{O1+c0l`x=%Cw9_`*3MTG{ z|1|?SLA43Y%ec{9UoOcxl@183x6OA@+@4CFNXT+z9gF*yWPRuPj=Av@gihWx8;or8 z?tKlM97BU>`_O~wZ=`9AQe+iPYJ%E3r|otct@6q#anD1c-&gXsM&m8gsolCqSRl|kh) zLa5S2Uy7Gi!cs^diEp*eU6G4BS`tsz#A7B)&z3H{e;mpWsQ0FzG=$O+{{Q+!>kPPL z2W(_cf3(wotbMxd8Sc3?{C_;XcQo7Y|NoyD2@!k7p0&5CU9nfys!=Pb)uQ&SHe&C+ zYDMimYnBRXZ)&TwC~CKgqWa7G^*O)qIgWFj{KwOEJ+JF=zu#{EAmj?CfZpSSUb4-9 ze>WdiT`V8@iM|D%*bg%X zE-qCl&IlkyusdUWvY3J<))*K9hZO9ItYXv@kR(|Q_S>VEj(lFAIhLU zo$Ufj>*nP=bQ_Q#Tq$n}IxtpOH;D%DlckalH>!wtisjiEOSX6lu@1rOxjd|Q(S+Dc z7#AWn&$`6_#vF7V4=AfC2B^TnlWck+V5KTXPK|2HM^)b)`U`S7KBW>)WS& zKd-o>V8M-I4h)E>DV2occ7Vg4yWEDy0VWbAcT*eS?za?TP<(?y#d;v${qDU~RtZog zkgik|HMT}lR0nw0kpTKpBwMR_hMgZbX*PLV@fl`j{`_3#p>^R>k)2!QIeiABoK=p2A7c^gy)IImhkakN+;uzLLLT z_;{}S(L`a3v+B)sB2tWirENWFY$F4Xk z5@v(J=u>8Pwpl&0&lm$~`teR?+k#I4_EeK$;PzHF_zx1?WF5Vb!3f0@;juVQ&uL%K zd=90^3&T4|4*Uo5N(D23>Jv2_BTs2_cyG#Ev)#V16P38f|;`apTYDfVSt# z6^F=^R%2FQ-23aJ22b~jMwX1^PhRfwSlsk8Pv5SmO+;WrAJ={G@BAW1pGgf}GscY@ z|E@gQD{g5iX&)i-ZmrR!XpS|)NbAO7ykKZ^u6^)-st{QsrvwvNtXw#`M9{iAN90SUs%=QgN2vnAXqYl`R>4sqni0U+ z2xMoH@!jW$(6V~6z1=g3yYGZlEH>(*UeQHFY@(h#urSaUu}G~R_)CWd#}xY@N&@=0 z0A$4~1^e|FXcN(g%+&{+x8G%;l=_sk7)r3L0`l2RiG}wSC7UStF;V_< zXs${)dgz_sE)j&vCac<*DF5_YHIu*XLoGXh*I(|xw={3BB;+(i{2>-1Q99)c8=^_S zvw=LL!9GVcpO1y8%eU7pj*=QrZl_KL<-eSf-Cj5RlT-R7`LJdASn(u?+~zFK4Ey+Q z`|;`S-Fne@+bF+c;lJD6t$%mp54L_rK5>z`uc{dLCv`D7WD5sXij6iogKOV7vKHDm zL653xt*=9G(P^cxDEs&TqMB^MjH_-K)Pq8-r=SvWi+gWW(n0h;|+3lRkwB!$Q zE|@Iz_H(xz%JfUd4LT-q%eO-&(}q){E?D|}rr5Gb`AFJe1Jq;e32hgvV4BY0q}Od& z#3D*8SW8@)4wB|{*qjIO$o9kIecp@21 z9!o#)%xU&T7)_OCL^FP{hKOl6;-1p%-8tg#eFG_O(SC~0^XkbXe}-qLrbXf>oc#(C z*&?>Uy=hq7S`Wu1TLqi;D&>xLvAp13Vuj6OGy!8Iz{dLq92twKPCZ|Ch$OE2CEfda=uynr+=Ybf zvB_rqly26dW4gd;huEmh_OqT(KL0k~AZyGTzWv-8uk^pMd=k8lx#(KzA^Z4K@FU9h z=Dlqm0W07X8%~bb$9*m%C)YnaJKOONeLmt8Ig+7Ho~Ac5r$zbgn>*kbs2Xm1@{Y6T zUHEi;y2K!b-Qrj^ytMUby%xvM4cX6kLqv)2Dc;@}`TL%!7|7w?B?fwU*mib9+@(~) zRmmYEWSGTH5;FVhdD5Si$D;>BY=1_U7wiZG=B}j}6*9ec)bT{$T`b{?AycpW8|DKH zkPKDTs2h#7cM^SC*%u|S0Dj|zNtSLq>X8l*_75vH!V>}fEdLe6AmGBPxb zWq&|hkHvb)Jf+u51%NSRg^xJFawFD}bM`+?8kO@o(+9}dL9DXILT&mPNezwDCQeCJ zF29*0CNQL5oJ^)N7=o@Sr|K9eiX9k1_*}7e+i45szS~Eo@+)W%1`gf&mV^**WJ3wsA)@7h~++iNT`P^ z6QYSJY93WiWjoUd$_Dl)J1h~(C+m$pyRhiMw24(`mPi^um|CvL~8W$)W<+wB0BfTJ_R5&P zPgYpz5N#oYyi0L?ye`+H6ZquVA=(49dw~|M)<_mFXpxF@*Qr*a zk@s@&>qD;~0y7B%d}TVf(xp6O@6*rqjq4r9Ix1^btNBHn9LuOBlC3 zoFH>RH=T}VVI9NQS-&=M5qfD3{_ zrO(-0lba{Ht)5S#Nrq9OH_8VN?5r^OKuT{JzOK8_=J1@M>E>++`LO5(&Uj^Z? z47#MjV^VEUXl3`Dm5hO-xxn`WEeIJucX*lOmG*59#hmyccs2_01#T9~j=nF58n5yD z0+FHi^{2tT(?mI9+T76+TV1giwTFiwN4DyVb6hem0Im4HyQR*e8`;~RqW>h0J9k+X zE|wht$dHpB`r+Z>s}W5lRMAam(E!c?^9D8@yx4MDOdd*vbI@)kHr$Ri4945uehZ&| zprf5ves6rDd4FyI@q+`dZMOGy2N4^}_sMe*$*;w*Ep{0E=$o;NIe;gf03i_qLX4H) z^Vv-r6Te%(kYFW$M{y-#c5E#z0hDeLVQ<}p?#W1Xedvi|Fxcjtn3sNE(dN3rAzkDo zlJ&G4^HGg!R%kWVzQm$ce8N+N*CGdw4mbMPmBWY-zpwW#=^Z-@u3|{Dn>HsKkdaQ- zql@2t?EZ9l-s99U564J&XIkeq$$y{7&AigiF3N~;qdc0iHJk|_x5;i6s*v{tJT_DX zzhCQf*~wrMxRL*8;RH8A!n~Uz5)jc({5d0PJJN z4<}=Xg2IAkFl(Oj3yx*m4~vXL5}vUB=}I$-sHj;Siom?4N>*XIdo$m3sHE_V`d{ZR zT3_UKUyz;pXfp+uH<+eYSYxN;BwOx)8XJ!#KvJRsVy0kikPP><4egwl?jVjMS{x$P zk*8&Lk7D&IfrmiYT!WhJaV$9UyG5-P(t!NqA0vJ?z6=^hgRudqk&Ch*n~;fYHZ>C$ z*Wk1Pjb1B>-`n9O1I7U-q(4yEJs2#+f?N_^2K%vwV893CJ>1ZVHk2(=7Ju zPG}FZSnjj7V_nX7y*2H$v2*$U37sqQKebxTTb`K!&^5q!L1Yh3BUY~n`xBJV5II_*XF5BF^DJ^?owk{sc|~*S~p|kr*wK`m5adY*xKb^e@gsUc;)Ws*~bc%@GT*5t^#rOyvGvYgoL$8-`c>*l^ z=b>^lG8d@K2HB#zpoGR` z;h%68Me~`s%(gx8iU{rlzdBqQQ)Q8+{ZT!?iuon&n1WY64*Mj5bf8Hd=q}1|UI~Ix zSecG^&~3BlU27lVE0ibH8;9h4O?nYnVRq#(=f)65k`(HAZl3(lB6Lb~FMlUOl7Br0 zp@AoYbanX@K04RrO~KW=TTM#B6tVe$TvVnM?UOo(>&n&kpkI9Rh>q)1L()<}S>f*p6BZhUj74I9 zbD5W%Sjt|Xtz$1TF#!AeB|Vd2R8ns9+)<$tKG~#^L_2xPn~PHm9G~CyT}=jT$Rp1% zh|Ht(=)BPVl6bxA zE4vT>{R!Mo)$~{=?pzJCjh>@0a&ktNz(SQeKEe?ZYZ=|B&8e%M)J&Tz2ebU9vVS4^ zbt!lM_?fXfu>cAaQ=ZhJdMx`u`+RY@^IEQ_$P#g@yK3l%f@qqXsq=~=OKp2|NF1%E zOuGRdVEbL-*awJ1Qw+OC%hMr9A^{@MFKXN~EwSk90)q?lgy)&0?+YBUPF^Woda8j* z$-t!odi@2x#CTv+VkTdEbKgJwf}+F>b*THgb3Vni24DXu;a2l~vcDybY2NLe&b z4o3t-XL-J6C+j0Uol9DtAXv}QSGHj^b{aVH zM&B52*)oGtMD<_p^s<4Okm8eKNBqF zGzIk|$Y4)z%)Ox*3+YvJwKMMl(?IpqR_Yi^(U+uPN@f$lDFg~A@vQ^A>Xve^5te*# zXzlp$i@}l*&s)6bd_4`wtAfn2HJ^P=RZE!ar9Jz>{zA#kv+2_t^`oCUCLtStCTLx| z!-YbNwrOfUumsepubT-xsx19sWK(FIeLU~{whO#}7k^{;i$5at`#sR&J88KuZEU)= z6asZ?Wj92rqMP({8GVyh-wQ)&)@p6>Tm7JFGSqvp_WDgcG@LEqX~6CSvG zi@5N+;Jf~X5`DzAMA>goC`jmHkNfzAPk{;!H%Et9{CxO4EsWiH-~|wFNZn}$P57D zrRZw|`T>ZXS;kY#h+h>m{cOp?_@17acG|}^1*NO^w15dfP&AvBHD*irkBkn~zElTt zJT#I)cgx`F#)C&$^^Nn%-3};CF6iYWQtm2N#Hq^ere^T>4Yd<=x$C=2vNsPV2{k9m zPkkFX!1rDmJO{6oNL0`zSab7A_!zS2${uee@|=B|`)%kXdovw%=;i7pgJ8CW#=Pms zWpDsnh)gjE(!Zj}BQUzWRLoT08Mj6b!s-=ZGmhB$YHu;SuBpN%pfpNJ@hK@#o#QT_doJwK{ zS6|x;8z2cUHk(;pHfP)btyYQ2jEFt8$SFqY;49O*3Arn$7Y=x31VWBfRyr;{dispV0V-Ru#YgA8tiADP$< z8M)&UspfJVSZ*K>)y|+dY@z|eog67zj2aZ2Pi)2dQru6*zxgw3l>W9rW0HJVaE)&*L1J};K4w*AY44Y&jl&GkwB-j$EOYbrjCl=3eFE}Y*0WNVV0x6|ea zlJ!(FRnY&K@(iR8Xc7>6Vm1N%C4=@o0Lq!+8Oa7X(u`+DQ-8@!1)gj1k8DX1j{>?v zrhcH?38DWOmKmg_<5R5}9UK<;@NGdIe(o(ShUE;8Fov0SM^?Ggv%N_9 z$M+b-KbM&r1svhKU;2@hT#q`7D%ND%eOwTB8{?>>2W8(e9zm)GdDb|+5y6JVKq`?g}1{75 zAp57UT$Ed+?AwG}c3dOepL#cPlq$Elne#S#Vc(sJaaq62Y7&g~g^)gdX{y;9K04mP z6l>o2yW|N}{2BgyO5SLH0H3S1JWrGarmWNq=ybk{*dtva+oS2s<~x6S6A0^(_HZ4S z7D=ojq#t{U;|-HCS~m!_dY+<8U>cS&cCDBhR~HANZ3evBQyu8i%2%uk@XAueZCU~>S(+EijkY_Wvq^=o7!n~$4N6?Yb>nS>=@-AW^~d(%-i&N??kZ-kMs8XI5Db^qrWI zz5G%CF7QJQ0)bbsXoan<(nij=e)y$}FE*F0`*bTUJg1L@w!14IF8Gt7Xuks%jS zRFiX6wGwd;@wVwqmIL+bd~EeNI?Zl)PO{yg%+q3uGkG(oh&v7$9`fBTF22D;dn&-i#ljX#g~pLrC|!}JDK z7;=Q=?T_2bU}Z7v0^#G*HQ1-T>31AkT^cmaG(3*3*8Bshf8%LGJ546buJlUJ1$%Ve z)<8yWE^1n9ZU;OO05v9)Fozr}EJ>YaQ_%YEl9;k8R&DN={;+t!t&hkS-Egrk*@uL} zC!z0X96axl*n0f*5dqsFXEdj)9PsjGi%8~}k)yx2mY#n77xFxr?eTv{103n*HAt-+ zo?$BXbTz5k_!R>SDm?DqGwM+cIf}~S$}o@}5Ak{+{N?yg6wezyGCtpi6g*NC%n%@T zIa8A1r>rDgpz#IlW1EZ3FT;!t{rP<4m3973|HAz$zPkxRimTu@pREzd(+vx&an01z zv3eLJB>2B?Wp|hjZk%Nn<)E+sH7b|YJHP+UDdw|!2)S@FcD_ZLgiSKMEn05^>dmtc z!vD+!ed;j75?mJgDxuC>Of?dKHZ7pcfZ?LUR*)rYJjpf!ArZ(=*HkdlkJ>8 zRrR4yFLMOq!0@SC@VyEN-)8_-C;qvgGzJD696v8irpT_Q9$>(L)aw3!QQwh-)?2`W zAOKw3i_FtY%L*8Bv;{NzQn+fMDBAA0&pq(^q^MIW3F+fC&L^c+K&kp$VHtf1sgfXa zBcR3h3rcMNP$3dh{Qdgt{zumyYYX%G#ApU3I^&l!Cx_im8}3$m_wH9+kqCU`SiNXm z?-5);+a=hYP;J@CeqITP3!6RlqjTC5x?Nnk1aDAB<}lu}j$}M`%-1~#%x~AK8)(7K z_812~_%sfvjHCf(5Penf$CD}kJp>kRcWkRW{UZ0F@PZmuCpm&~Q24{XNUR21(zTVQ z1>^~m|GBA+MTa8KnQP4hCg|{v81l7%E9^RK!pacMat}&4kh;nujeVF><|`nw#sY~( z*;Qxr70wjb2j2e4TM2-~sClTfBG19VOqy~4b$ULuJf9T*^zQbgZ1w2>`eIF80IxHY zuB9`ZDP9b#xyo>XQK~QdjtRJ#X{gsLws}AU;XM0g!%L`gM(MMamO4FW+ksMR;@pC4 zu&!A6+`Q{HP}750uRC8)J@NaG_&INWLOQn|%Kj`_msz3wpD190|7VSHr+s8(IQ8V2 z1Rnj-C)&4V8y9i$Tk$}I3{RAKy-J0|mX6ASaXojziOX|0c_S->VejZiA1_$OZ4TUq zB}7`rZT6Q#Bj6+g6=db!rMTIgLhYVv-$y=akeOYjM61>k?;G2677T@24mxt=FHc+o z?&kiy8tQ$YT|kr6o5$a&lTMeWmlJA$w{uA+<2Q;la+tnkE{73z9w zvn=uXmjapM!)UI;C-`ecKf=_S3=9MTe1de)mkvWW&PkbyA`MHO{;>k~KNxj&1DKB` zU)I2gJa;~wuRyhCx-pH{L`of(k&2q$hjX2k15gBn`O3YajcJ;yT#R!g`eUApykC&s z&Wec{in>jy@JYhmKfP`PB~z>IRiFa}kU$Fvm~8fPl&ai@Bs2>WA1w)~dOk9;^+dw! z!@{A-oKllQYbpZG-us^%Tx7se&&F;$#C)2^b;{L*Yz zFnpecFf~ciZdTs0o#`U3Ne$M7Ysi1v;wRkzRlO6f-Tk=seDU_L>=~{>_F6H>HtC0P znZvx=?$*gnYfr%6?U@KTjLZ}uZwTmn#B5*zh`2#A+-3LYNL$jcGVib_g^R~KF}KPC zazI8>xE5r*zd|SWtL}Hjp5&TYCC?~ipvmsM)AYjbI@SD@CcnaNUljap`ft$H8-d>) z>Hp=XHJWHT<_7NWUy#`k(#i2XB(QacP;lI*0lj(E>BeDVw5a9M2WG#&+WDz*sz31i z_?5nJ-GR-GuYa{MD>CJcP&BI|K4U@pM?)5h?sC8dM}0!jI;b<`2ny17oxjJeL& zz9{RL^m?&P87^l2PC`*I0vS4gQ#V(Rs1&OB_+gvGrjEv+Esjdkps>WP_^xtcr8>};# z4=m34nJ|x^=~_&;d}CZFlCg@jL?%27$h04P7LcBM#^|>QQ4SxoFN@MWdJ7=OwMMQ{ zatLOpnY4iR+odU{zO9&080F*>Nb^>y?XfvSpN=tKY2fi>d|GIVDo z;Cl^u@OZDN=)``PiJjgyp;x;2WRJ1Z&Ur!1FU}`V$N!fZi!&2M4dEqW;G|_4kA=k^Zh#u$P1y>|o?h8eHY?c=D$--I{k6jhsK{QmZ>ALT!dOWn$-JmJ;M$5$6C{iUCTIbFq*xli4i7@ zU8x3t5nqSf@Hd*fnAjOf`S4eW_9o_T`W-f*wkphAqds6+jx1;0fD39i5#&6(^iQFLZS7&3|$;=DYrbh z2-uCd}$gkFqaer=!AGqWeC|ZuOZJyG}uj9PUtM4U_iuq0zuWt%Fqtt1^YWBdp z7=g6Q-~cR#7_e0_?)ARKhr)6w+DWFDje^JomR0xNf*|yVU(?!RB{=15%H4wtt9%~Y zzJMtV4u2J4a(s{!3K;Aii!%yo&L6|$QzgM81Eh3C8FYuv1*LyOIWwn1%?%a>>nG4c zTwDl_JZzShm)P9-mm`V60qk*XGvO+20)@>{MWn9>7tki9KY1+Jv2pD=J-({LM!+)Z zn&H!zcv}@~uQUF}%>|S9c-O~*6pdXCp1wrUtG1toy)zD^L}{9NdhxCXDHBs5w5_BI zD|2e{ER#3FhCCe`9I3IynZFz!2_P6eo7RH%RQwb)#u?_8M`=e_cQQoyBY;RAN1buy zj8=ujuPXIwH_EWr^;XQ;RqVTfMi;sZtNJ2`aXjSnh(1R8|JO9KvX1x==%pYwv)bKh zj-@QMYhfS6Bs)3sI3zAg&rta|&XS>AcIV!i=^79O$C*A75q-XvxNaSokYpv4EHz0F zJ7Mq{ZP;kT*ZmP#o~y&h8dX|G(zY0QnZ{+SK@pCK zg*ub}7i99$iO;`@u`L3?I*%6W0~UD;nu+p_< zO!iNP?0?szn5Rp}AYjQ+VK6N{x3T%;e$S#n~yE6CKg z@mjnLQqwEEixA(9QWMWo*1l7qHXW!v=@Z6B|M>SPPSU~TE1?$=Y@MOrH!Iiq z{qW>*3CM@D*YhbHVA^sJ+6SicJz)Th>xOCKqiM7T9h>BdGE?PJ@gYb+G; zQKZ|hEpcwq!0uw76C62NMJU1MMSf=haag=Z;}2Z@xS=1-pi3ahhmeB*^mu@z3C$Q< z!-syi|F)r@!&%am&FoQyvfun2+O|6u8Y*f1F?q6J%OE|3_2rKbopLr#t&eE>ZWL91 zCdBsqbEC0G8eiMsB7RH195H3X9>cn1`tN?WhT^IG940z5=Ux*Hn2_?s@z%WlG9sAy z0*Ze7m$Lw~aX?yJoaDa)aqc2N;oGxXep01CJ#q|)ou8<<^;Q4POw)puY7EteK@^d>)(?sFin&Qjqgk z8pDyao5(+Ch;SFAvRWCS-yb2+85j@EKKu9=8`$t#L#MR)4eS#JtJSsCO$i_RGn^%L zHFs25!nb%}WWIW+U<6C|?ncJIc`CB>{EuxzCUqp)zC4lpFDMtyQvc#8WP6RkHk%*ScB;Vxk}G2nk3dkq6VNvuDkOyL?70EpM@l4<6DV0 z7JWmTXZ6KUN-6Mecg(gW%m1*->%7z1jtdO0cP;wsJ=|1ua(-e5;XX>I6lpI&p?_69 z`{wi55g%`8CUVL7($Qn@7dx*~n_g0J*Zt}Zi~n}E`5N7snvmwZgD7{dv;^LkmW9{$ z{#oRnP+eLX*^W{78xm(x{NT!;@Furgy_Ibt*dH$A4RDj`0|t9F(TV<@?pSN9T(O4HtWiFCYu zoVU7*s@W)%M69!LS9IX^q>Q-x+5Y$acyl2?wT<*y$-|Mu?_x{Zc5#Z4Ds!%~PxFK# zb3{kMr(;6oyi_^jcyt+%WJ3HF{?o(1G)q;lARq9#-k+ok<14QhqVcBh1~Y7o5xxQ& zyTvP#DmQ;yUg%XZUL;a_v5tdBn6r19UjejteK|_5Q|;}OmuVL2ODl7XjZfu_OagNX z_(a}a(dj)&4X5foNKNT~?7Bruub_!^QZ?OBEe3a54K#sccZr^jgoa~fM0H^FLC%13??(CX%<{I;h zW_2f3;OFuYV{#?lV-wLs^Oi^RQ#t^w|bAGz)<`@Z<-{JZ2RQ}(y+8S_*P=}XuB^)!4^Bj-nD1$nwRxyR(5eNDDwkGoloW(-+a!higc z_G8_(RvuFhXLB}rU4WZ8+Tx9nYj}GwwOJ*Mh35?|?}&b3RZR#5fCA_{4FTcKOvKN7 z?gf)y<lpZv50|-aHh)B zk@$!<;vbY~r9zLsZJNZxducy;11Ey6_&-umgNSN_Na7e_7<5xCinG%%b?y(1Ma7!3 z!V$KObdq8DOIX2u9~J|m{o9ly`~tiR82*Vse`v-7k%5um7xD>~Wd+?@$$oFXKcTN! zV`T!1>9>bZDmAl+okRZW|3N(eyQmR%viO2c;n9&B`+FR}SU$FTZGE<&FI4z#6C(zv z@OajZ<6DUCRwoFRx;RRUXtavVDHm?(aWmVf9rLIh($a^cUL}q51y=N_*yBFt!p`^I z#{;T(epWB-XF!n<2MBW@#i3jjsr?B=R-T_hFQ@H~10P7x1@IEmZKn;{5q1{rzdlt) z+MgK=B>+TJpsdIRSTVD*v=VtVJx&EqE>+1j6r?Kh*Y~+=UTC6oP~7;hqd-|wQzLVZ zw4clf>4+r3CApxGIb0t9&B4qq`xfTj>J(PLaxyH(gwTO1zINhlYSM5yUKTg z`ZjiUcnXRglw9#Ffpo$(TMhzoZ+Q}_Jp0^$l0meS=Sy7*imAu=naHtGdiYK$JG1k| zTOIsDANTE((pFN!LeSg}RRy#hiT zFjblVIa*@SebSFL@4IEXe=F;1OKe$(-CNgh8G8FZTv% zj2pV<#KJzhBs3N1bbGKl4dE(1I&0D1>-;RYhi3V5nVf@yY%GQL*4qp82am|hr2O-~ z)Cga1Ad8pw+U40~+?Vs!#Z>*6c5Qy{pDA;hquJBhV%{6a6F40T7qBUgi-5Qcd@ zWm}hz{35JlIaAzvR?p$DSSjGo-=QV()bQ<*R1IyWR}36ptSQ3zG$@6#>XwRr=ytQ< zPtvLx|1g&lLn=ufzw|AC3UX z_-}s+gQWvwBE=8+YLsp-{rUnzDl@OsW#4Q5z)K_Kkh5h2k`Juqjb3!Nv2s7dLf(%q z&Wfi?Z(W}=jv*W%S7O5oV`96Qxb1+Y+ zWdq1<-lTMEl9{*0n@7${fq^6i9j{CK`U5@fHkrcPPKe6vT2ynJ1&9Bgx6D?z9m*QP8=JPiDGW_32gH8>uk*>xtw4xLtFFqVX)C} z9uV??6^kQ?AW&e16nDqhwZX9@gXuyaE7vRqE@L!@`cny|p~Nrd{ztxjVTThy*c8l)=ZL%M z*WHX>9J>BtehKF4tWUv>Zkx=Uq{wUxc$+QisVX^UMVL0dxC3#;4}FJjmu8kBsXph? zvWVe+BDnmx0Q3i-9IiAU$behxEn}6D^mM*N4U3I_xg88jl@*|CZrYkrRcLWKdykMu zqJMS5NT!+~vKjl#>d|-mX9gFkmXuuK7sdVyQ3>q0MaUDuL7wC+p4rJoU&QIRi9Iq} zMjZTv$ye4BumHxsQD#tay&u$qQ-(*&W|PS{s%?KaG}4u!S7r-9bpz$D`obL4dDXd>F`O8WZ2PC(5`N4c$UR(* z+$%S+TTLexi`Ng2U`(-R20MrLY0q2 zsPq%EIHFO4GJ8Ec9Far-P617j8(_tq;3NK(UBLF+iBP!kN&@G-K?+98WLMQN2&ij{1ziJluD0&#M+Wt8C69;KD z8}rn5pDuyE9@Ig`N*k_6-0%6GB=LUb#vO%YtsDQ*THWI_4wLTxyFgwP%oCv0z6q0haeYUH5T-CS`ae%(sa(8t~g_uw`nY8^+VJms`j z1_5A4Y_KDYvP!<3{S#lyD0B}XCP%iJoe)kmtEII>I@cIZQF8yS#+5(mGja4T=JHez z@u<+|ZGCc^Lzy~&ymiycKo4LFnf~_Jz=_e^4zP{O53g4#J3yQOF%l`K zeRC$eYmTUZD09YN$|W@P4g`wNJo!iY|5t{4dzS|Y{BfMqn@;qs)Fas{2!29ek2e7% zCC2n=BI#Z+F`SUFfB3jJF3Bh+P>WCT6s(m>P?(IG{N+&UINFv|`Aa(B-|mHM-|asI z!uv*orS=7KWEe8>CI2D`4h$~Tv_2SQHXzsSL2j#7913bvQ|9Iu#7`0zTd7G3%UGAy z0XTXH!4G`z#JPzLcp`bN_EV#|w|?g589*MkIX{>Ws#%@H<yLGZ`zl?BmH?G_uL zf5b9F=?M0;KY)Rxb84NgPdCI$;104Z~SsJ|xiC&un55cxD!2DaEw-400`n zK*kl5NoAs+7I;)jS5CEKXbj_#u45}3jPw?WPNgc>wzsK#rMbYpb1O3&I`)~aw&Z&l zz9q0)rIi=tLm0ytHtY0*&jCe0-`d!xwUgmzZn zlV&L6l_L(G!Db{nez_8!|1}qePk*uT3(KIadd^8mt~0C?Kh^ml?M3tx zwGNm#iiB0auSQVI4vt1)3BzJ4Ny90*X8Mp3J?$hA2M-r{0`%2W9_BO^V}`a>fY=$lmV?+OUm z@#sk|VbmtOI?%hR75b+M2rxq}HU$GlC&44%JZ$nWPRiu9|o601cjkVFD;T)b9fY9HwrjMUlP7t%5SA7ELzSQ+LE0Qv(~a7EPuJW<(; z%ZLkQ70Eq#7Cs|?4qxV@FY&|^*fmakDI(=kN195iSI8_p;@vJFyZ;*Q(Hjxun`59_ zcT|m$Ng~h{eBQy?f8<n;jRRk*D?yAr-{~*`&s1IB^S+TLy*@P~`vKbw7ZA#sF;z7V_RV z)OJrr#hZ8;q8Ru~v2U1h^c51w_UMz>M#}-q*hj556Zc*{ zOxQGLm6AcS_B0r=ivuM%V+`@wa(Rw+!_=?yzFr3HU#8vB&21WyMbWG()wPi?drA|E zktuy=1vt@lW&9u+o?;c{Papl)(Y_iXsC5D0FOB|LnA*9?DBl%5^RI!-d$2sS7l*{~kdw!F{lLGh6VE!%IF# z$uejUiZ#K{78YFT5&k*^v=IOecC%VtI&Ld0Ysn=;NpShycTF;7e?ZOpkn4%7xjJY@ ztR9(k!Nd3uK00@UW@%rSR#dCVRyEHrg(Q!lMa`WFZ&T+^HfnYi#n^I z1A(wq9%imKy=qI}2_o%AYwUii&pye}(mbOrl<^0co3yH4pt%Bg`iCqIXfFd$p6NCb zJlu2~XE-_a+#$@oDPzvO31e|w#ZnB6t|(TzbliY^*4UNdI&*%vu z(t_ga8qW`6OKDQhp`AaD#2VYH$K8c%g@O;0%a9zx)9$+$qhXLx;9U!oZ$N&kN83WLSStTY^;KKyjf>9`?cI_*E??q> zPz@O=&K#O?Eh>EMXUX)ebxrw_E*Mw-)ex-Q80u_TXQZbqoaBDuWNw!302V<#<=XX7 zfh%`PgqQZ|g2YwI)0%3C&0eI~{>-n@W_;#OL(__$z)B);Q890uFBspNrUUI&yTf!c zhdD!L45@8a_HEp=%}Paxq|vLH23mF|2I26VaQOznw_LHZpZtHwSkOmn`&wTKK9|Ko z+CL_f+X*Tv0B%UcxzmE{^VtEIzDS@#1Xh69@mmL_(g6scwlrnWx=TL}NI9Imr{VP` za3Y4L^(U0s7{K||39CA4jjs$d65FHDM2f0LTir&dU30WAz&umFtQmV2o3uUN;K({i z$VLw+Iqtg&HKgZm9hX!#u6dgjv5FGXwX!*0e@(%7fEm6XM2@sQE++U)508TGYsXr-FV?07>rZffR5O6pX)hSu+%~`LM55H&#*fwzEHE z=*uq70SV#02xU_r7XFY7IYp_)QA%ycVelG9fPdVA2zVW2*xv8%K3Yv!-C5l|)Ce07 zE$#!x)%p}{F(bARf!;Ol--6@g+G!LXVH{s4paiDcM|;3&y_B0~+0$=$-r9F4(|JT}kM>WyCYr_GgcLC`Tibzqq^o{}wA_9T}QlxjJ zcLGRJkR}L7m7;(&>Am-2kfK6>(2KMnC4ms~O?-~e@0|Di@vV2Qmwyr_vu5_(^}6?- zeP38j8tZC2ri!v&U|Wta_zralgkAjrFNQ@fuOm@&r1Pmpt()U}BYm5fHAx?5@DuJ* zNEA^HTwAdp(rV`_>T>!P|J>2+iRZe8gs?XE@MZ+$W9TRs!h}t6@oQjz&KN+BOy#kw zDD?3uvEHYgo>vLi*?BfUXCJ9kS>)fc*A{VdF>O_DmpoH>B-C~9+V*W3pFG=e&B_Ib zkG#sUadzyUA`7%#mU_Laujrlb>|I4t$#l*cs*1f<{Q#0cFwuvI5k(PRJuBBSVcj+K z*aQDeHqa{q`|v}^N2Y0s8!46|h+A|btIaaFeWv`RGWe2Q`*?n&Szi~Fs9#=k#FEJVxMmWP zJuf+4(zbTrlb!7K5u5vfn_VYTE9Dx7XZuB6vY41RSpqX zUcB{ZIVNcL*R{P$=DTy8cV}-jqiE<^=f>B{>gLmCPm~~(n`H&Qtkg0qmcHG$Alw7# zM5JfB`!W#I-tK7aStp51gbb_aH9I{IDM3?Ppv`${vYjN?(w8Lh2xhZ-)f}jq<%DE{ z2B|~#4H8ySB34kriOV1@laG|^Grles>1%6bYHB1#?%wlO_iP(40;%fFz+dVNB-UYa z%v*73VlyvAys#2yA5tRJ2xwzL5h>h63_35wNO{j6O)M|EQtPND6LD^osy7#KBh8ke zL}a0v_@ro^3|)3I@JLItHsyUMkp*J#2j=Xi+@*KE0oL-n(mgek%RT2OEw)9Y0*ItY zs;R-)$5-pd@Q6*21PfOT*YXJ*p!7R`ad_!*{X&MMHY*Xzt$-4cFjzx-#XTNG`7@8>{RV+PXiSI{f8 z8s3l)eJ*JC*f%KA8e~f%!GPYh`#In01Xo#)VLJ-!{E_k?6fObRJ`Y)oD1)8Uy#KzR z-bH03ViOtKE`TMNant^@&XVIw%(S4{v0Zv6`?IXSpemJoyp;LO2V%$pdv)%N)vN4__Y z-SA2(T#FWG%zl%*{ovzU=8NnblPfQhlEqxrK4nyqv-Gu66@Jhy>LYo{1de79$k%F+ zU0X_C+bm~NRE%N`&eCMt^tm@tsb5_#sdtTVxB&dbgjFRb_{8YJ2RsMPxE^&1B4PeF ztPe>89lFNM`^#ScJSn6Qo1`ze!dE9Xq3gJu-uvWn_f*VeZC8U&S5J4m`cw@5mr4BC z4|n%vY10)$R8|%8xs9a=1wW?JZidsNOStuo6CSa0@tu+ivrqeS+FqHV%Li?A-yzjs zL#7qhNO`$*j;GQAJym=ayU2jf65JOShcEOUmXSgba`aOoAHrvvWYqxyWTFrZSmV?` zHqk0?Q_KTVrd*bXM;FQ|C{ik_G03X!*YR+;zfUUbk@du>teMW~~;ckOk)m z?=4H9Z4+oQ2X*e^h@8O6T>GTpEs($D7JGAF8d3LQs3+7{(4av?boeQ00b49Y4&Le9eVfUxn6YwdXoDrvI7g1UNOpF?X3R zP~8uhc8l^Osr{pM!-E(10u`)3Uj(v!%hMSmdgI~}u#u2s=q)ot6rQ+LS%Y}ADR9CO z!=&KAPZ*?KoA1%poyJnM0_Jq2R=ir0OynU1Y%*J6h)rWFbzw4$-e#7~p9rvxhsd?g ztFe5U=q~OKa(EC~m?bf4m2zF$bzslG*{X||5D-e$`}`16lZ*)t^`c5p&!$Rs#EKE< zCZ^=8yrks&s@tH-nYtsoYvppTzGy&ODjVZjs)z+xI` zVPCG?Y$iMa&jeD|B) z?hH!=Tv4$jYBPWyH>>Vv|KeUymM*zP=@w#CGsG@Z3*{pC#e3x$X~dRN7^XBwGfIq5 zrY;xTF?FQ7pp0hw{KPjBeu7ph8Q%5D3|jSuEevUwpGy8A6x*{=?4n>F+RRa>Y)zou zLhMKo@N&e9wCTOyKDK9q8f{FoNy^{*BEcWaFCZNep?=3Q(*EvKDK?vzCx#WkJNc9O z1=HU)4yE+NwrT2>a%V^HC41S;0L%$Z31XmIAelyD@Kkex+*hKD3B2H+Ia9TB2}f;n z5edyzqUtv&rxv*LyCzK^a-R8@IT*BeH8TRrzoDs zntn>cRKhRXYLMt_YcYnGZ?QxMGk3{k{6V^kG)oyEdO7c{W}$m5Q0~W8dab3!SaXT4 z+PZdkir4K6%BP09Yp2?e=>+~^pNXtU;%4nf(60ng0 z@1CXymsk1j3)zwk`32TFto}A{C>>Hc(;P`Ni-300+nBiZC-B z*`GO zu~R~mi`k%+rob&u2)~jhnGA%c?v0K!Xc84z{Dz$u)vXL-(!wKG&W4eOjBtk&K=v$! zqOaL(oaJrajOjG$VGFc_DsKXG`0(V zZFrg9@&+MCcBM@lOaviu(Yr~(X(n-oIUM75Yy~k7PVgX2c~MNyt@;EhO{I*?*FeNo z4}VPU36zLEHowBW8UNss=h;wq9Kg+{f}cFel^#IAB*foFZX$#^)ZbF|x$con9z3f% zeOnx^L@vYaK-2q*tvVkYW`TV3g82d)*uKx&fsvi2EtCtpRLE~)_ULKI0nxujGt42p z-hK7ilC@j9Cl&cU6*j8`0aPp-rO_Z9#oFvF)cq!Nr-fFSUz){}jlxtsQcnq0q-tNp zW_hwYi{0UDdPh{j9mUC^Me?1`_*(u8TF+g)yZ6y|_Ogh!)N}0i@Km;HB3|w)RH|LS zt=WR{o;MEQUpcpxUJ4!AORo)CCZ-8X&kD^@aI1$w|1ePkjlDFGZ=cf8kJI0HN&~tR zc~d-sO<#)tZmn=@MV~}*aK0)i(#%+930zH!N5lMCR-tLAJN}9{q8qQdYV2+wvCAho zRa4G$H5g-VD&*Rlpz7X(j6H9tzf#=2YlQW>T6e!b=x-cPPKc+~8G2eap(oHau`1V* z=tusI4;Br%W5yUsI`2B3)Pi8(g$JmhnG8BuV^ zZX`x3!oO~(E09cZb=HN2Ld8So(a5CON%;q?{A2AtW+~6tv2FLC!!6VEJ^NI^#H{!8 zKU`tHA1<q-V`0`dgY}aVb(hD`=Oj3UQ&{J^P*vISsP7H zWW=VGnvlGog+IYoJ+#Z(8%5!;$#&`o0DVIaY}L zHmNf0?wsS;8?F_c&6HmqzN^LTiON^1e-TWL&r0FzZ^=~u??xjdo6fCZz1pz8$ko=Brh4hrNL7jn9yEf0do$);oooOk@^wWa&L zO`sw+k2Ra$k|aDlOL!i)HQz35ykng}Fzd05>1Dd;N5!a25zR^)drg3&0cy77mPjI2 zwwE3Pe!9QeL+l31#OuP@lhJ-D5F?{xyltZ6>1Jsn#U7x&)U#oLserjUxjCBDKO|`# z-29!-u~$xpP9?IjxFCOsKUF?f)kk0mu|t@FXL&`^MLtJhlptyT1tyP}5%0;GB_TxX zgDNQb^p(m7QZxe2gsHmp-}Pe9?{Ro$csG{xBr2DQ;z8!kY-jKNa?^Vhg1~N7p%NFt zpurDcZpe&f+-EewXEHeCLrFaUIukjV@nph2Sw;W+DczG~e7ccXy^X}d`IRlEE4M%# z&dy4EHYfNs2YWAqYcY_`r z=A>X1-&E+^Fw?Ym%{$U1tQp(OyQNKRv%UE~hPAgJl{~~B!qa=R&}CsMyi~S?lG4UAzxT7!&n3Q3zlORg^!-4GIy*{;xR ziB7mcGzx>15a3lB>+*U2TqRhpX;}}G1@>4;U{?z->K1uY{iaI6wk$cIH z@aa_`9L+6kyvw3I^E)oEK4hzv?eZnH$WOKst;DyX9ntJDVY>4dD&}OcMQq0koLSk8 zDrr3-6SlGRMx8WtUb5QOa3B=w^+MqFYl>5OqWf;F*2sjE7Vao^&E{2xREBBc$3}O) z2wr9ApUP27?Pz4H%g{H_>3;Qs<&jS0tE;IQq=B*w=nwQrs{`KJ1B*V)0C_x~PM@s+_}(&!@k>ea*QD2Z z_@5`TMpAX;M9>k&=aaDWU@Y8HUp!&FMZUyif2wrvfH~-)qwRdCm`he|&DA1`?~|+h zVOW0e%`4Cux12vP8rVK%L!vQ4Y#j{NSl1_Wa4m>=cD}SqP*r*ybN#98_%A6}|I#8? zgNbAVpNYVG8oC4a&cv?Vh-Kmt(T*G=tBHwT6!Wvv-q7Tr`{GnNUhetWA3g z3pqY|)Sb*6`?(~PLUhvW9XR^guh~_P{6}S3LPtJ8_sJMqgCuT#nt6Kyx_#GM3Nf*E zRpM}iN*+HB?PbCwwK%b^*wAzeT||qa>+*>gpvD|oDG;Kca!C*h>24V zOSYT_z=Xu~T}u1lsU+?@=s?w@zB-sRv&SeMkFO3B^IFRUae}T%l-^60#Dr9?HVR7*xHO+)*jvCp z$YY9O;!ulaQg~T|nOE6rp^ju4sUk1zr3NVQl1O9mQkOLa6deR^J6sc9!&lJleoYDy zBNy2iAum;NefwL*(=b;p!A5>OUeC#oBkMODN+_SL)B9xTe$r>5Vtj@-@rxQf{wZoN zB{@_{#dRW5Lx>%T%M4H5I)~n|RDj8R7WtkR6}d-Wg*cl9!QHqW%3S3XCyH!^6STbPsmhyM=000RcHi{Kt;?1G@r;Vu|acTAJ!#7nH+;HN~UsR z*RM=CTjUab_@r_MV(W^X0gtB$Ne_s=OyxEJ^D=7Rk-RGRIR3bsGJ52`+v7U8lF|F0 zyu~!^SyiJ^LZ)k~B$lJB@yqjOH)(>}P`L%G&~TdOqz@11MoKw14{m*wqm?cZG2Yv* z1sZyJqKES??p+|Lv7#J7yJBr?*G0M$v8cq=b&;?$)xn85kdk?KlM=Cb(~A=#_g<^J zb&7a0khUOgGFgb@Y|!uzC6}0xGjaRU=RE{8-x+j}2Me1_kb$%9L$=)^0Xfef3iUYz zR8u?WU8lAlH3wotTeTk#xQvbG8tJQj1e10?ylYM081dXOuw_#!S`SDP>wp_Tx<@gm zS5os&4VV=##B7u_qAHXrbU{aMU=yh=P*h2Sf+pJ;{^5(D=aLnb5z0DY zv>-&3=o;vGT9kpw9otgYp^k~?_g{z+-sH3FAJ;w1l;2nj6wkONAYaKCoI2&RPb~Ve z2;n%<=lwCGYP9F(Z?31ppuT;HE9^zb@OENGw2klU=fv0Z4E7a@S5dxE-S7{h*JU*y zQHtBf3@B!jZ^<3OkhZ(h8JCy({lpl@j*1hK420Ya<%?$dhF{)B*Po|`|C%O=-%^Sa zqRMdit^v#!W|#1WbBeJo=k1#*JFi9Wj;e>WaID-C^GF6=rP5qeyVf1+-ssR4QS?-b z$2Qh%&Wf`;G27{-dQujx-fbj>0%e^{5` zS*{&GRf!=(u2Bc_#PYvuP#CM2XCsf8Euj6!TErFGkgU&Y1-_d}rW-D9dL0L#FpHmJ z_HR--g8d08eyc=OrO*9_2Fq+>J#(?E3P$)g1|Xu?gU=xlqKb6TLdOzMQ{c}GhR~d5 z$eJO8%I*mF(;Jnk?0uE#Ttv^_g06yTza|q{z2$J_V#B_}iQ5(g7pV((&i!CvK1CH5$jISACa;LVY@XAYm>-_i~FDCO1 zCD!lo@zOhFtJ~IAQzGfjC3G*?@OUY5b2=JG1Z5kVu+B&55)aPXJ zq8IR2W_2_sYDiyp5KGF1b3bM38Tr`dvz?ji)wu#|dsfCb zEObxp85~~ZuQ}uPp48fHqU7_Ot39<63f!?f+CiawNT-a~r6l^NRVl8UUk5U%)6W5t~Q|!}nWkOb_AA0yLZVc{`v#6`d*WRl555pF!K{ zjH}H1*GJQ#Cs@OvzzRXdef^ylszjx8{>9^u(r4lV*UisHj8{Sqf-HRV7TjVM+%_8@ z>gvvi_Nd2B>sPDpeX6-lvxGSDH0j4%duHfa%|k5XpqRQRE%9~qQilQ;GIBc6j4XPko()03`2 z`JuE%^3$AI)l?*AKr$&wfaGa>8b``mB5SdWQC!JQ?}Pv&@-QXlN}dN5TAk#%5if*c zD&=yoVC$#yyoFsdw{OI7y5tmpuk*2ze?zBo(yU6E$deQaeou7IM6GOsFPN%uG})U#?@h9(M95|)km7b@PqB@pqxst;?`X+`X*c|cHF_>WgjFSD?8LH{ zCGV{)|H@)ZBb}z{Gw{GGxju><*G-5f*}bs|Sg}EP#)sMy^-bim6U_oO>6Ps`k2}c7 z7b#&1V+sl#!A?` z+U^7SGlaw-yl!v~`)(b01^(vscKLXs3Euo3WxA`D;`kZc?o_dOd`bLOr=3t9 zk2Y#=Xp_r&SM+{#GPi)w49j;|dsXe+o)u8}7oV0s3?aKOFdv4+JUYW5;fc!Ug~}uQ zN8cMe7?cqN%A$hC3v*SU@MyPbzs|4^fHa!%-q4E%w>&0sAd&-as>NixQy-xm4wmKP)cRSz|kI7&&_B*9B)KatCzB9v4!s#vE?ecLmr|&L${r6r4 z3nXZIvImu^4}{Jz|456MQ+M8Yy9bML5jRO+n^65)ej_38LS7PBrT7`vf&>HD|ngZd1+64W_N4oAwZ- z()X+r_AhbID^a+(e0>?9QszLzSLEyQoP=7K72x3^mPI5-vM4ba;p?EpQhp!GVHJCuW7nw;GB&!QBCgO&Z zQ2t|r1)-9~o^dPZuuo2SC5sivo=d(U$`6u7QcQJUWicoH3TH=fAo#bC{ z6~|hKYk}VQ?yOzofvk)ETqq`f5~>m=IwT$vG63E3u$*hu5#6!Y>^AAvvQ>Wa+Wz`A9;WB# zcp79|104n}{j$%G3|otVS+Cpf?txQThym%2Z|-5og`EA2VTX*XiP+B00P?CblX|C_ zvd5o|D*pK*50JIKNS~|C!?;L3h@L6q0S&u6Is{msKF#FS0BQW>vfNkRv|ai0!dIgMs8eI@H_X|W z#1IPPyce{hBiktm*zVz1dMjPKZP!+$w-t_Y3+|N;CIAV23qiKR-~4Fb_X%Q;#7|lP zZgNs=u-8yOnlADKT|4LM{J$=ZJ105Is~!dedS=G$y>Sob={Pa%WUx3dws@W1d9jX4 z<~*a#&dFJ>uL{|#=&1W2&;8kG?<4TuxZjtMDd`Y8rL*COE935=61X}e_>eJ7Z1L=? zAHWD@^sk`)`QlfI6Blg99NE&o@Kj$O=xuGf`1>9ISIv5-#<&?^q%YG21HC5-{|^Oz_BF=sq4{qr{d02b zBrpeNRF`)#oY!X`c4}{3Z?cTT+wPd}?hC)P%T66FJ5QZEqpQ#@RlP^KNCgru1m|UM zSnRJiV*W>2@OWImWEmx;0U?XC`xl3t*ca0q!#DE!oXw9|%sPUZkqdsDm*3&%rZS3( zAuQN9C%?k>lM46#^pM%~<=mkCZXD0eALZhJy{2me>=s8T%1UA9DYU0pJ00-`Xk3cw zyP-r|fA(`wDY^%kpTcIyg2)9 zt8hG_?msBMxV*E|{&dgx3+F{Q=Sv!V9N|OoQ=IMHN5^0hoN6Srn2^iOu#T!KO^$9X zKuy6Ic2afu*rF=6=SS$?kAfY!i$?4z98s%`uJwn(7Ney*&&C^9EXsTtXYE5b>{p1x z5X6UDg^4STVd%!g84Jt|3Sa~X1cY;e2MKfsn6h>WI8mU>S!J}n|6!a3TN=)>DjJlr zn^*;<%fm!;rv>)J0_k&d*+xEQ`1#!zoD*#<$8es)?h72XZ7M?B01dZC89&~e8V6-G zc4GEIO`Tmi{Fm^KdX+DFJ9bCXgf_OehAy(gj@;wL1}>I4FF)W}Dj~wDQBnth%HnBy zF7mKPVYv;?RlZ2rYNE0mC_|GYppVJ=puhjQRVd(j-x|+wvS0#j0JtwtY#;-BoOk&) zS^xbYN6=W~8O|8I03)b{?d2|p<amVkpB>fc( zU>#GT<1_o_|1#?EftJ2$a^8g_flzpshdo1;L(j{%E-wm1&be_2rDOcF1;!~mI~&D( z<}_3JBUK*-oECtKTOVRL%8+*TupXS}1n>O_b@k@NKbd{HcUrA{xvD&XbK47?IuGCf zXd~i8{~QEOqmSIaoEllE083m-C#4Tn1sE_McuSY)gF#NrH_nBE&t*oG{<~iWLEs>F z{RQhp#pq3iP1#2wum=APiM(x?;b-6WkM#qriwL6Ne&b-m4g`584Wu3UZuI(o>p@3} zYhc*MzC{@jJ_C1=p$ACBggUiV>%z)pv9>=N22T({MTm3}E2Fw`7FT?Fq>Md6<-pI_ zgO><6nPz6}?dwflamC)^X+z`kkYHc^jv=jD*5z z9p`L4PW3G_=>`>ZLmeB=9RpRJxL6WNtrW}HwnT`+In@dZg}P?`w4T*=roDJ>C>OJ9 zU+DP$k>Lt1ANnta6T5wT{P;6b81}(J7C?&B#lN&KO5y)|_{I*#bT9fRkPYO39^V&a zi|9qB6DdrTk%rb^g`66w^O)>?3f|z@TyDIXJs3Kp)|Zfm;-bi-GAP76XjI-HZx~SgrbptWze|jj)-Z+PPcx z#%(JMsCGD7F&Giac^vj9RsU$)xENE8<} z432Q2bt4RWfOU-*+rS0L@et&KeJH?VIUBsC&MA%ycj zN5iq9M?0ODI~#XJ0+Y`++%GpwPedNKYVXQ-9?6dbalkqB#5imw_+ph48MbpW*LWr$ zwt0{N_%>7=AQ%c7bH|>#XT@*?4&l;OM<^4~5it1`weh$rFzsxKSlIywFy2E^@la75 zT662~G~5%eIlwF)+y?A8_H0X2*-7d619j$~Luw>PxW@8g1Ji&Q_iPn{;>2($!+9^L zV>4*1>hf3B`^3=W!VT%rRp~KY#y;A$SBBv%X&@soY>^VT7FvKoC1RV-fbes()=YSO z)Z)_fD68`00<@#d{cl0<&$W^;Ne4`L_~cAJn%Fdm^uc02Z1fiH>o*S}~k^K0AuH zSLiis^60)u>-=pzHmH1gxFv(#3B!mS-L^Qoy@5jgR6x%{rQk=3aO?Aqu=m(3C%-Jr zc;}h)IF7OjH(vYT0s9P`{c!j*8Q`2Hy?CM!XEnt*1ot*z-z(4g-DvTdMA!y2!~OiD zyEu@&!&Y{7Faeucq5~PGg<_@u8+`>BS^qzPS^tHl|H8F}G@QKuPqqGI z;+t;j|IIE^;}H^iMia-!V`0^xtVcSqGQc(gH? z9G9~c$FY9_F@Wy=Bece1%r5luCN4U!HiqsrBF=wwa=_1W;IBGQe_XP3T#R2Tm0sjl_-J)hP{sr%_N%M;pM)b)? zM?|^3BDux!U5m9*oR~me092v^p$@Q^&eisV5O*mhfN6?Lhn~;!FaG=Juo+yv@Q?Zz7*4NS{JxGt zVlUyVIK26-p+3&wuW5+=YsG%AaI_e>Ja`BI+i_GEu@WRPc#Q`I#hxjz0#(