From 2a072bfe2a4697f934325c0ad415a420691146f6 Mon Sep 17 00:00:00 2001 From: lichen225 <161898702+lichen225@users.noreply.github.com> Date: Tue, 21 May 2024 10:51:42 -0700 Subject: [PATCH] [DTensor&DModule&DDP&Examples] feature updates and new examples (#35) In this PR, we add two examples and update some features in DTensor, DModule, and DDP. ## Examples 1. 4D finetuning the llama2_3b model. 2. 4D pretraining a mixtral MOE-based model ## DTensor 1. Update op strategies on `Partial`ed and `InterleavedShard`ed dtensors. 2. Add all-to-all communications. ## DModule 1. Support factory methods for nested submodules ## DDP 1. Unblock gradient allreduce for sparse modules in DDP --- .gitignore | 2 + examples/llama2_4D_finetune/README.md | 32 ++ .../data/shakespeare/prepare.py | 57 ++++ .../data/shakespeare/readme.md | 9 + examples/llama2_4D_finetune/data_loader.py | 64 ++++ examples/llama2_4D_finetune/exp.py | 98 ++++++ .../figures/llama2_3b_train_losses.jpg | Bin 0 -> 27205 bytes examples/llama2_4D_finetune/llama_train.py | 280 ++++++++++++++++ examples/llama2_4D_finetune/sharding_plan.py | 63 ++++ .../mixtral_4D_benchmark/mixtral_train.py | 1 + examples/mixtral_4D_training/README.md | 26 ++ .../data/shakespeare/prepare.py | 54 ++++ .../data/shakespeare/readme.md | 9 + examples/mixtral_4D_training/data_loader.py | 64 ++++ examples/mixtral_4D_training/exp.py | 95 ++++++ .../figures/mixtral_train_losses.jpg | Bin 0 -> 28883 bytes examples/mixtral_4D_training/mixtral_train.py | 297 +++++++++++++++++ examples/mixtral_4D_training/sharding_plan.py | 69 ++++ examples/nanogpt_4D_finetune/README.md | 23 +- examples/nanogpt_4D_finetune/base_train.py | 2 +- .../config/finetune_shakespeare.py | 4 + examples/nanogpt_4D_finetune/exp.py | 69 +++- .../figures/nanoGPT_drand_train_losses.jpg | Bin 0 -> 23677 bytes ...etune_4d_forcebf16_train_loss_bf16_200.jpg | Bin 30464 -> 0 bytes ...inetune_4d_forcebf16_val_loss_bf16_200.jpg | Bin 30660 -> 0 bytes ...anoGPT_finetune_4d_train_loss_fp32_200.jpg | Bin 27168 -> 0 bytes .../nanoGPT_finetune_4d_val_loss_fp32_200.jpg | Bin 27666 -> 0 bytes .../figures/nanoGPT_train_losses.jpg | Bin 0 -> 37171 bytes .../figures/nanoGPT_train_losses_fp32.jpg | Bin 0 -> 29298 bytes examples/nanogpt_4D_finetune/finetune_4D.py | 17 +- examples/nanogpt_4D_finetune/sharding_plan.py | 19 ++ test/dmodule/test_dfactory.py | 284 ++++++++++++++-- test/dtensor/comm/test_all_to_all.py | 64 ++++ test/dtensor/ops/test_pointwise_ops.py | 65 ++++ test/dtensor/ops/test_random_ops.py | 73 +++-- test/dtensor/ops/test_tensor_ops.py | 124 +++---- test/initialize/test_defer_init.py | 2 +- test/parallel/ddp_optim/test_moe.py | 251 +++++++++++++++ vescale/ddp/distributed_data_parallel.py | 44 ++- vescale/ddp/grad_buffer.py | 38 ++- vescale/dmodule/_dmodule.py | 31 +- vescale/dmodule/_factory.py | 122 +++++-- vescale/dmodule/api.py | 31 +- vescale/dtensor/_collective_utils.py | 75 ++++- vescale/dtensor/_diff.py | 85 ++--- vescale/dtensor/_dispatch_bypass.py | 52 +-- vescale/dtensor/_utils.py | 2 +- vescale/dtensor/dispatch.py | 8 +- vescale/dtensor/ops/basic_strategy.py | 289 ++++++----------- vescale/dtensor/ops/math_ops.py | 8 +- vescale/dtensor/ops/pointwise_ops.py | 38 ++- vescale/dtensor/ops/random_ops.py | 29 +- vescale/dtensor/ops/tensor_ops.py | 303 +++++++++++++---- vescale/dtensor/ops/vescale_view_ops.py | 304 ++++++++++++------ vescale/dtensor/placement_types.py | 22 +- vescale/dtensor/random.py | 2 +- vescale/dtensor/redistribute.py | 51 ++- vescale/initialize/deferred_init.py | 28 +- vescale/model/patch/linear.py | 8 +- vescale/model/patch/utils.py | 30 ++ vescale/model/patch/vp_cross_entropy.py | 4 + vescale/model/patch/vp_embedding.py | 4 + 62 files changed, 3132 insertions(+), 693 deletions(-) create mode 100644 examples/llama2_4D_finetune/README.md create mode 100644 examples/llama2_4D_finetune/data/shakespeare/prepare.py create mode 100644 examples/llama2_4D_finetune/data/shakespeare/readme.md create mode 100644 examples/llama2_4D_finetune/data_loader.py create mode 100644 examples/llama2_4D_finetune/exp.py create mode 100644 examples/llama2_4D_finetune/figures/llama2_3b_train_losses.jpg create mode 100644 examples/llama2_4D_finetune/llama_train.py create mode 100644 examples/llama2_4D_finetune/sharding_plan.py create mode 100644 examples/mixtral_4D_training/README.md create mode 100644 examples/mixtral_4D_training/data/shakespeare/prepare.py create mode 100644 examples/mixtral_4D_training/data/shakespeare/readme.md create mode 100644 examples/mixtral_4D_training/data_loader.py create mode 100644 examples/mixtral_4D_training/exp.py create mode 100644 examples/mixtral_4D_training/figures/mixtral_train_losses.jpg create mode 100644 examples/mixtral_4D_training/mixtral_train.py create mode 100644 examples/mixtral_4D_training/sharding_plan.py create mode 100644 examples/nanogpt_4D_finetune/figures/nanoGPT_drand_train_losses.jpg delete mode 100644 examples/nanogpt_4D_finetune/figures/nanoGPT_finetune_4d_forcebf16_train_loss_bf16_200.jpg delete mode 100644 examples/nanogpt_4D_finetune/figures/nanoGPT_finetune_4d_forcebf16_val_loss_bf16_200.jpg delete mode 100644 examples/nanogpt_4D_finetune/figures/nanoGPT_finetune_4d_train_loss_fp32_200.jpg delete mode 100644 examples/nanogpt_4D_finetune/figures/nanoGPT_finetune_4d_val_loss_fp32_200.jpg create mode 100644 examples/nanogpt_4D_finetune/figures/nanoGPT_train_losses.jpg create mode 100644 examples/nanogpt_4D_finetune/figures/nanoGPT_train_losses_fp32.jpg create mode 100644 test/dtensor/comm/test_all_to_all.py create mode 100644 test/parallel/ddp_optim/test_moe.py create mode 100644 vescale/model/patch/utils.py diff --git a/.gitignore b/.gitignore index acfadd4..8cada63 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,8 @@ /**/*.tar.gz /**/*.json.gz /**/*.log +/**/.DS_Store +*_checkpoint_dir # pre-commit config ./.pre-commit-config.yaml diff --git a/examples/llama2_4D_finetune/README.md b/examples/llama2_4D_finetune/README.md new file mode 100644 index 0000000..299cfd8 --- /dev/null +++ b/examples/llama2_4D_finetune/README.md @@ -0,0 +1,32 @@ +# Finetune a Llama2 3b model in 4D parallelism using veScale + +## Overview + +Finetune a pretrained llama2_3b model on a small Shakespeare dataset. +Dropout is set to 0 for this model, thus no randomness is involved during finetuning. +The reason for choosing llama2_3b instead of the 7b one is that it fits in 1 GPU so that we can check the correctness of veScale. + +## Prerequisite + +``` +pip3 install sentencepiece +``` + +## Run + +``` +cd data/shakespeare/ && python3 prepare.py && cd ../.. +torchrun --standalone --nproc_per_node={GPU_CNT} llama_train.py --dp={dp_size} --tp={tp_size} --max_iters={max_iters} +``` + +## Experiments + +Like nanoGPT, we finetune the model with a constant learning rate `3e-5` and set `grad_clip = 1`. +The model state as well as the gradients and the optimizer states are in `bf16`. + +![](./figures/llama2_3b_train_losses.jpg) + + +## Caveats + +1. Currently, it does not works with `transformers==4.38.2`. The error happens when doing a backward step, the `aten._scaled_dot_product_efficient_attention` operator outputs the error message: `attn_bias: wrong shape (head dimension)`. \ No newline at end of file diff --git a/examples/llama2_4D_finetune/data/shakespeare/prepare.py b/examples/llama2_4D_finetune/data/shakespeare/prepare.py new file mode 100644 index 0000000..4ed0352 --- /dev/null +++ b/examples/llama2_4D_finetune/data/shakespeare/prepare.py @@ -0,0 +1,57 @@ +################################################################################ +# Copyright (c) 2022 Andrej Karpathy + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ + +import os +import requests +import numpy as np +from transformers import LlamaTokenizer + +# download the tiny shakespeare dataset +input_file_path = os.path.join(os.path.dirname(__file__), "input.txt") +if not os.path.exists(input_file_path): + data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" + with open(input_file_path, "w", encoding="utf-8") as f: + f.write(requests.get(data_url).text) + +with open(input_file_path, encoding="utf-8") as f: + data = f.read() +n = len(data) +train_data = data[: int(n * 0.9)] +val_data = data[int(n * 0.9) :] + +# tokenize with llama2 tokenizer +tokenizer = LlamaTokenizer.from_pretrained("openlm-research/open_llama_7b") +train_ids = tokenizer.encode(train_data) +val_ids = tokenizer.encode(val_data) +print(f"train has {len(train_ids):,} tokens") +print(f"val has {len(val_ids):,} tokens") + +# export to bin files +train_ids = np.array(train_ids, dtype=np.uint16) +val_ids = np.array(val_ids, dtype=np.uint16) +train_ids.tofile(os.path.join(os.path.dirname(__file__), "train.bin")) +val_ids.tofile(os.path.join(os.path.dirname(__file__), "val.bin")) + +# train.bin has 318,905 tokens +# val.bin has 37,782 tokens diff --git a/examples/llama2_4D_finetune/data/shakespeare/readme.md b/examples/llama2_4D_finetune/data/shakespeare/readme.md new file mode 100644 index 0000000..acb6e12 --- /dev/null +++ b/examples/llama2_4D_finetune/data/shakespeare/readme.md @@ -0,0 +1,9 @@ + +# tiny shakespeare + +Tiny shakespeare, of the good old char-rnn fame :) + +After running `prepare.py`: + +- train.bin has 318,905 tokens +- val.bin has 37,782 tokens diff --git a/examples/llama2_4D_finetune/data_loader.py b/examples/llama2_4D_finetune/data_loader.py new file mode 100644 index 0000000..f83c582 --- /dev/null +++ b/examples/llama2_4D_finetune/data_loader.py @@ -0,0 +1,64 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import os +from typing import Optional + +import numpy as np +import torch + +from vescale.dtensor.device_mesh import DeviceMesh +from vescale import distribute_tensor +from vescale.dtensor.placement_types import Replicate +from vescale.dtensor import empty as d_empty + + +class DataLoader: + def __init__(self, dataset: str, seqlen: int, mesh: Optional[DeviceMesh] = None, dp_rank: int = 0): + self.data_dir = os.path.join("data", dataset) + self.seqlen = seqlen + self.mesh = mesh + self.dp_rank = dp_rank + if mesh is not None: + self.device_type = mesh.device_type + else: + self.device_type = "cuda" + + def get_batch(self, split, bsz, lbsz): + # We recreate np.memmap every batch to avoid a memory leak, as per + # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122 + if split == "train": + data = np.memmap(os.path.join(self.data_dir, "train.bin"), dtype=np.uint16, mode="r") + else: + data = np.memmap(os.path.join(self.data_dir, "val.bin"), dtype=np.uint16, mode="r") + if self.mesh is not None: + ix = d_empty((bsz,), device_mesh=self.mesh, placements=[Replicate()]) + else: + ix = torch.empty((bsz,), device="cuda") + ix = torch.randint_like(ix, len(data) - self.seqlen, dtype=torch.int64) + if self.mesh is not None: + ix = ix.to_local() + if self.mesh is None or self.mesh.get_rank() == 0: + print(f"sum(ix) {sum(ix)}") + ix = torch.split(ix, lbsz)[self.dp_rank] + x = torch.stack([torch.from_numpy((data[i : i + self.seqlen]).astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy((data[i + 1 : i + 1 + self.seqlen]).astype(np.int64)) for i in ix]) + x, y = x.to(self.device_type), y.to(self.device_type) + if self.mesh is not None: + x = distribute_tensor(x, self.mesh["TP"], [Replicate()]) + y = distribute_tensor(y, self.mesh["TP"], [Replicate()]) + return x, y diff --git a/examples/llama2_4D_finetune/exp.py b/examples/llama2_4D_finetune/exp.py new file mode 100644 index 0000000..b5e5df7 --- /dev/null +++ b/examples/llama2_4D_finetune/exp.py @@ -0,0 +1,98 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import os +import re + + +def parse_train_loss(log_fn, name=None): + lines = open(log_fn).readlines() + train_losses = [] + for line in lines: + if "loss" in line and "iter" in line: + token = line.split()[line.split().index("loss") + 1] + train_loss = float(token) + train_losses.append(train_loss) + if name is None: + name = log_fn + print(f'"{name}": {train_losses},') + + +def parse(log_fn, name=None): + lines = open(log_fn).readlines() + val_losses = [] + for line in lines: + if "val_loss" in line: + token = line.split()[line.split().index("val_loss:") + 1] + val_loss = float(token) + val_losses.append(val_loss) + if name is None: + name = log_fn + print(f'"{name}": {val_losses},') + + +GPU_CNT = 4 +DP_SIZES = [1, 2, 4] +# DP_SIZES = [4] +SINGLE_GPU_RUN = "python3" +MULTI_GPU_RUN = f"torchrun --standalone --nproc_per_node={GPU_CNT}" +CODE = "llama_train.py" +LOG_PREFIX = "llama2" +TRAIN_BIN_PATH = "data/shakespeare/train.bin" + + +def run_exps(max_iters, dtypes, run=True): + if not os.path.isfile(TRAIN_BIN_PATH): + os.system(f"cd data/shakespeare/ && python3 prepare.py && cd ../..") + os.makedirs("logs", exist_ok=True) + if run: + for dtype in dtypes: + dt = "bfloat16" if dtype == "bf16" else "float32" + cmd = f"{SINGLE_GPU_RUN} {CODE} --dp=1 --tp=1 --max_iters={max_iters} --dtype='{dt}'" + log_fn = f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log" + # print(f"run {cmd} > {log_fn} 2> {log_fn}.err") + # os.system(f"{cmd} > {log_fn} 2> {log_fn}.err") + for dp_size in DP_SIZES: + tp_size = GPU_CNT // dp_size + dt = "bfloat16" if dtype == "bf16" else "float32" + cmd = f"{MULTI_GPU_RUN} {CODE} --dp={dp_size} --tp={tp_size} --max_iters={max_iters} --dtype='{dt}'" + log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log" + print(f"run {cmd} > {log_fn} 2> {log_fn}.err") + os.system(f"{cmd} > {log_fn} 2> {log_fn}.err") + + print("train_loss = {") + for dtype in dtypes: + parse_train_loss(f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}") + for dp_size in DP_SIZES: + tp_size = GPU_CNT // dp_size + log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log" + parse_train_loss(log_fn, f"{GPU_CNT}GPU_DP{dp_size}_TP{tp_size}_{dtype}") + print("}") + + # print("val_loss = {") + # for dtype in dtypes: + # # parse(f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}") + # for dp_size in DP_SIZES: + # tp_size = GPU_CNT // dp_size + # log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log" + # parse(log_fn, f"{GPU_CNT}GPU_DP{dp_size}_TP{tp_size}_{dtype}") + # print("}") + + +if __name__ == "__main__": + run_exps(100000, ["bf16"], run=True) + # run_exps(10, ["bf16"], run=False) diff --git a/examples/llama2_4D_finetune/figures/llama2_3b_train_losses.jpg b/examples/llama2_4D_finetune/figures/llama2_3b_train_losses.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9beb8c3c0ed29ea5cb1d08d04c88d7e7c466bae9 GIT binary patch literal 27205 zcmeFZ1z22LvNqhfCL~yJNYG#j?iPZD;4TdzKyU~y0U8OA;7)LNx`E&h!QBFlySvl) z*Eu8ioXp&rx#ymH|NsB~X`c1;mfmZxRlBNIRlP6!ZsKkM@ZgoCj3fX70RdnH{{!4j z1H=L7sHkYDDClTtXc!pinArIDu(7bPiE;68@X1LiDac94$f#&o>8YMDQS5#J2*VKM%ZENr7?CS0r8T~#s4xN~sT3lLQSzTM-*xWiiIzBl) zJBM9d{+Jg60O?QD`s>Von%6`4ybzI*k&w}T%nJe08Qzc{BBMNJL&X(SLNl<%d&2IG zj{hPey|e{`nnU@3;H}*-CLs;yBJJUisr@pue{EtuziVcHo!EcOYYKpkgaChdNDl!b zfb&a+3@^-ow|~pve?teDKG%;FSvE48I>(N0lJ~F19P3K6zxI=B_+}RgQJ0Be$c7_S4b0eSJ7Nw|?28y-3!O1Zv?WmQN(e5G2sX;oo z?SC9XbV-QW<|$dNBSiS@0qjb{h5!bVigj_A7NKQl9aHVLk0f4J#^|KF-z-dzdESG|SV#QNG8S1xI0aO8${}xl1WFedM}_ta1m4bf;QBL z?f}2YI1DCwO$rxkWx$_>iZHbQqhHA%w{RQ;eM4KAza%IZ`ICq34u% z6awHN@T1@1Uv6PPf8Y({x%IQc`HOn|d+7g``u}YTGZLWuo(OLYJ(L8u>H=}o~IVAeIm{tEgqUbH#9Guf{LtxwWr zTgO2WL3cIoIy%JSwnbg!l6$L`f1SB$CCftp472nOFkGdIFgz^r9)x?EEG|)2nb5^O zexE#6xelXQCG6;FRc#gWCM>5;)z!Jd4#Ma95T;4zNhD*Ovqw%>wR^Y4H}b<4#q3?z$Cp8eG3$PB_I{Pk_6y4?c)NTtn|FY4Fe*$p`KkQ zD^e1&$3tikDU$wcpbC$MbbHUDtx{g(cg)rdNZ%e5iNJJg z(XL@z`sZz+YoHyDpy~z;474q`u?ZNQ=^7tLq9iYhaXqIDu z>0M_V|7^hLMCrpPW_7fsk(U@d;$9itiW@DTyAiyR4FwNH$u<5e*%wWdgU+-WiSKGc z2<9HB*5-F17c1Vsk5gJKSa;YN{P^IbCo@n$n#6dz?M;=<^L}h)=>o{QfM$5Bcd#`Y zJ+gYpKq0E%@*RNCyh<#X(2cx&;88y2XcCV7h?hXv-xYqp1AI}hqE{Lbp3Z~4-I`aa z4xO3YtTNuS(w2Wv$}H~no)p10=@_WrcF%gfwGnCTREHhq`*Ux+hsi~K4|o=IMEn3Y zH-GaC^ga~86j~^VM&4FK(;8AX^4YfntMGImA+bD;6pTLIVG`+hwm3aV1XH$pJ0&Ok zDoVE6rKKWrF_rE~k(z3y?L7Gg()Y%{>WJwTZPWEZ*zH9zapcI_^f!|9(2=vjmr-rp zEZNg~{up*B`5hnZU$ZKwchqzqk@nNqAHu&|q# zxZmkSjhGC74*h%9@ZK)+S&Gpd4awKYU2mMEuT~<@=9@Q1g|=;!4^Ck zauS4YqZtWhn|{(|4`oJm1(N`lIj!j1W0keE=cE?9t*m+q=cLGdT$MiTbc`8>)_oc@ z-#G23fVPp4??t#_RMmZIyK)OYp`Iaocml>`K=Q2_u6MOxyb8nVcTStp>9-ya$&*=U z?%-C=4i>eV_r2M9R}Q{rvbR_gxj)w_s9=$87OB~|6W}DQsg8tfxWczyy$w{uM+jL0 zBfG@4R`I{7<{qh$d7Q+$&-(Js#>5^EL)S^afw12a_=Dm|)8dmjDgRFSTG{?u*NBuQ z)q%c!z!M+)u*9^rI&%gG8S{5kG=O1hVt%|h)ldXk0?7FiykJV%jJBb;JtV53TQ zmh7q_rsi$gRjw#vu_td)nCr+^`|5(4UNip^oU7lIJ3HT%-D{;Kt_%gAOh+J>_Ux!~ zt73yh%%fh9Q;og$p^3nD8MwBGr^IRo76}A>wL!ARmRe-iZ;9;Ml1{k{TWgueq3jB- z)=HxkN-Xn9CtQmv?Q1yXJJ~pP`&9-nH9>aExwLP5xRyEO0dm?qtm>7pXd}Gc6oy79 zOA!~CL&hON7W}e%6}|Z}A=gabBsEev?;wrpRAW49S*G4*9QAWx^FdR|8%fW4D9XWY zug>>V(qx4;ZVhf{Zo-Ef2gKp|;l=EIHCseGt@rOwrIRWglFCSNcm%2PD<*dHmsH*X zUf!#-b~bUEow1&-n#(hv+Pzs1-4_&Yr15PX)}C&qAx;0hT5mENne#fNupAp9DXQO& zU*F%#k`wu0uOS20v$po&O{4%*P(WF=?m>_Er!nc4mJYt9x>S7Kjgh`7>#BttDK#G& zir0smP3Dqa#F)L2xmCQEL#Wx~@E8|8H6SfcuhHYhX7-Vx#RsdeKOUypf>{8PCpubDD3CYx_-b08&N;aHSTNgD7I`)} zG|Q0Ai>us*aN2yTC_%T@DKlrjNZ73fgf33w+G5PuhmPz8cacja^(;%#qn;F#^IM~{ zbGm~bjy9MKX1U3?qYSEwuDoie%~%tdod>5M`w!`5o61M8U)8!z=ysfI_x7!Hy*42V zkMs?*>f{}mf(Nwd16j9N2eaQ}Nz(cC@;|WjXBRd!3#^}k)ZurQQmQ{3MU}yQTW31S#UYmHl7bz{b;9K?a(7vHEdus;? zPh)?j7C;pvSF>ANR!^F{sXy(^ACi*1qzOSO@tWC|{|AeOc@ z)sWjp8iuX6jZxHclYBP(-avZQyS_P*vKK*KtZRX!tMM?Uh!bY*pNrxbS_qLdv&mDy z{j$rFiJf-rC0lPPW+&CZ38X8t$cB-&)HcldWlw3ztwfMazsNI+ZJsnx$+KrfNA}D# zJd-ghx%;SOvqQicB8>CECGHh`0rlhB4#T@AQ{iewohX+wj?YYo>clVJiufxZq=S|PVBHc8^7jRHFrZoH=1jdJ0SRB0# zagv#)QR0mlc#I;#CkH+Zp^ohbyL05s$(VR7k+V*%Q3v>2luM}uA7OMYj3qO*k#JWV zlgNXOJq@G|9xln5`e^U_%6kDCv)zmDfBc$@_kFz`_>R=NkDuI=$fA^%_e6Ov`=wi1rJDuTb zF4x#pvTh)>Jr&P-B*h&&9{W&Nf-H8*J%fRI!YhDIycfZb zkIZ+x;eIc$vJqM#_k>$N4NMxYIpYNNn~o&qs*EwA3A1s< zEGprN)gmhHE4t5qUz(RIytc6^ZOfB@hl1vIB1|iuDxt^1R9U)JQnSmSOT~}Mgz7Yp zQmQQV8^!ZP=@ED7XRHpDUc!azf^n48yDzmc?0r1Y=7{=b9T%}cSX!nj^apDJtLEG< zXd5dlyy98kyeCvD!H(12BfnJ^7(#6y$cDHCwO_`5&NOGr)-roaR9mM>sAZfWG4l{< zRTRrd26<&Ed&sdtzPAl}OLU0xokvB?yy^_PO%&AcRjNLh>DALpO!h#mx!|>p`{QMx zi3&gP^fm&F{%+n7AwkwPg8RX6*q1`UeNkRP)mrkLOfYOv@etVlamoQnfyD76&*#tN z&ik(995E;*0Ko;Pw0kx%g=_LVfVet5^zt5tKS(w6Ei-qqw9oZi;A!!7p1u9d5Nj2L zk_lo!3i@_Sk;sk$5xkyp?$>mnM1KdM0H-@Urny*|d|%|LdI&0#qw5~nX(D}slISJr zDuj`SeFq3gh|Mg7@slR2@666&1%1&2u^6)AnuxwnL--aNy^pM~D`S?w$rIHr)DZ!q z%}rz?bbYe8^o~B#iQnf#3|;9hp*b$()0EsmMB${a6jy(XMKAH=;KG5vG!6ikK}#oe zq+Af&e`~*O3SN3v$uoyg38nDH7+?zpQF}%sd~do2N!%`j*fgD+j+8m|uTbuql6}xhmTA)&W+S|iRayR{kcAIW`Ug$pr*IMyTGRXN;wcIA#-|8$Vl@VE zaDZ2;Tq)cA2!-i~xpc#Ok_?+-ubR#b^eWNcp!1=HTl6=wi-Jlyw<{E1YoLnctPb9G z13LpCa_a${y*{31EQaqccZqK-B4cCL>L!z8lKB(^7~Q8E2hIbldQfBF#V`}1pGT<7 zGb{2+`$u*s`cNCxnIy7@nUMb1wC%TrGbP=TA)kHbz4Z=R=I$@3b(;m2!4nY;c;*t_vL?bq=Lu*CHdcLew=3wv$p4C0C7* zV)u*XazuWP&CY&b?E0Q?%`RU1S;9b}a1?FKZDPJSdeYQ<+rz0Rwhgb!n6M99wv~V` z+po%HGYslurELb1@fF(}_DnE|Kfa+5MfQ;gFq$IeWRzG6DptjwkIa}(HKv~N1`kxl zZq-4NuY-y51sw*njV9b>M)%_664s^09kzw!GG%3ZtMR&>Y@5Vhbo3?# z-Exe$*NWx(btaL@T;XOTXKCZu@`&|L1v!&qyw-xRq8EMAqu7ZaEzQscvLQ5?*P5A9 z&eCIgFR)68m_?fhiwU}c*Doq78^@@^#GtGhqqJ;5-Dxq>j*kq6o$=LPM3i;fP~oSe zZw^i&Rd_OvW==!kYHu;u*F^DeZfLgXl27bg#^lK0v9RuoYy2o%)6rw^A~ z>?$uBnKFY{sv`^mrK-wp*x?-BTTIC%Fkcf#*Ib=B)60(7+6XKYp&f=NF`s7e*#+K2 z43WL_;@@II(mr&I4asZGp#|?{=F1XCYfZeQi7*`CsgzdJJ&ct6pjwOedB^>Qb%r;q zRWS<@OMmi=p6*a2dySq8bvm-|wkWBplGfA!B!MeohAVk@SPTA*&}@unRDLwLE&y3lES zi?o;5dg1+$97#9nWj-<55tr{V&qjDDiT z`6jCWm?M;}F0`V?48FR(zXMb=n#YmW%>-ceoSS^P7XL1&Z=5@b1u{t%KC-t#uh7E2 zl9wKE=|)1MsIY0(4hwa(xi_V!-5994_MCM@5vCAFlfM#)b0O$PP7qe&zNg#Qy0t&W zcU(u5U~^4N`L^Nq`)&7?-r$KsnQJ1?o12@NW!}ww9S0?yb}h|~U{^JW{Vn(KDpc@s zp4POH9wWxCrlVHo<7Mhw03Fh?Z02QEquU6m{4IjoPG4-5a+WylMQ5YlMoG9P3H)u z8C${2@*WUVc!nT}dEwfD$~0V|fMwnfnXzajgGR|U)dceHv#eF&wXlggt%vuBS{hM& zSpv^o3m;UBom53`iuc7|cPr6FP)A&qNftebnu+L_D}d3rUXESO>rK?uM!e`|=3)P) zdmmsG%wxMPO2#tyY42Jm8wOGstDJ1Lg~;ZZhnO)jtCYW5c~S)oZN6W#o3L2{W@`1s zJ_yaQI&r#U>6O5noPITYF);8D)frI3%ZfTxJohwt>$2N)M$R@c<6NvTa*emROR5-n+o`}d_cAm{V2~zMtu$PgjpbaDBgv^l?)4nrP^u_D z>E>{E&c2yoz&lF;KD(CBT>@W{d^=p)0rn}F2WQ3UFrPEYQRebwk`Ko1A>d@&-s2GnIK$!Dslsr?8R)H4y#luoceqh7AbkmtsvG8Fk?KgH88fJ_ zsT@rqkkzqAR{WORXNlM+i<=o~9-F+-U3^WD7W>_J)DdQ6C1<@XVOQ2)lZ749iZk~> zssvKkM#Aa&c$j`1ZE5@)e{@ zfRvhYT9~r1E|*u78Ar+m76l%0M0j_+&*a?&R>m#dph@~436KI(ErD^kg2gnE5H_D3 zSwrM=u#oV#Y1vV5`4#VVp^j-Rk>QMvIYTosAHDaS7q9+afL7wJX7_HQ z&zu?bFqTcT6x;2Jo=Zwm`+iqM!!P3r6I3vS1 zE@aqHXWh!=gZ)vHJ(~mx{FjNq6L;5hjglBd|F9mgN0w9625o5Gvh8c_klp5zmA7B_W?`iXB*4a zA?oTX)5dnBm&3ZQE?zxaiA<4V)MOFu&S#!jL(C30FM6;Dj~t0!Kk_tgixkWgBb6;9 z>>#eC;LxxsC+AQe3jOSCPCFWGK^X}Q=&7DQMhcfi;~T_B=!MzFQU+0Q_^wX zs~laBjpgUxT$SyarH|W$XxM3&PFpP=l5cl#R7xf~yb0U4!j>w9L}?vV1{DX{=TBl} zRV@+6N`PlxlAM|w1FSkOHJvhzOF;h3QBCT<5VIJW*kN->?Rzv3{xPN7Tn;Z5B zw8(b{n5mdDOPEbbUD)5X5?8NH86j|vJi49UIeI`8OmxPI^M?DlQx__s1+VD`a2DfD z(8d;9LJ2n+`wGegjriBm7{r$88vt)zlzXm}$BPB^3v?>7s!C`pViJP|2@Bf2qi!mh zHNzGc+(ul)MGj(N^F44FLVuT`ayd8t%;kEplNW47`7J>nuAKaCz#>SB7Kgq9Z->|> zbqSjH0EaTn`6bEygaghyfLTq7E336Nx8f^I1u)C0S{4f@?S?UKX1R+V>?($sF>Yx(j%*m-8w_S)4wOxU1^Pu z>N3_6W$lJ3k3963^Okt>{foZ#ddajKTTUkvq5a6rpSlpJ5u6sk6Lcn9R$%a zbMpWP-^@lzNdlb8os zQ+m>!dRYiF7VRRer{P_FP|L-r=X3Kjm#HY?F-ZZ!`yyhx6(@!AdQMjZeIjKJIk_KB zsxqx_wOMK>PN(3oNhlnl{4GwDcEl!ukD$#}K5)MPA z1Oj|Lpw~x1w^t&Il2^qStB-EZK;LpRyr|~kD5D_UpEeG{lp42hUE%WfNIl5kNJHnD z1r0JZ?U&u++tk7t97>>hlWRsd2E=>k`o}%buHZ|kAL+ku@&Bx#(3fL!j*X|DXRE2M z7dIIE1kh(f?0y3c{u4|2VLA4(`bWxK#aAera4N)f!Nzvd*B{V67&O5Mw?D~Q#XlkY zdQ$nzuVKp;!n+wu(q|C<{jlC!*J9mbYFLJF%d!xj+LWiNO2i%D^~!AOwTlKFhxAh? zvy@PsU1PP;EXWENTv@v|U(9lSMN#wjbC92{_>b7PkNb5U3qG~~0oKBg8kokHYmf*r4Zy+Rvr@(oW1U?a9(-oINV8}){>0*sXhKd!vW6ouv z*_y)?Idt0L9{)DqF*q~)=9RC;_7|9|1M4+Lj>!4L%+8!~Do-yB1#Oi$PP`+!+Q33q zAY=zl?$QeXxvup`?`!b=)*t_<-p#|Sch7EoiknWXz<;XV{f1}u6Yq@Rr-T@;p1pA7 zy5a|NO?gA=_-FMrfc4DGy6|WOV@FsX;S_|lUZ}51I3mt!%URDmX-?t{1*;AS4x5ey ze`l1m$$Y)&`CumAon8>T`3q~jq)4rr^o>36##N*xzRzx(TP|EIa!FWL*iyfV_ zi>B?WdgIOEy*U2ARwQrcU@7Dxz=o0(^2#%9{j+E&=_TtL>5Ju*hoSbHoty{_61KO5 zhwt(W&21+wNM3T|IKSm0RbX7`{!)d)c%c+P8kHjjU$PYu^f-#91qO>P-uCcWfh%n5 zblD@EO-pN;vXssxsEX%weLU6ocgB2e7n*Qp&2jbm-)PIw;&u7BA}dR?1(^C;zFV+U zJxROw;qO!XYiw|O^Lv+fZ7XzSb+qmPWtK0p!M*a??!86ieV^=;4nt=xO8&kHz`Y%X z)FA^Y4&_Edg16s8T#u0#pa_EPH$Lp>)Km3ROCQfSmn1(saSDuU20oq3RyCm?O@gMm z^QG*qw}uEU2DyLTS0bW9VPs1>FQ&!{vcY4}k(}3?XKnHx*1#ias!=NQ(VqU0-PnMu zz8BgU$WW}i{x^iDf5sO1ub%&{##g}ECBJnje+rop;S7`CwCYs6&0FH}Ybn~CG7G-g z-q$?#gJioTq;~7)HWl8Ev_uQRw)?;e*ZDR9Af}b_lt4b$Bo|Ymy7!KQSgmX!aMC#Z zp;tVO38Uv*4pG8u7Vy;e+*6(D5ZvR}jd8LXUyoPSwJdvm!wWnmyI=bbG6DN-Q^=Y| zxS9xLT=Drs{Buom!Nq4*Cls<$Si-wod|kAnT{P~Qxt;@m`X`)!jWqLPO7~4VN#4%D zG*nV11?##NkJU9{ptwkQkrt7OV^KZ>I)`VvMl)z5K#h0)h3tXCmI&tVN(k-ef-`5`{kt+D%3FK?_JWN^;9;3dXt-8HGdIty@Le?g83!W3=X?zm< zZCcfeEwg-dr?-m`hYPE&s|P`#5`ZQKgWBjyhS)60iFa;-w@JwOf#v+Vm4Uby)>w72 zq@jUq-x78=eqmevWPd>Np8_Dj#>SuN;Y?lbFLGPwiFPUoNDXRDX9rvmyhH8)5h5XO zr>=1wasqv$4ztomMDN)6D1zBVSaOdcy4s$D@zxar?>ScUMgq7-)yHnS9V>(SQxYZK z!9YqRRRQH!*Q>eS#7%RZ{g-Ryh1mjynXy^qeATj!Y4G*C>wnZHf2Tdx%z`?FDGoZ) z?}ZkBV1C4r8;?tyHL_1k?ej^}MbHrvSiS281&-G1<9u=eKjr!)7dEj{fmfvb%{gw) z?^Y$rOH%AY-HQ~b-3CR0?UMx^w>o;iP}U(M-Vg^rpy@NgBqFi+Q)CJ;@@=0W8R)fm z{?=4ZHazi-7|d!+z88It^)&g*9u5-~6~hK#Msn>^1DG};o9*KKdNem&mL)q&QP4M& zol4pD^obYe4|*9K{`pM=t0DDU>glg|aqIQ=xgc{F4izufbds~0J!;0< zm^YXaJ~*~2yvGvlr<~a;li!1(Que+u*}Mwq$An?n2bp-$JAiQw+VSj##s@v|gV}-4 z%l-Od5s7Q}gQ{zoXfuI!H%$SS?QrQn3-Q?bIypnbVioMb(3J&emjmJEy9X@oe3*Q#xmoCqET9 z#8i8Jke6;$-dNk$Un=gcx4K zZau^tSk^f>ONVHJ-tC0)mao?a&MB~ce*0WC7749-s{Rc$6WGo1ajN3*H?pkbn`;9H^EIJWxPlS;?s<0mYSNQTERQrWLB zh-GEusY!^v4E~ou(Hx%h`@x12Es$ z_2XRqnDgCmX0uxP+S%5N*0o5stLP3p&v)~qlyz{XQ3F=j_+fn)b$A}DcNGfnib?Q| zMQSP4wbXct*pI8i6fU3 z-p{hGf5w%#?3gvrKpx>QwMEi^>?MY5L~?isK<-PyKO88|189w9UCM)0+@x7wHXzUHF?+`)3DH^OC7I05lhv(>fq6Z6)`7n|0bt3Ok#LTOs`U~oxL2u{ziOY`$21_u+ z_q#fr4}-1G2betns{tJ1kB8joaX6Nep9?FnfF6Hy^-GIcouwsd`!nOKu5f_L!_g53hdXxf^%*ik%3&4t5g)~jMw)uHEaL;&4C&2w zrDR43o}4h8rK-`SWBp?RAzS>HRYfAYT&J{deUHDz{jx z_zc&ayVYB+nMU%`u>vqZO@rby(DzTA^w~d#>XH{W9Omjsn+yoD$&oS)xS?0 z`n8t3>|9-A9I8rZN#*F_?xx^NUJQ;}*oPVDwd+1>m&3YJX+^!{oVCr3dR7yBuzW!Y zyJzYy1Z}Eve`>u40$E*sM&J|l7~PKu^$>&bi_{e|dQcRe3RvXHN1M~lR5UF@A&`BA z@`y1d34#gR-$3VoM49R@ef)2gtA1T5nxvM=aLH2a+ER-?p*kg4p}>o|CO&+!FIF8M zq?2=$IXkreWm>p?dCQ}4!lJK{Wuvp}>*2lHRPM2l$}tPO^o;g9GoWTayQPc~HR^g$ zj`M=?Q=Wi8y_?v|{o#i4YfCpPf~du01IK+Ftt~PPnqUl7C)-xf`mnc5+O>wXG+DJt zD!w0+^hGl=QkF=Ka0K&7qP%pM6a42@OX}Xoo`=6`3rCd3p$5wqqOfYvzE_IUkgk%z zSBw1hUf9K8QE%OHu_&;bR#(-CiP?;(wNzuqmC=%j3m z>32Qf)RK`KGS^n!=wiT=v-&;ISDisy381j$5(}{Q*-K0e%~l$}L}CC+7g{k_;)2d* zxdQG~PO&{pav7VkmddJ=L8#`*7S3_8E7e0E1d7vc*m6s_OXllY%mVWVhz6dCo#z-Q zLScrmVGY*E{N195fMQ>Rhn zo^0wvmtvJshu#ebu%0yHEK|j|cG804$jNC61Z>n1%}5zye1sjfpUR_lhmgn-^{@U< zA)+@k_R@xu!hSxblo1e})}t*u*p;r8`pEb=XlMFCzJ=XrzCQaln*9^9bwR@_ZNSby ziZhSrKY3!_jv@g^sSWa`7qg|S&7z$tcddhOeORvbcQO2nB69rv+Zx;7Q zMN&sAi}7Gla2^x-{$~}X5ku}`{#e&ccY|rF1*grV&=n{u?Ty~6BDam2w6 zgI*~*E3ClH*ZLp+>)@{q=AUiin=DzM?b|&4OAuzKVE0FE+<%uF?mu|Ia^(KR+1?+hrR7 zagK*`>-Nfg@gR+~)4s^Ou7d$|zF4ok92A84q^i6r#3+f0IgYsnfjURDYsDFjV9S!{ zDMcp1Dvt=po;8 zYVVTggX5@$&iaZ@6Rh3PQj6uJ#>7!Th66>}LkzH@E}4Fa)~3B%s57xK^ob0W-5`*$)mC$;fw;N1s5bDxr*8vGADCy!VDuD%-yP61TPt5`4 zWAr|fMfzXTm0R=!d-nJC4V$ki^S7KS=z~6VM98rZK8=t^^at#ooLn=_ZwDVX`8HLY zs9VtxJ$sB2%P=tZ1>_JKe{$Uz?Md#^0hFBO1Ofxi*$|(Y@tVcB`KB&WLf{cZiE0XA z=Eo7;@rwFTHy77hZe@G(-TA})%3EiTsbnf?-C4OC^_}6RdzocY; zU2T7Y4t{;SzIBP`BW<7!?3IIr>E7ZKS-&vf&g*iD<)-!RRE0DzNt`ETg)*vPOx&4z z4vDK)`cfuV_w`$ZnHBIh#Rm{sBnc?ofgBqPZj?i}#9K@q*IL=2ZzxNN9?mi~(r@E~ zJiesZMl^`-)rSh0R3df zF1S|`NO)+hm23lDuDuluUGT8n(>nUKtu9f~A9Id5x1k`-jvlpp3S z8OMJY#yfz5N#D6R;vWQd{~( zUt}SB9yzMpvq}+)v5f?@HV-@)Rw2ul;s|+a`+5t{nKZRPE^hk6yb+|Zf9u~-yuo%sgI{>6g6=CFecuIsG*!;Jz zu$%#+Cy@+2xX#txdQu8qf!PbGNWF^|`Ps43@8<54;e2suzQI`sHV||&-{aAks@(r@ z0cTBPi#W&2#@N$ZT!3D~ZEWYkM?`e02X4N)8lM`8HWG$htyHe~H;x+#NQtg;VQ=fq z8R_LY)FW$@k}`@auW?d~Pj~HGAE{r$(>0c;6vHom@$8}thUDAB^`y%;&lF1oM=rN_ z&YQzu?l_gwET9S9gbI}b??FGF$WE@LEPdmniQ4BdT2HzI@V6b}vQli5Vi3 z4%X7SSrr*2STQ(o3B+5MaooB(<8IC4A2~ZE^*tIBFazHzH0}fC#{&g`c1q3wx!kGAn1qcP-P#=|(}OQ(#9hJ{*9A=DVD-c^HqSbd2&f=4 zL(OQRfE)|A`R#`~_qdpzMl4te2yAXUG2iyZmDHi-MZwd7A$hlTfNj5MS`^JNHqqp%kS(vKZEv zs9YW*Qh&@C_I;ddy^RuUGJm#E;_-%Hwh8i$Xy>-_w#uvPp1GVrKlS3C-3^Bmji{#f zjz@Xm}v^7YVv8=>3zK(HP*kFRY>jHX$6k|=*X=&e+e|2wG zOnIKWcOc=Us)!=8vtiBoakK(C^o$y}^La87b2SpJ_rf|c#71<4bM{f^eiWX1inTj61D*qrreZ77eRAb7*<~8f9#tje z?qnY9&+BNP&)u3N&Rx@#)h1|-Uo6Lzs6N@n6m9NID+-}+ukG}9lo+m;A{(8Glny^q zG}(WAOM7-J<-2{F2r=aceR4P}`~LWTz+UxYrf|Urk{It07E@DbB8cTx>sYqDioK6} z@2K9Hdei!afRje$R7jGulYQ2ZG<#?O)%Vk7)OonkRdAUi&2Hk>u+Bu=?xD@RXSbi9 zCxR?>E(w2jr!`_P5D#PPB0?e_)Qx~vnHLwU_{TJi^G{f7iHYT`Kw?HcK)u%^ZsZ;W z?Y)Pxtr%3H#Cxo|NeKAJM(}sZlxk`RwD;ge8k6TpitRyRS>ny*Tm^5|gKGzuUe_F~As8{jLXrj!f&r`E^M2 zT{)Y_tVWf$VErfZlwY{J_cDnw!%nN^k z*Q=vhyhZEm`?+}Lv4_2ul4MLrSt(7i7E`7WGD%df^??S3rKO_& z*e3<*YfihPdwgK(<|*ONu4joa-?ssJ+8#_GV~2HAg~Vh!R(!lMBoM+^PM41~$AnU+ zsLPnfS&2}mj0+NHS=bQQd@4vsrg_7-Vv=My2R0QZ`+Bg0b1(Fe6of^#90>wj3@evJ zt0|afLKe1Uw$7&q=IZ?ZSx^QQ`xoB$k1-|pLD)ul?tXbC86(cL*m6C0qaL7(Ru_(4HZ?7-1PBnh^_J9?k5L zMZ7_IR6YU^Eb?PRgv^fGlxE$D`Q^J0g;`ITp+X&tJ22WTN* zzXOnD!lll1>)|+f{csL+Sg^5?cHjQDheUq50umYsSY1E74X`4LbgB3fUYlmdsnd?~ z#ZRdZC_{xX6QK-8wSIai9nlHx(M}`M!v(r6c9_8{?TpfBONuxh#Y1R(cL0pWm^*-Q z?fNbHz6flL@*r_j|9Be$S_nOcuWTApa1}`X?J=dFuW@itM0-XVcEXnx%bPI{T=a-6 z*r@uL`G%YHT$k)Quh6H{e9%wR0LzKIk`q2xt&t&~R;+o%ihlJ`g3N{T)l~ zshv}_;QF{n4=tXz;KoJDz$K%q$Sy<23_x?+-14u=&Z$b# zG8teDSAwk63&+3c=HHAK@(G)P!YB6kE3rQ1hLzpfHsM#3ml`hqL9R7wpE0G6{!cUK z9u0-I$MKmcT6xzDCXYEraZDcN5y?A5jgdMYp+p#u$s^3jV>*%|GjogvQ9?wf!5B;! zX%LMJsflKWPV&e*@|L($>)yN8U3Z;z&${dUasJr*pS{;!`;YH;uixJL_xn@>wceS{ zCBVzdWhNY?nlY4}3hc{hNg7c}6JP9uQ#IJM{>wBo_CtI28x8+jc69%6Oym+CisgZHBAQzH}zwn6rcUpc$O@ zClT$Rm|?i+Qmkw2frolt2m|m-JwOhVM>^DM_AJu=DZH@!N?mpcH3$QVo|I?Yi>Y$e zvNsWxeU$_hqX`d3)P(u1;nW&~=B7qk9zs4%R|kSgv(5Co@Gdh&R4EYfs9oPX!B1Gi zd&7bG(0U!)x4*tAI#9O4v4CF=U~4>dIGX8c0Qi2(cayB-v;8WACQ5{asD{D8 zNyKtYW8igdCTlGF*6e``JDtO&W(wp&>MPDkZ_h(SuKdE72h)@o^wa?4VvLK}aTe-? zsL0jB5j~LzK=i+UApNnK=W6Hru1A!6lX)s2&|PStFX#+y$V&ZUK7efoC#@z>OFrLZ z+5~E(qL*S9Bql}%cfu5{#$p|Jwgq>|6GIoyT^+C1>34oPZWm(SVMBjYi?(;U7M{D& z?tdCjp~c{zC@M)sA`O5g|BP(J?u6v#Qy*a&ZQyzSP3Zl2NXJ{IYEb$4fkEnW;l=*&$MUH`T15i| zFM!p>NkCX6+1QwKNg!wqLW63VzQz})w2xo*m^b!Ix=v1di@LS9joEC!_K# zQ|wEy@%{#_9~|G%)G<4DeReo3@~-SGyt|02d(`k@yfro~OjnAWs*RYQF|k=bJ~G+y za7;rP-kSq$Uq@L|7jVcgrd1aWjA0XKMajVE`%0q}uv8)(4#5!{;aEzX+b8 zsZRx3PZ;SLonQC~EnyobHAPc=x5;u-E0 zXll3;!m1TokEfS-I4xchH>*mR+W+}k(z4)W8cr_cWYomTs(PbO+E@2f=Q!SbvgUN? z5Yo#>#D-OR+8jhjyLVqojWq8i9`4k6fBp_}n8z?*qw*=G{1g5GRfvSk@skev1Ibc6 zvt_MHRma*N+yJG3oXX-2n?CJg?&nfER5H*~bQIu08)8|r;iEzUcyU#xEFLXXpJ;Mnl_x5KL* zA6TLcjQo0A)^c?h-8VoYePj)PcQ<`|Jygyf>gM)g3Sa8-x-{j>uBo4$Im(=`?i^65oz zDG`@3n3$nior)X-l8jNAsiiYO0RT86lWZQ!ye5br zm<@vm#D4u9qxw95jY`!Q?jl)WC+wQp!oYM$@tUwW@=|+`g~f!zVEY!JF@Lo{+DSgE zy=NxrgG3GIc>4(rlPsY-t1n$(YF9`%ZR+X1$xZ;-{TZl zo&%!$An)_aoX}9A@SoyE5Mg^JU)xWS#bEP=8&*ZnA4wV^&arggT?bQz1OTKJ-Q*h_ z&yB)c*!3~Dtw-NA>ek60V4^^GAbe{Z+5UZ2pPCIb#ynT;=8zKks{K(fIX-$KVS?c- zk4Dv;UM3y4fqn)9Y6W27nW6aAP#mwIMhPMbO^%VAfAVz!oDY)9hKtKcI}d6br_ zNn!2pv?E5e{`y7dgz#r?J2kh-1_mG4Y}YOXHy&-)@Zm8Y4mKT}IYZqrDdh9KJOtzX zC?kf*ZkBSnmZJsVSq}$!Bdoa1>g)?-#G(r`sDgNj`6S;^S@NQ>dc?c8$C+kDR%Q3B hl~s|1RT84p6Pl^ 1: + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["RANK"]) + device = f"cuda:{rank}" + torch.cuda.set_device(device) + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + VESCALE_DEVICE_MESH.init_device_mesh(device, (args.dp, args.tp), mesh_dim_names=["DP", "TP"]) + device_mesh = VESCALE_DEVICE_MESH.get() + dp_rank = dist.get_rank() // args.tp + torch.random.manual_seed(0) + torch.cuda.random.manual_seed_all(0) + manual_seed(0, device_mesh) + else: + local_rank = 0 + rank = 0 + device = f"cuda:{0}" + device_mesh = None + torch.cuda.set_device(device) + dp_rank = 0 + torch.random.manual_seed(0) + torch.cuda.random.manual_seed_all(0) + + ptdtype = { + "float32": torch.float, + "bfloat16": torch.bfloat16, + }[args.dtype] + + model = LlamaForCausalLM.from_pretrained("openlm-research/open_llama_3b", torch_dtype=ptdtype) + llama_config = model.config + if rank == 0: + print(model) + print(llama_config) + print(ptdtype) + model.to(ptdtype) + + if world_size > 1: + model = parallelize_module( + model, + VESCALE_DEVICE_MESH["TP"], + llama2_plan, + factory=True, + ) + + model = DDP( + model, + VESCALE_DEVICE_MESH["DP"], + accumulate_allreduce_grads_in_fp32=False, + overlap_grad_reduce=True, + use_distributed_optimizer=args.use_DO, + ) + else: + model.to(device) + + def configure_optimizers(model, weight_decay, learning_rate, betas): + # filter out those that do not require grad + param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} + # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. + # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. + decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] + nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] + optim_groups = [ + {"params": decay_params, "weight_decay": weight_decay}, + {"params": nodecay_params, "weight_decay": 0.0}, + ] + num_decay_params = sum(p.numel() for p in decay_params) + num_nodecay_params = sum(p.numel() for p in nodecay_params) + # Create AdamW optimizer and use the fused version if it is available + fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters + use_fused = fused_available and (world_size == 1 or device_mesh.device_type == "cuda") + extra_args = dict(fused=True) if use_fused else dict() + base_optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) + if world_size == 1 or dist.get_rank() == 0: + print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") + print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") + print(f"using fused AdamW: {use_fused}") + # + + + Initialize a ZeRO-2 optimizer using veScale API + if args.use_DO and world_size > 1: + optimizer = DistributedOptimizer( + base_optimizer, + models=[model], + clip_grad=args.grad_clip, + grad_to_fp32=False, + overlap_param_gather=False, + ) + elif world_size > 1: + optimizer = BasicOptimizer(base_optimizer, models=model) + else: + optimizer = base_optimizer + return optimizer + + doptimizer = configure_optimizers(model, args.weight_decay, args.lr, (0.9, 0.95)) + + # learning rate decay scheduler (cosine with warmup) + def get_lr(it): + # 1) linear warmup for warmup_iters steps + if it < args.warmup_iters: + return args.lr * it / args.warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if it > args.lr_decay_iters: + return args.min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - args.warmup_iters) / (args.lr_decay_iters - args.warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return args.min_lr + coeff * (args.lr - args.min_lr) + + @torch.no_grad() + def estimate_loss(): + out = {} + model.eval() + for split in ["train", "val"]: + factor = 1 + losses = torch.zeros(args.eval_iters // factor).to(device) + for k in range(args.eval_iters // factor): + X, Y = data_loader.get_batch(split, args.bsz * factor, factor * args.bsz // args.dp) + loss = model(X, labels=Y).loss + if world_size > 1: + losses[k] = loss.to_local().item() + else: + losses[k] = loss.item() + if world_size > 1: + dist.all_reduce(losses) + out[split] = losses.mean() / world_size + model.train() + return out + + data_loader = DataLoader(args.dataset, args.seqlen, device_mesh, dp_rank) + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + model.train() + for iter in range(args.max_iters): + if iter % args.eval_interval == 0: + out = estimate_loss() + if world_size == 1 or dist.get_rank() == 0: + print(f"iter {iter} train_loss: {out['train']:.6f} val_loss: {out['val']:.6f}") + # determine and set the learning rate for this iteration + lr = get_lr(iter) if args.decay_lr else args.lr + for param_group in doptimizer.param_groups if world_size == 1 else doptimizer.optimizer.param_groups: + param_group["lr"] = lr + # load a batch of training data + X, Y = data_loader.get_batch("train", args.bsz, args.bsz // args.dp) + + start_epoch = torch.cuda.Event(enable_timing=True) + end_epoch = torch.cuda.Event(enable_timing=True) + start_epoch.record() + if world_size > 1: + model.zero_grad_buffer() + loss = model(X, labels=Y).loss + loss.backward() + grad_norm = -1 + if world_size == 1 and args.grad_clip > 0: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) + if world_size > 1: + model.finish_grad_sync() + if world_size > 1 and args.grad_clip > 0: + grad_norm = doptimizer.step() + else: + doptimizer.step() + doptimizer.zero_grad(set_to_none=True) + end_epoch.record() + torch.cuda.synchronize() + epoch_t = start_epoch.elapsed_time(end_epoch) + if world_size > 1: + loss_val = loss.to_local() + dist.all_reduce(loss_val) + loss_val = loss_val.item() / world_size + else: + loss_val = loss.item() + if world_size == 1 or dist.get_rank() == 0: + print(f"iter {iter} loss {loss_val:.6f} |g| {grad_norm:.6f} lr {lr:.6f} fwd/bwd_t {epoch_t:.2f}ms") + end.record() + torch.cuda.synchronize() + exec_t = start.elapsed_time(end) / 1000 / args.max_iters + # masure mfu + if rank == 0: + total_flops = { + "A100": { + "bfloat16": 312 * (10**12), + "float32": 19.5 * (10**12), + }, + "H100": { + "bfloat16": 1000 * (10**12), + "float32": 312 * (10**12), + }, + }["A100"][args.dtype] + if world_size > 1: + total_flops *= world_size + print(f"1 iter time: {exec_t}") + llama2_flops = estimate_llama2(llama_config, args.bsz, args.seqlen) + print(f"fwd llama2 flops: {llama2_flops}") + # bwd ~= fwd * 2 + print("mfu:", llama2_flops * 3 * 100 / exec_t / total_flops) + + if world_size > 1: + dist.barrier() + dist.destroy_process_group() + + +def parse_args(): + parser = argparse.ArgumentParser() + # Training Meta + parser.add_argument("--dtype", type=str, default="float32") + parser.add_argument("--max_iters", type=int, default=2) + parser.add_argument("--bsz", type=int, default=16) + parser.add_argument("--seqlen", type=int, default=256) + parser.add_argument("--dp", type=int, default=1) + parser.add_argument("--tp", type=int, default=8) + parser.add_argument("--dataset", type=str, default="shakespeare") + parser.add_argument("--eval_iters", type=int, default=1) + parser.add_argument("--eval_interval", type=int, default=400) + + # Optimizer related + parser.add_argument("--use_DO", type=bool, default=True) + parser.add_argument("--lr", type=float, default=3e-5) + parser.add_argument("--decay_lr", type=bool, default=False) + parser.add_argument("--warmup_iters", type=int, default=100) + parser.add_argument("--lr_decay_iters", type=int, default=5000) + parser.add_argument("--min_lr", type=float, default=3e-5) + parser.add_argument("--grad_clip", type=float, default=1) + parser.add_argument("--weight_decay", type=float, default=0.1) + return parser + + +if __name__ == "__main__": + parser = parse_args() + args = parser.parse_args() + run_llama2(args) diff --git a/examples/llama2_4D_finetune/sharding_plan.py b/examples/llama2_4D_finetune/sharding_plan.py new file mode 100644 index 0000000..89b0653 --- /dev/null +++ b/examples/llama2_4D_finetune/sharding_plan.py @@ -0,0 +1,63 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +from vescale.dtensor.placement_types import Replicate, Shard + +# forward resharding plan for a single open llama decoder +_decoder_fwd_resharding_plan = { + "input": {"hidden_states": [Shard(1)], "attention_mask": [Replicate()], "position_ids": [Replicate()]}, + # atten + "self_attn.input": {"hidden_states": [Replicate()], "attention_mask": [Replicate()], "position_ids": [Replicate()]}, + "self_attn.o_proj.output": [[Shard(1)]], + "self_attn.output": [[Shard(1)], None, None], + # feedforward(mlp) + "mlp.input": [[Replicate()]], + "mlp.output": [[Shard(1)]], + "output": [[Shard(1)], None], +} + +# parameter sharding plan for a single open llama decoder +_decoder_param_sharding_plan = { + # atten weight, no bias + "self_attn.q_proj.weight": [Shard(0)], + "self_attn.k_proj.weight": [Shard(0)], + "self_attn.v_proj.weight": [Shard(0)], + "self_attn.o_proj.weight": [Shard(1)], + # feedforward(mlp) + "mlp.up_proj.weight": [Shard(0)], + "mlp.gate_proj.weight": [Shard(0)], + "mlp.down_proj.weight": [Shard(1)], +} + +# forward resharding plan for the whole open llama model +model_fwd_resharding_plan = { + "model.input": [[Replicate()]], + "model.embed_tokens.output": [[Shard(1)]], + "model.norm.input": [[Shard(1)]], + "model.output": { + "last_hidden_state": [Replicate()], + }, + **{rf"model.layers.\d+.{k}": v for k, v in _decoder_fwd_resharding_plan.items()}, +} + +# model parameter sharding plan for the whole open llama model +model_param_sharding_plan = { + "model.embed_tokens.weight": [Shard(1)], + **{rf"model.layers.\d+.{k}": v for k, v in _decoder_param_sharding_plan.items()}, +} + +llama2_plan = {"parameter": model_param_sharding_plan, "forward": model_fwd_resharding_plan} diff --git a/examples/mixtral_4D_benchmark/mixtral_train.py b/examples/mixtral_4D_benchmark/mixtral_train.py index d6674e4..32f77cf 100644 --- a/examples/mixtral_4D_benchmark/mixtral_train.py +++ b/examples/mixtral_4D_benchmark/mixtral_train.py @@ -84,6 +84,7 @@ def run_mixtral(args): accumulate_allreduce_grads_in_fp32=True, overlap_grad_reduce=False, use_distributed_optimizer=True, + whitelist_module_types=[MixtralSparseMoeBlock], ) doptim = DistributedOptimizer( diff --git a/examples/mixtral_4D_training/README.md b/examples/mixtral_4D_training/README.md new file mode 100644 index 0000000..471db35 --- /dev/null +++ b/examples/mixtral_4D_training/README.md @@ -0,0 +1,26 @@ +# veScale Mixtral Example + +## Overview + +Train a Mixtral model on a small Shakespeare dataset. +We use a constant learning rate and clip grad at `1`. +`attention_dropout` is set to `0` by default. + +## Run + +``` +cd data/shakespeare/ && python3 prepare.py && cd ../.. +torchrun --standalone --nproc_per_node={GPU_CNT} mixtral_train.py --dp={dp_size} --tp={tp_size} --max_iters={max_iters} +``` + +## Experiments + +We run the training process on 1 GPU and 4 GPUs respectively. +Everying including model params, gradients, and the optimizer states are in `bf16`. + + +![](./figures/mixtral_train_losses.jpg) + + +## Caveats +1. To examine correctness by comparing with single GPU runs, we are working with a smaller Mixtral MOE model on only 70M parameters that fits in a single A100 with 80GB memory. \ No newline at end of file diff --git a/examples/mixtral_4D_training/data/shakespeare/prepare.py b/examples/mixtral_4D_training/data/shakespeare/prepare.py new file mode 100644 index 0000000..60e56e5 --- /dev/null +++ b/examples/mixtral_4D_training/data/shakespeare/prepare.py @@ -0,0 +1,54 @@ +################################################################################ +# Copyright (c) 2022 Andrej Karpathy + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +################################################################################ +import os +import requests +import tiktoken +import numpy as np + +# download the tiny shakespeare dataset +input_file_path = os.path.join(os.path.dirname(__file__), "input.txt") +if not os.path.exists(input_file_path): + data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" + with open(input_file_path, "w", encoding="utf-8") as f: + f.write(requests.get(data_url).text) + +with open(input_file_path, encoding="utf-8") as f: + data = f.read() +n = len(data) +train_data = data[: int(n * 0.9)] +val_data = data[int(n * 0.9) :] + +# encode with tiktoken gpt2 bpe +enc = tiktoken.get_encoding("gpt2") +train_ids = enc.encode_ordinary(train_data) +val_ids = enc.encode_ordinary(val_data) +print(f"train has {len(train_ids):,} tokens") +print(f"val has {len(val_ids):,} tokens") + +# export to bin files +train_ids = np.array(train_ids, dtype=np.uint16) +val_ids = np.array(val_ids, dtype=np.uint16) +train_ids.tofile(os.path.join(os.path.dirname(__file__), "train.bin")) +val_ids.tofile(os.path.join(os.path.dirname(__file__), "val.bin")) + +# train.bin has 301,966 tokens +# val.bin has 36,059 tokens diff --git a/examples/mixtral_4D_training/data/shakespeare/readme.md b/examples/mixtral_4D_training/data/shakespeare/readme.md new file mode 100644 index 0000000..1e6c457 --- /dev/null +++ b/examples/mixtral_4D_training/data/shakespeare/readme.md @@ -0,0 +1,9 @@ + +# tiny shakespeare + +Tiny shakespeare, of the good old char-rnn fame :) + +After running `prepare.py`: + +- train.bin has 301,966 tokens +- val.bin has 36,059 tokens diff --git a/examples/mixtral_4D_training/data_loader.py b/examples/mixtral_4D_training/data_loader.py new file mode 100644 index 0000000..f83c582 --- /dev/null +++ b/examples/mixtral_4D_training/data_loader.py @@ -0,0 +1,64 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import os +from typing import Optional + +import numpy as np +import torch + +from vescale.dtensor.device_mesh import DeviceMesh +from vescale import distribute_tensor +from vescale.dtensor.placement_types import Replicate +from vescale.dtensor import empty as d_empty + + +class DataLoader: + def __init__(self, dataset: str, seqlen: int, mesh: Optional[DeviceMesh] = None, dp_rank: int = 0): + self.data_dir = os.path.join("data", dataset) + self.seqlen = seqlen + self.mesh = mesh + self.dp_rank = dp_rank + if mesh is not None: + self.device_type = mesh.device_type + else: + self.device_type = "cuda" + + def get_batch(self, split, bsz, lbsz): + # We recreate np.memmap every batch to avoid a memory leak, as per + # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122 + if split == "train": + data = np.memmap(os.path.join(self.data_dir, "train.bin"), dtype=np.uint16, mode="r") + else: + data = np.memmap(os.path.join(self.data_dir, "val.bin"), dtype=np.uint16, mode="r") + if self.mesh is not None: + ix = d_empty((bsz,), device_mesh=self.mesh, placements=[Replicate()]) + else: + ix = torch.empty((bsz,), device="cuda") + ix = torch.randint_like(ix, len(data) - self.seqlen, dtype=torch.int64) + if self.mesh is not None: + ix = ix.to_local() + if self.mesh is None or self.mesh.get_rank() == 0: + print(f"sum(ix) {sum(ix)}") + ix = torch.split(ix, lbsz)[self.dp_rank] + x = torch.stack([torch.from_numpy((data[i : i + self.seqlen]).astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy((data[i + 1 : i + 1 + self.seqlen]).astype(np.int64)) for i in ix]) + x, y = x.to(self.device_type), y.to(self.device_type) + if self.mesh is not None: + x = distribute_tensor(x, self.mesh["TP"], [Replicate()]) + y = distribute_tensor(y, self.mesh["TP"], [Replicate()]) + return x, y diff --git a/examples/mixtral_4D_training/exp.py b/examples/mixtral_4D_training/exp.py new file mode 100644 index 0000000..6e7ebeb --- /dev/null +++ b/examples/mixtral_4D_training/exp.py @@ -0,0 +1,95 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import os + + +def parse_train_loss(log_fn, name=None): + lines = open(log_fn).readlines() + train_losses = [] + for line in lines: + if "loss" in line and "iter" in line: + token = line.split()[line.split().index("loss") + 1] + train_loss = float(token) + train_losses.append(train_loss) + if name is None: + name = log_fn + print(f'"{name}": {train_losses},') + + +def parse_grad_norm(log_fn, name=None): + lines = open(log_fn).readlines() + grad_norms = [] + for line in lines: + if "|g|" in line: + token = line.split()[line.split().index("|g|") + 1] + grad_norm = float(token) + grad_norms.append(grad_norm) + if name is None: + name = log_fn + print(f'"{name}": {grad_norms},') + + +GPU_CNT = 4 +DP_SIZES = [1, 2] +SINGLE_GPU_RUN = "python3" +MULTI_GPU_RUN = f"torchrun --standalone --nproc_per_node={GPU_CNT}" +CODE = "mixtral_train.py" +LOG_PREFIX = "mixtral_new_MOE" +TRAIN_BIN_PATH = "data/shakespeare/train.bin" + + +def run_exps(max_iters, dtypes, run=True): + if not os.path.isfile(TRAIN_BIN_PATH): + os.system(f"cd data/shakespeare/ && python3 prepare.py && cd ../..") + os.makedirs("logs", exist_ok=True) + if run: + for dtype in dtypes: + dt = "bfloat16" if dtype == "bf16" else "float32" + cmd = f"{SINGLE_GPU_RUN} {CODE} --dp=1 --tp=1 --max_iters={max_iters} --dtype='{dt}'" + log_fn = f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log" + print(f"run {cmd} > {log_fn} 2> {log_fn}.err") + os.system(f"{cmd} > {log_fn} 2> {log_fn}.err") + for dp_size in DP_SIZES: + tp_size = GPU_CNT // dp_size + dt = "bfloat16" if dtype == "bf16" else "float32" + cmd = f"{MULTI_GPU_RUN} {CODE} --dp={dp_size} --tp={tp_size} --max_iters={max_iters} --dtype='{dt}'" + log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log" + print(f"run {cmd} > {log_fn} 2> {log_fn}.err") + os.system(f"{cmd} > {log_fn} 2> {log_fn}.err") + + print("train_loss = {") + for dtype in dtypes: + parse_train_loss(f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}") + for dp_size in DP_SIZES: + tp_size = GPU_CNT // dp_size + log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log" + parse_train_loss(log_fn, f"{GPU_CNT}GPU_DP{dp_size}_TP{tp_size}_{dtype}") + print("}") + + print("grad_norm = {") + for dtype in dtypes: + parse_grad_norm(f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}") + for dp_size in DP_SIZES: + tp_size = GPU_CNT // dp_size + log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log" + parse_grad_norm(log_fn, f"{GPU_CNT}GPU_DP{dp_size}_TP{tp_size}_{dtype}") + print("}") + + +if __name__ == "__main__": + run_exps(1000, ["bf16"], run=True) diff --git a/examples/mixtral_4D_training/figures/mixtral_train_losses.jpg b/examples/mixtral_4D_training/figures/mixtral_train_losses.jpg new file mode 100644 index 0000000000000000000000000000000000000000..447c8dbae55a1cefadab75d85d3a620251c8222f GIT binary patch literal 28883 zcmeFZ1yq~cx-J|F#idXvE-miG-P+>C-3t^A?%E)QwiKs8i?+B0*Whl&tps-sF2Qel z?%n&W)wB27d#`i;JMKTmPR1O`m+#A$yi?x!ywCg0dH3UP5%54kMqUPhgoFfmjd%j? zW&qCt=%}b@sQ1y)(9ke2&@r(Iaj>zlut^CX;t^7kQBzTpQBcs(aWK+6VWXv>U=muI3(PUDIRnG%a6Mj06zLX-Fv#oNKXLw z@R5-5k?uMG)Bpg|eZ;^0`G^0{1L+pOv)|sgE4c@tC;;VPHNYCLtxGXJBMvX5r=I z7Z4N@e*QvQMpjNw4Zh%-Sb2gWPFtSkGW6@q}0$%ogY5o z4n!w>7W1vF6@!*X{eZ~KWd!pP9q$tT;ZM{4!?J(Qu%JI>*cGZ6|C2;LXB{~o6QAA1-kcSfW@ z@h`=`zxQvJtzWG#sXHQAuNW1|OOM?O7kGB7cc`uv?ZaWng1y%R0FO^<(g4CgC`5yL zpkBGj=3OH5J3yvGVP>&Q|E+>d)g9nfA@vQpyCfwgp$^(1r`|ZPfrw2}u>Gs0X$#iU zl>37Fy+vGnHz5M*Q!|Gi)R5xU-Gq9>eWnuZ;Wtlq$lS|^#ENQ0z8=l%G1xtvN`D@C zIx=%@RrwmnHB*ITP{Q*mX#KwH$am%e-`kz50UK5KJb%md7O=<7jKG+|8#1bI(%uu@ zodbOK+RO7lYOS!=8UkNaWd){h=%0b`8^lHC(b{3k^8$SXE z>_wGGfXvy)EV@1;lWM?K`w>;>+;eY-ik6v7yTuhEJ{wXQHT91WOC{djwlK%f=?f-< zsmkXM)T9Ypm@>iu|I>vvrST3B8 zQ#mVjN8014c74sH@JZq+r;aGPj8Q&wGL{5OkuQ=YXpI6IhwRnKlmP~c_5Fnh`tzgH zgguK)mJ&So0r$A&DJAzM$u+fb$aSlx+{T+KkBP8&9VrT{R@}{s!bP6a1nUuaEvQ89 zY4LlmcafKr_2<1$+}+||43VF!M~WP6^$4C15VqZ5YhKL)nw(*l-2p~wv___cEbajG zwEA~|Qmv@l2*_95$h>WRvC!hM#K|4*5^0C@jX<5eyi~#QD7J~=@gz?fht#iI5X}u<(#p-O2@b}S3uwfdN!#S1&Hy;YZHsZidkPjg&038mE? ze^JdaIXEYU8IvP|C-r0FxKJ3>HbAr0`zZI(P6(3Uj3;v~()iT|-@=GtS^s92-(};; zR^Ez6Y)N%c<;hfWdj3X*AJxHC>d2fee#|h&gQy!c39!@Uc@SU61h$VGaokJ`7r9(_ z5U#k1GJtPiQ@tdyEmN5a))2nkMU=6kF0$hm>v^hZ@?{m@k7qJBOr&1PQjX$OG1TeW zxtCf+!TBI6innF7xA+5ClVIztj?&}2FPi71#3%>C_P%GERQFX7H+uF>`hCS10bl(M zV7IcuQ{Y_Wq>H}s4O#Qx#$W<8p@#9jk!l-5^k?NYN7mJo{U5jPKOz*DQ%L9Rj+iLk zEYqf6t?J}~UT-ZNuQ}cU-gFukJ*~6b;-3X)!2RmKnC|$Z4|1PhlpMM>rX}xez^TMs zm!l|JE$AS{DTmSz7V1rd1%Ha(%zo$-mvOF(9r|N)H2~dx&PcXy?$uRXfky<6-$do zZ@ao&Xp7Xp+WD-~q0&EX{Y{$#W)acynT?1C$aQ!sBN`GM)I#klNfRD~ zepdW36W=-`?A0In*?;yN(V+1ipjU1Lv`gfVaLD#RghM_@IAjF6eO0)BLk|4ukkusb z|BXV?-GLQM9eDat7slz3+gP3kC4sTOPW@PENBkHG$}@YYdk3(O;gi7g`oKG&ttjFv zbaM6iwv{P~!I3%sI-|aCox2WjO0+N{N%1ZOx}i9hRa-w+qIMc#Y!gfL(bXX9ZM==- zRWK&T3EO>FW5$ESapiAcH(|?kSMHk$1|Buj=H$GA^NzeZa@#XRg9dP)sRA|02gXwL z16ie=?(9X8Pb=C8A-)kUCF?jY^VDUvrH5yAM1#5Rd39rUx=LGxvns#Xn`;}oAit2{R5a=lJ;)e$5p@MWgw1uyGD05-#PS&A@>Eahc$RB#RX6Q zL;}ex4W&A=U8^>Xe5OGK9aD=`4!|LQHS2axGT^DcHL20_I$@I7b>a&WLmf5owNTk1 zwy}RQmg1@*^Bo|;K@1&sN*z-ooMxgO;s``ZO+Kxl7S7ktR)}*<&-7tzPEU<5DzO(3 z@nP-ZjUofS!EU_mITULch34E-n1z%Sq(wGItw0<3sUO%7<98fxKZ3`5ezk&j_$(e9 zhjVFWHEviW>5OB1{+t@1GWM}cWful{Iaz1Cil9Hx5 zMbSa!K{>?X9i2bdX+FJ-Q^zkP%!1m?@xmOprmrJ4TE*GN_`t$gBpdEcHQ}_Kf|&c2 z=1$#cv$*k-Q%UfWrLVCdm`Fs8zvkmOIb-2Y20_8%?i@%4==GFk|Sa{Bz@+PX)p*5XZqmsJo81*TKj{<$#$; zVVNgUL1}HuCzM)cCq*|#lG9uFedwESHU)aJ-WI|GD8Vc{*HPcj^7!l@wHuWBitM<1 zC4WQ4&MiYLE1^$XFox?CR&{Vp`G^jLQ@}u$w>yCg7e}l?A*p4H-C(|R;2N94Dww9WcHlBO9{$#jr#XYK{6nY1bQz*T>*xdX(aI}9)^m#ViVF=~}Zf{4E z;BGs5g%s-)kb*R5J>Kv<&Slhb_SKP$vbu8LPNxwm5pOo+r%Uqa z?diEbfPG$=%#e0z*kzkr8dU95j1L6`{x&*I5FSm~_u9kAU|_!3+o?`Vd2LdP=a0jL z$Vp^AadbM!-Va>WmJL3|0nAYjjTP*sK5^ zfUFw?91HhVaJ;>Vf4H}w9HYk}mA|Fg!^M>pORMYWM0lR5o1h}liT#AbpPrBayZe?B~7 zMn?XG>r<7a3F5{>I8CJ0d9DMasdtW19~yX?xbPR)xb6V`oZ4!0{<4m_%FV4!d{u{H z3B8z7c^wk|xLy)4&YBv>p0t7niDj9$kVh38TUe{lz3MGOMUT40&|W%^mEDN$KKIgr zJ+9bv)9#er9#9~*sfI2Xc+dEsOv2{Rn`=UGhJ#PCON8KRAXk42k8|#}ALx1csS6xZ z?E5{}H#;MRxmZS1SveDFO}-)&=N}c+rVA0C!{`qq`d^Xee*;newN+gcOpes~O0MoB zAmrvu0iyZ-C202dAluK+wKmEs>zs~fq-l`?^rQp}2{rrhlPcqf$5V{PM14Dx*zW*= zelODKMij~AUtbKT6mW?ag35iVw^98C+AmMj-=BOg*_{bpmg`;T-^;x|VmX~CBA$$e z&UTcyvc$VNyd^uVd(sToSuUe{B7gF+VmOUb4DZ!WdiUXzI zxQ?weQVidRun1a(t)rx2sOx5__C9%P@yh3ctePG-yuJOYu~^DH?ILwMe z&un3DkS_FkbZbw=ob~9?F!n;~xa658MWl1uoY2`>6t?Zd?I!B#hFW<7c>#;psw93L z*nu!Z*^YEngVZM*`c>M#d9%F9Lvly1;XV&9EFq5XE-Qtb-!8_AVwibEzD5hN+laDz z0PJoZ1FPWoYgFvyY)_haXIW3GrOd2{OjOSJlOL6j*{h8s+}rM#ziGj5t@+FQKkwJj zcn&dHZA6il(b5FJjJj~iq+U1 zBYe`m;(GOwlL1S@cba2G`x!Fx`R@1n3g)$MftuKe|Brlu*p>6TqjqvP1D}r zx9j5B4+FFvP!52bVtfd!%}NvX>Z4T1gEI4DB7fY-EW(jXms6K{mCQS#Q~?D!Oa@ol z=TVE{%k&ca3Q~7~YD8u3?$8odnwZxd!F0tS-iq%(-$PQ1`V=6yGMQGLu;ocVYS2-` zU62)rf;6WiA|)nADo9ywjl}tA3r26ioDsSjHP_2)`9fn`lRNGrY^_1!Q7&;*Rp*2Q z%pv!7pp50Ldf;BjF-IwXMctt5pzld>+wUjse;h9S>h*p-mUTwa(H@=o9=IHjDV!D} z*Gjbd1})75!%wp5U~S|H#!SDA*bkQoc<@EcmT{}G(IHUYi>$lDBYRYiLY9?L)%BBO z%jz4g{xBtaZISMyZI8wE)Gf_)HmVZOa{ouS#x_)$(w1Jnx>3f%Y^O{IqMNPm)NL!F z)Pbm#JK4MwBfU(uy`3kB9PF3Jf8-7Sa_;wRA7ZZK1g&OSH=kiPAqchV%D*+&y@n_1 zi&G6-j)ycu$ppgq-Z!-j(0wH7f3;{4xM_{7AGUP;d?gvis5P!Vq!6~FG&0IKzTtf) ze37!u*NUy>)J^Apul_%&%wMB#9RvgU+xhno7|6fJO#b?LW|2+wj$V`#vjD zs4sQP;w_QSPoK~io}wQ+N2PI4qLHjT!H0*v;1|nU(e?Jk{lXA04}eI>Dfv`+6L%O0 z`3bC0N_0N5l}%dlG7gm{p}7Mz)KUkQXqVT&Q!NWRnyYy(3VcFs#$#Bkz2c*?dn@a~ zPw>h7y{cMxrS{B=X8PP@rMMAkqDk+iztsL8z5dHBqfH#1GiG{_y)^-*+Urw_Za{Y$ zG@2kpjtQImI zJHp_vCW5gA1%s&ga;ljusCW)lJC%P84|RTFWq$_gzlF;G{ZaRG6u$BrlZ(1BSuPOJzcCVrxI8aq3(sU6B%k&qQzRM2AINDQ+Yil>M153;W%ns+iAwHi?x$y0W%ky6E%>5{GIFDtT+ zf0R8K7f!7U=XrS zal8a(_l@zM;+F(I#xXW&Vn|MU&h*FrogD)Mjc_H=z7InxpHp0Q(;sA1$F)$=c&~g& z8#wVxpn<811Pb36@%e|d6W5IWyWKOhx&j809#69{Y>CR zFK-s}obG*VeBh=OrL<=KBCBtH*_sqz8*Oh~V3mCDHEZf0Perw-ka4%N zbKahtW!Z%_@NVyC914`h6rhb;#$= z1CJv}E+1ibypgGM#Ou%RUx7JS87La)-2p1>o@Imkm9zc&izvUpcTG8roU=wQ{k0YU z%yBWW`wcZhRq(F;VLQS49^QgfWcIs*TUy$kD1|NS{1n3F<;$dzLKZC|xxFZHxW+Pm z-lj^{_lw%nC2nu)D3zX7e=g%-V}W9sTO@cKZvH*{u@aJcfN&x8Jm;eF(j5S)#4Pef z0$(gYb|f+@ao&#=Bj>c@Su%3AvfwH-0MQUNM z2~8*TlrhuISU*l?9fHD6LWh&vo+}(rqa+@ibUrZ%dF9#laNUyR!iz+3kTrL-QfSxD zpk#Y!m5nTHhRudl<_W%M=5$F2Bc@bkbn3k+O3<99mXF=wHDAG)o42C?P<-ZOr|^Q2 z!Iz!z{cnh3VCmvNl$8A*wp1bVfWN1O5qZF0duN^F`>FD+d87%Bj3!?O=_t3&D+=bA z=l31HYsRCEhiwC?4n^fal}(-EpX!d9E3VV1W4(NvH9rp|7?F`@R+gr8ZsR)h3ni0L><4hm@1c0cd0?A!mn)?y8T7vj{z&4etg zz0-u0yQNXO<3-aQP=S!58!OrBDK@3-5fD?_vlOc!uFENYj+6qmPo{O*LVW}c7Ja8v z)$fw71rOhR^Y^sK99ZwAkSty-EK_BMV(iS^YqV|f)zO06FW1Y1i!z1%u2jAZE@y^P zI>*HNcoOXl4^VdC4`ICvlg15iB*BjcCUajFtEn z2P0_2lkm1*8|Syj|GL{KHdy9WA6leT7c{Sn|6=}KyJzpH`7T13Nv+u*w zjbpy;i8c0R^pUXi@JUiMv`5vUjr#v7_WM0i^f$B7zcuW?`5qV4^@hrXz{|x~T1$ZW zYR1J^kWvCa27aZ(d63(o;%dPk=(PdU{$d^Mq+k_~khR5KUg>1%M8VhFKO8wS}+wf!j+2 zdpZHGO-~NGa;d(7e{8)Wm+#6IYwSt~k$xa(d+|bsJ(kpYZI4{4#}1K0oFXFdQg2+s)tTP!t_nd!s_mKfIxa|;XT_O@!@PF@eBESrWudHUq1hP2Ufe|~6#g!w zmLM-tE&c<`&`Y~Vc#clPGd&g|&@)%Imom81iu-DjPss_9G+JKvW^MD^?ww5vXHS6h zE!kcuTyv<G&IU@Ai2}`XeH1%xnLDPS{(L=6Fj+}+<#DgVoxtNFl48KXhj-h=UOh(O~S$dKxbaXXk=UJ-Y$_HLf4AuwhJ z4eZP3YRq>~WXwE5Mug;O+CS4fmbR_(?GAtptz}*Oy1FCfJJme5b|n0C`_qoJ)`-2Y z9pnr$ap^d`MHD~I9o;7*g+>#me=HGd%o<$*H`kn_oG4>;Tz8VAG#A=?>I^$6d07n6 z(vmx-6)s(kPgs~*t_EfzU_~X*-^3Pw0#E$v&3_-_BEq@f>EH7HYc%&~;(%YigZ%vs zgO{0gmI~pPL1x-)Xx$N9)GgZoDbh8bxsge*s0Ln`n{@4c?L9X7@8%Ugftx;29UjFY zBd=v-ReM@O5s7&pvO^pfm>*KfoK!}q=tG_$dh<`#gCu#BxPlbn=OrTe8rv0G1dxOX z?{ndTNuDvY7D0XH1^w3w`)%U>pNW6}YJDB(J^!fuVyf{T`PGh4QIP}36QE+~+eFZkZm&TpSBdxo*siFk%B|8@t5FyF zN&w?^ylV^z4NGvh1|rIv_%(B(y|#K}ps&cQT|OuGF+Pak+rkfr#~^OjgoB&Y&O1O< zRj-3vKUk@OR{|(LPsC2j9G#Y!s25U_S*j%&1jMKvI#-A!*zGtv%PQ z*o=E)8Mi~L9?M%o(q^W(x?)5n_Qa|4Pzm1o0b)cbxhr_@f2?mH({}ivA2P&nBr|0C zTA*x?44geDvBCdT$UGm>idK!+s0KpqrcVub1l#5ub$*^P*?+G>R*Dq&{)~tQ7c(v3 ztT-VI&{c6TQJtPUtp_zSZcCVA*qXDv*fZQ?Uyl`k)%k%oQgSO?6Ot2XUMLo)9HNvx zI>(qm=-`e5m`0;r=ga?}l$iXh^CY2r1P2aKzHV#}ZEyXhLzhC0YWpk&bz4LKQd6Pp z(*6TuFAYum>@0C*m3xxHc$Ubtn_|Gpo$ul-n*`lwpoccw+Us-i46i97-^gADcQ1>v zCO#x4*N)sd5_wvlwU1d-l?tAp@!t|ZmuDYR$td$TN-=)&_VoiF zG9@EjXM5g}MKZ^aj0fu#>`m$RV&KZpdNf^Iq=&7mF$*=SzK6y;*@p7&Sx`d-5_vY3 zT@vgdS)p#C?i~r7$_v<@L($4H9~G5^SmdPXFoij37fX8kKAF6$r@*@^(427>Y}J8e z5aYYwIb~xlY$zfi#86C{tVX`j(H+#VruiFk69u`v~yUToC(JnHt&@n z9d0bPC*{q%8u2(OQWCyYeT5=#y(Y5yZ(5q=9zfZOissXGm!|uNe#XqB&a5okj~3b6 z+~zhHeW-_Ti6(TSyfU9R958+`JH!p4V#{(dLiNb?c!qWmlj8FRF!gmyoO`6oYvoF; z_1M7L`Yf*OGjfnD%hPjK-lYBU$wnMn!ODx`-I^UR#CXPS!rt-{0R-!n5$@@K zYo-45$XL6jnUQA=Kx)((pF>ZeMcV!5Euj7Gb!BI&D&^bdfHDZCy=CX?^zZNpH9 zCGUwW529ZJlBi`N#VpS{p)dLc+-upRNZMw*=CjY;I@9(@)G5ub>VKy zT|V?>mN*^&eJR?9D~bx;0i@;x5Si+@YoH|dyn>jy$vQCT4iLik6*P2~eFyMI=y^M| zwTz@4q+OfutDuQHf|a^~dM4vLpe440jTR3!4Ze3t!&qwdV#`Gr$czkl;247G4UD@NMtDQ3JSoD4*K37U9wj^OX{4a7;| zpd2p)4c{x)uKI)7B~qTwK)ko1V6#5aI)XF$ zxoB=mx-OximxoXY9>uqO@nrLCFEzSYBM)YG>x9uGgNE{Y49q?~8M>C9?q|X$dAW!4 zVEWs7`y4N$eVGWZNN=3&E=muy3zuB>-U}Zmp$($M94}7Yu?C@~ewmtg1Hm2CP>q>Z znj_d_$EUmM&e8SnlvkG_lxQ-|<@JZYvj)3H&u((6M8+@Pn@SI%MQTPmOJaDh1RPic zzipM|&+k0-FPKOD{45_`Rk3H4vUWC;M2L}S2p-(20S|7u6)mkB%^EA_t>+=kJ4Rz0 zz&8)aHXq)ImRiZ?NL^ZZ0DsfpMwVTiP?HGtNO0SMNx;yfs;Vh6<4WUSZ=^{rSHOx+ zqww{wRo+f_@sYPu;Lmh&mlu9?Lh&6WcDgUMi;Kw3)E|lsqzaP=JS`ghn*rBK{abSRgp z?>D5;)l9A)4CKByjcpUr+g*?4izR+0=f-&yEkzR^8E|Z5^44whWbTRF%E4)M5Z&kI zX_^=lQKY)Ca0kIno?%x>lEhdOcuBWULI_o^cdm_(=!2^6A6ZA)wMnXaGKYP6YQwK$ zea0^dlQi)Pq_W^$_PSZv`lG46m-84t>gSpHF1r`0f4{ z8%649QIF|l96w!W1z}J|+Tw7^O>#Eq& zfty%NXH0xBu@|*`qi_f46yTmIg1T<2xY5tH$dx>N^yA*@Mz~ocPn&0=fH?i(Ta8&I zHjm428%&V|{E`^bw&3B{ae+J9ei)fb8D6Ziq^@_^!G2Op zC4@UVP|J@uWu}FquCTCe!Ph}rlHuqsNa}6|8@^{XyXSZ4>IL_(#*9fs025yj5&C3P7b-4o} z_gnh&Q>vBuhVKrzS_b)~Zu&k%WSRxa9E$T1u;N>lql;}1Hyyl#p-yDO;oDYchMgHA zbom`RP}{=0=ZN(W&c-3aVOg4q8ZV_WhV!<@wSi56(u7){Lw?d^e_lC+^cV2^j~Jxd z^urs@tB=V}D>$prL@o{^ zU8A1b+w6nFL_{KLI65ON6ssTjg)tKAy+thA(Yo#|S41i9B=WabG_=V8hKzjvitL=U z%V1k=vtE+IS<|v@+-#?V)tz!Z-I3IlKC|fCpc9dtAYCk) z@d=W6)e?-SE)-pW%KOG6yPW3b4HHG|ds$sQ z40SI&hF&rY@gjYbwldziQTCS;$mZ_dRwkCNW{_zxDE&lZ`LdgK7=Au4DXpzVyT`ArtBR&0 z{R(cloMdVntJ-|R9+gc-luRp)R*?7&@mP`+3%s(*%?8-3NjpJ2BmCqGpei8q zCG=vp%rtAMC`~oi3kTKc2og%oLr8!~0c%9#;Wb*DHrZZ`*%a8>Udb(n0k*A+O>PN{j4zGk z-IQ?WPwgBdzIg9n2dt?nmfp;`Sc{udutUbd`EJ9JhN6*9FWP4mjptEKdtT zb-fk)sQ?H?+E-j?*w@ST7^)nI`F1sHo`9l;Q;-H?BY{DXLhFQ|R?+k2UN=f=$zau3 za;Cl2Sl}IiHq=A)tG9i6{kHHhnt`;-rDvquPCEKLJm`{{jdpBa*U>aO)Y>8(_PS>& z(mC9{r|6==waDfSqi}nkptfQ}nM}SD+AE6gBN`ppLW&a55v9z8T$_KbyflPtQtal3YnL9AAt0{@tA&HK^n4Wx<17#4oIhog4L z5Q6spL?lAk3Ca2MY*g*aUi$Hmq~kk4s1K~`Oa3feMH7`#o@6HU_7Ul9`t$Ip?#bGf ztFVSM92HdJ6~LLJk#pG?L76@~(J8Pt8!yvKRWi=O*X zxod@K5q<2{&2+cu`9{*iTL}c1bq0OO-?kuPHTgPL7GEIceiO_>UexpcWz<^KJT^tR z#~olt=e}{a($Uz}#|xd0pH~+QQ(or+S=~YTo2JgUZR~>ccL2`eyo7VmE4kb0s{<+S zi{~+EpD&$ttTP%vYpBkQd7PczeraVEav)ka=A5K%X<`~bWzvi(ws~uiM>^+^OYl<` zsz-baFHtI(N^+0<`bwZ$`4b0AV*$UvRBHfIH{Ey zv`YEUEUJNx1yZJUGAmycVtwM~t12V#6oVuPs~f@A0%RWFC|G^%&6-|3rT*esGdQDH zyFnX|+K9O~IHg+c3!JfbuwU4h=0?-L3~ntrOJayKwW3c=OSEtVGA@LYyS;EM2n?2_ z@~I)FdUz=DuBbBV(JatIKYnO>b*y+z@dIs?skp5i>SAHw6m3#+tHeQ+f1JVX{O#Z^ z^me@VcKFI;08xKN#bV@H>tkqS*71;)jpHnPWP{e%{0^DkC=?HVgd;1fePJXcvTQ!t*X2Tq19C_c9 z+nH@M%%0I{3M8Ah+Wq!em8OTq!}(~HK(HG%K-55abdS#;&!l3!$`K1{sOarha|eL( zzn**1IRi{YtagT~1_l7S70yywgr1D-W+to}$t4^X8sf%AXL7<%uWhl0@T?&b(76Tc zOaF;JrwSFj$d%0w%SuU1#J)>{097>VgC3FwnzM3H_ds=~FHx7z{?&_j=`CR@Xm@~~ zg%!{h#`8M>LK-G*nrUDIiOlqN7-b%()601uQz}QuY(dh)c5fqniRz*XC$*|!Y~X$a zm!ebax|<}J764$A)d)aUF_ok;N63w^4ZXrth}#+8^RbGF+k0#?7r|aA(2*Ff^*utQ z+0nAyn5`TcqfSy;MxQ27QI263eC&};JXgbO<4wH1EUum(gpLhd@ijcX7$vnh8s<0` z(<>WK!Byci9+R&sA9u%s2m5E-jO19_-E?Ap^#P!%sQX(Z51j58u|69;b;S45=>q38 zSQ;dh6LSp!>1GpDn9H@k)V7GJC?vfeZ;)PDz#U9H7iM3*CHMUh(O~^OJL`R=SBx%F z+*TSZVlC$;feF^e@X2HYPOTwdmgURKy)IXsfz+}_gE2=8i0*eqA*qJPH1Rl7C;`qe zNs6qrN0o@B%ONFT*vjOo(B7W(!1wB!XCig0bQU$Ck^q1cIte9f((RDM3HvpohFcfz z0B%!-YLXPx?;ZHG4#97SmIhSRMl0c>Gp5#NkBbdh^v)qSor~^ky}K;UZ4qWn;>3DN zc`u;u;EYlM0f!IdLA##F^?`Ab&I3O5uWiLiM^D+2uW-PV3rZXzU|77ldco4!EA0Dp zuU7IH0{B`{?%^UlzV{uV{l-dh^qSO}{x%b6TeeMHv~Cmc5mhO_4ICE6#_Tjf_p)$^ z7=?Sv5#Rs#CFGNDKb8F8?8bb6Xc@Se!w)^Nta8$Mhl)G$Y~T!V*h$}lzNBj|>Hr9L z5z+Cd3bbGP;CQW9XgdW3!@wU)Qa#jOsFn=&g@qMx%QBJ{=d2G-9?~inT3hL^>ZvLO z20>B^u}^4@@N7#8xA-CJ@#4uol;B)|T)_`EBs9_!!Pc{{xZF7Aw+|#jecYzhK2#O% zr-_(3Sv91Fza{Vie0T9hIsrSPZY4n2-==wF+xr&V;yQ2^qXc}hp#8dECW|S8v~Q4< z_91IWp3D|v6n`lrFAb@Ye-<|$(Xw1H4VP`#cc1n2Fs2%{FE$sF|5_eFsm#@Wzh^l> z)Kg~Efe(bF7>AsD$p&?tHeZ(Q$_cujPnLEYEpv?(XWT_UWyB@FeEP{O(t zmVWLoVQ7{kG3&F7GW{0+(G$5*`@ z)&6LG-ZWQJUlEJJ8N=PNm>OIcJ7bUL$jG<8f6G=?CAaF@^V$}^oQQh=y+)j3%(Uj! z=$!2&t3#h_!f|{#-c>=jczliT5+}~Kj*HQPm#;jeE{=RXtxv$CrKKM~uH{}OkL-deeHH(bZrZH6y;MSpZ#0W^*eymc;UJ1XhuCE7bny_%g1sv za`N^5U)aFEyq0jlb2B^5#cDo)SFMZHCeUAA52CY7BL@L!8ApKQK*kRJC-7iJAo1r8 za7rnuAWhVcVQr6sn|@IJrv0-j!HUuIeDs1+uEqU}= zD6x3Rh`lv)8k8jHL5~6}?S;`94)Vff4LFP&=Aw=F>Q*Od2o^gOJYV0l&W%qgtzLf|3=uV>yuNltf(ZT zi;(SL`9mh0%U58B-OV%)k78niW8j*73M{){2BeH`m2PEarA=qIeRh3_8fU4fkE=+@ z38wuK5hBHL^S4bws-}qzh*byY10(n{kZ&wz-tok}449h#Fn8Ym2W^E_lyYs+hVcU% z@g^BQuBSPpo>RSPjnXd9JE%B0kPfto1=(+MiB*i_CnsZ)5*FfQO7n87Zr7U;1oaZ9 z=(F30hdS@OvR&D#32fT6D)s2-9*5f{@D9C9XAhv9op|KwPBifSk;mxl?Dyzag2Y&P z5|)fedAbv#HG!!nn$ocb+OLMaV8c1vrz*aARd4%hUbX|hZI7NpzxDXm#gmh26-+oSjKNR{%aZ-I;<6Ez&gjVk(H6wlU}2Em{*;nm8%ocjP$;uPOG70^ zqveS6!wBBR+URp-JON8tP$6mrr;da&(S|(>AfynXtfQ}($_DsQS}yE(K;M*P+f>kj z`E%)F+Io<*m`S7edV7|XpW^Yg!^>fu?u`mK!j8nV#(G^DD!a)cVe=Dx4y3bv#ZFQ* zdppoCR)vPZ8fq+e#d4g1p>(<|?0a{B*Kv%L<=$WMlrkE$#Z)@dXiJ+X+1amk-BD{( z?a2C;2`fJKduVD$WRZ9GA3rioN%1H#9m}pue_5GyM9r|ZB$LFpVIF@!@j!?WDNZ7s zw@{Do@l(iAG(m^4n3TI?UUiJ$e*vWN4Tdx~#htGE96-C4(7Yk0wOt z3O^~~8eKZ%$Of+1)5?{Okxd@!HiYc$?R_RWXDNKM`*<$E>*5_>6K^YO{o(>&y7#-2{F4`R?T~ye+FHi2vO;+O zv*BVG%yBR0;p_{`505L8Tzb&&H+z5sV}&oLm|h*~wAWKUGY}Dz&^x%+Q1R(q=M`|y zxz#Gcg*OBXsEhs8%FL9sMh(>ZPln^4#Z(hzP zm;AtjvFwQy`gdY6NNH1Q1+L<6%GlmG3BO6shpR|Uo*F&Z8F7QjYBXON#wwG7Y38b) zy)J)5hg(O=m~DmF*3Sk3a-{KE($BY=FP;R5h*@<)7y{iXi!;4OX64e%Iw=Sgu6QDT zNYh~l_kwW*G|DG=;f^Oed3*K@q>;LVQoA6E!^jC0h&1TUN;_D+H%eyH?km5$Pd zG@9&QV6wTkMLc!cz^bI7TTPpI6tN7$>mXK#+Dv8dWvFn$)pwU{WB;TB#_2t|`c*dd1d)l!RxSd^AP1yXT z$CqU1`(I`Y!hePasEBYWoL7n8;PV~_qg7SB>6cd(N%oPw4ajA92hslJJX&^5G;l}N z39wr&$6g3*+7v~@<{aM<%I#zXSoEC)g}9R&QMY&+vrJ@NiTb=P->%?s$9R16gDSTl zKaX-iaq))b?Di5w0xbP}%DOsjV|xcUS^=gMfSMM(qe>{u0z_C^CPjf?H-$fsxnt!| z?zVA=*4nKn<0r6?d>#oPRTnxb;jBOqCd-s4@SwJ-PW!^aefIp&Wuj6;AuQ7OJ>0b< zbj7=?Q;Hz(;`T$M_0h^?V14evDypsn$8~gi^QVJn)yiJ5)33P4FDKn#R*Fh0Dh*-F zy{%6yM4hLEwjvi#ih!W^X%8U{kx+Is?$O>gSyt&M?X)LteLCjUnK~(b7Vj1vKRCSX z5}6rLmK|S;P`HUnbJNQ}9_H}g!v&69o0{i<+G-}*>B9)emdCPwe8tHl!g$vF42fDw zl9G>Tqxm4l_7T;CgN_<1x?#8E!U$o;6dvQ}OOxa%! ztT<;mqcZL`BQz=Lj^_34Gt$~Sty#j9)kjEW-6AU0-V2lqH@PZtZDg;RoeOs!G(6)L zH}Eoem!=|3D!69=@XYh`wdHL_+S);`7|J!?x&v5RUVY!;wYi8^S+_>6_${2uEB+0o z&V*$BH#U&S&)=#}YD(pc*nK8PZ^)fC|JDb8vLSSGLlv&K99mt^DIViCeIbk8nBwzv zPojhDEUE7h@7?aTd{(cyS7v&W5LX=endz-J(U9@S)Lb08=0ZwcKgC>94Ufvo31evP zc3gR5)pnKKa1|uzY?3wEMq8I&b#vQ=ge;-Kwf*w!+rhqS-6-F0PsG3$N4{v$*tGaKdLHjeH5P*v2n8H{_EiwKN{j8OlwE$#r#_P2!o#S(k{h+WBMu9=WU zOb+G}->d%^q93T5$o#@y45K?*4|(4Tojs66c_4LCaCS>5S2@!!blPYOD?Zz+P1+~u zu3YuDxc7Er*mb`8&ELLlYUKf6k$zQ9Ca7@xva;IFc!G0}D3+L(JgA`90OiQ_McVyJ zBjxq8ZtaPaaY*7&ZB=<426j1h!w(jPxJZSVq|K(kedDxoQ7ltX!Mq91bImHngJ%1x z>e^|dL1G78bJ{fkW!7t4TOc*64b)GZGK{(Si`#?cDPLwz2!Htgh4)g_OIHKXZ>m)M zKYs&hU|>L_e;4qcYSN!b=UX(bkN^uC7x*lvV6x@M{96cGqHeZ!+=dhNtKlmuIJEo6dkP(p(*S{~7PA>cKMdd|%0c(ns+7$oO&lNWqeaDi6ayv9(Kjd)k=Tj&6BL7zVL7=b~G?%#H=OPO`?)IRC# zP3Oo1%E(a6>Cr1ANG(@dwFXe%OV{XzC9wPW{EoQ_JbQfg0`~xKQW(4M}lA_EsU|N$ieIYgIfyfXqTXkt&*utTC z@kF8hS-Mbu>H0GCL%!Ex<6k>xOnCVPP$PLn#B8R~+trfYjOlTbO~&X~q-e+KDT|dn zS+|K3ixp2$go=Pr4iGWZE#=i&c-;U2ZfjIBQfigNW25|{y?)00JM4eI9-Nb;b11#G zYA$$Hry|(c+u`gr|?s+l#mEw6CGL0U>A8)>y(NoH`_kUK$} ziuGwErDy8T-3zcHQM5=a7$`IObF>G=N*k%)CR< z2Si12Xa!L(S;J5Ud-e+T+O zBtZS$4V!jRDaAR|Pv@5H`@4<&Ua3U(Fvyd#h)iKMR;QDIzg2~z=&Ac?KW|CUOR>r4JSp~0`&uCBv~EQC&)srFahagljK|;t+IjAjuSL(tu;lUcg{nch>eYCb4$mH# zA0Dix?g`YA{y9ytO_=A{UR>1|K%jE(v8)w{g{!-COUnBYvpktKggIW z#e)4k$Vlj=sf)%ge8c5mZLQ{)O7o>f{1XU=-o>U@bdD7wI+B^JX1OF57tFvS0&iXe z3#>mkej>p|^Ep`QXtEG&BiLOo*QVl&OC)NmlvWc3Z1EZ?xq3SF> z=UQ%Sb4RC~O(rsieqDdG!qkV6m&qf0$ZMuU883u~75x>)`p?0mN7;q8*99xehniGe znzSQA4^`6l6aPgR^FrUQ+sLT1m5H9PmEaBcX^@F7R0z-I4^(D!y!NAO7F>e8i-Em} z9#!Phx|6UDzy65R?l$!yUDVz`TkbMZ`C3JXYW>a<60+dHsf#smaCQGurBSb7m2m=- zI#|Gi6nFS0217~(`df`_v6QV>M?=Xo7V|HjHo8LIT8cN2-`KgS(|PhUSqJYP?m^y% zQ7c@<^Kpma+^dtGqm5Vcc;*>l8q`4E7Sg(ynROr;dG_tQW5RkH8gf;*_NBw52%COH z$PEJyM%N_r9FA_!-H6ORXVBU0+5HWSC*8{2mBhu}!6fY}(ICfM3UEpScw66zl-BgS zMx0zdN-*+OPIhd8zaD(X|v7{LYEQtCBBE?+dCoDB}a*xWy_T6*jbN=Y-i* z3!=qsvPUoh+Bc#D5}-^MXQT&Aw5xVN$OisHrX!1;)%P84MzFT!d70b#n3}oVOMZCS zq&*lI741U-{)E$;A>iZd&TO4ZN)fCaT7Hsr?Qq5#$l+Q?<)&?26gyPjQpNsDF6&j z0fN7WRPuexyV03b2JuW|5U*{8lxtO-w4|dF$nYJ`R6g#>wGpJngf_VHWkg0SgSHum zz&3AJjcPZC z61He&Ue(qrnSm?h%gPrRMg(FCc%MB@@+j7}w$Obbs?!EDM&eVaeVmGAY1U<5nHxPe zF_A>CcTth{UoU|rTV4NUia;eJN^XvGwaT+aRDGT(;&O$aMeCbMBqXhqtAOS-)Lqf~ zvP*5G0&ceIh_6YmTI-W}1UC0tCx#4-P9VZI!I^P#qs_ z2>kHbWopWz+s?+-^)vzr0av%Rx|v6l0p=h_h{-o-@?qn_3KUVy($adn;?-OO$CKJp z4bDgyg-XoZ#bTew>b&50g8jAw7^k#kgks%mn^R|GV(pasUh)$fKNmi$C#M@j%ZMDp zi=_LX%QVf8CE)=kZ?21a%`GUr?of3NIK@Dfd~wgLyTV4XTJSYdTV+bP=C=P$=;NE? z#KlTST&jaaY+|GAZud?exYvzQ+4l4M{8~TJ!9Vl_T@!EA#cVNL!N**xOX+ejx+*g< zULMX3DWei)pp^6j?~Dp5P{8t*h|J`amO8NGxb%a zUi_hJUKk|r*kFh%iHr9v0&D(w#y$whJTZQ4mDzH%Q>Hugsm%G9ML8FPF1GmqIL|#2 z*Agd}tf!m8{aUi{xhBM0e#1vOoraCPl5?6BA=0FRQDq-|480rmc~w0jvRl)m1evQe+Rn z+CJVW`?zIIpmRWSlAEa49yh5McS~K9Y9jBmw!FO4j#aQbA|$9MeOA~ra~5=iCXK$N zl_sb&;+pd=S%%Pw8%y9+4PlDV0Z)B>Sqa@O=UwbLiteOuk~h$Suo=pVkPnASQ@BaL z*Pi>_Mz88z<#0%o(<|Hb>`2Rl_$C!>0e@ zlbpNixYJIxZDP{YiRk`l+Mf0GGOx&%!K>>zj?#N(ZDMo-@LxmB1&8Dsd!~e<$yQR^ zsv|6q)rO~bH-wG@QN8}W%LvKCGyXhvXq1s}(Hg1C#rl`R% zE$gYc$W`RfWJP7@biZ@9zZ95e4-RSsxnKIAgip@WD04&v7yWk8E2JvH#>7vi#XOR9 z_qiELiYTy?I09?XlA}r@%xG!}z2jb0Ta>W0;Wv(DyR(Qd?-e>ejpjaP{sIu{CH=zO zWGFkT;Z`{$<OIjnfsCj()bn5;L0lF~1{{;{-WA!+)zJ#Qf zkMzkT3K5mNk9&_cM&z8wdb-E%_R|;q+1zs8B?gf?h7r_wZw>ahoZx-zeQGBS>dJj;WBZF*>pg`}3{-spvA8oy9vlVdfJ zth=~9YQq?nVF5n(@+LBma|uKoR?l5gaxSKA&RZc=mWqrsFS16d?>1AI3#U%pAGZ&agKkbf^ZL|YVN$ud7=p5MwZ~*Syd51sX$bVn%@h6ta8PSJL>Ytu0xiWX zQYIGg540=Mo9EkTLQAyG-%dxh+gv`gNeHFN@TB(eyP!M^mltsgyqJ>X!dWExw}<+# QbgX~T8~?A{r1&!OKfCn`L;wH) literal 0 HcmV?d00001 diff --git a/examples/mixtral_4D_training/mixtral_train.py b/examples/mixtral_4D_training/mixtral_train.py new file mode 100644 index 0000000..cb33761 --- /dev/null +++ b/examples/mixtral_4D_training/mixtral_train.py @@ -0,0 +1,297 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import argparse +import os +import math +import inspect + +import torch +import torch.distributed as dist + +from vescale.dmodule import parallelize_module +from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP +from vescale.optim.distributed_optimizer import DistributedOptimizer +from vescale.optim.base_optimizer import BasicOptimizer +from vescale.devicemesh_api import VESCALE_DEVICE_MESH +from vescale.dtensor.random import manual_seed + +from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralSparseMoeBlock +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from sharding_plan import mixtral_plan + +from data_loader import DataLoader + + +def estimate_mixtral(config, bsz, sqence_length): + embed = 4 * bsz * sqence_length * config.hidden_size + # MixtralMoE consists of 3 linear layers. + ff = 3 * 2 * config.num_experts_per_tok * config.hidden_size * config.intermediate_size * bsz * sqence_length + # GQA + head_size = config.hidden_size // config.num_attention_heads + attn_q = 2 * bsz * sqence_length * config.hidden_size * config.hidden_size + attn_kv = 2 * 2 * bsz * sqence_length * config.hidden_size * config.num_key_value_heads * head_size + attn_mask = 2 * sqence_length * config.hidden_size + attn_proj = 2 * config.hidden_size * config.hidden_size * bsz * sqence_length + attn = attn_q + attn_kv + attn_mask + attn_proj + return embed + (ff + attn) * config.num_hidden_layers + + +def run_mixtral(args): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + if world_size > 1: + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["RANK"]) + device = f"cuda:{rank}" + torch.cuda.set_device(device) + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + VESCALE_DEVICE_MESH.init_device_mesh(device, (args.dp, args.tp), mesh_dim_names=["DP", "TP"]) + device_mesh = VESCALE_DEVICE_MESH.get() + dp_rank = dist.get_rank() // args.tp + torch.random.manual_seed(0) + torch.cuda.random.manual_seed_all(0) + manual_seed(0, device_mesh) + else: + local_rank = 0 + rank = 0 + device = f"cuda:{0}" + device_mesh = None + torch.cuda.set_device(device) + dp_rank = 0 + torch.random.manual_seed(0) + torch.cuda.random.manual_seed_all(0) + ptdtype = { + "float32": torch.float, + "bfloat16": torch.bfloat16, + }[args.dtype] + + mixtral_config = MixtralConfig( + vocab_size=args.vocab_size, + hidden_size=args.hidden_size, + intermediate_size=args.intermediate_size, + num_hidden_layers=args.num_hidden_layers, + num_attention_heads=args.num_attention_heads, + num_key_value_heads=args.num_key_value_heads, + ) + + if world_size > 1: + model = MixtralForCausalLM(mixtral_config) + model.to(ptdtype) + + model = parallelize_module( + model, + VESCALE_DEVICE_MESH["TP"], + mixtral_plan, + factory=True, + ) + + model = DDP( + model, + VESCALE_DEVICE_MESH["DP"], + accumulate_allreduce_grads_in_fp32=False, + use_distributed_optimizer=True, + whitelist_module_types=[MixtralSparseMoeBlock], + ) + else: + model = MixtralForCausalLM(mixtral_config).to(device) + model.to(ptdtype) + print(f"rank {rank} cuda.rng_state {torch.cuda.get_rng_state().view(torch.int64)}") + + def configure_optimizers(model, weight_decay, learning_rate, betas): + # filter out those that do not require grad + param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} + # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. + # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. + decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] + nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] + optim_groups = [ + {"params": decay_params, "weight_decay": weight_decay}, + {"params": nodecay_params, "weight_decay": 0.0}, + ] + num_decay_params = sum(p.numel() for p in decay_params) + num_nodecay_params = sum(p.numel() for p in nodecay_params) + # Create AdamW optimizer and use the fused version if it is available + fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters + use_fused = fused_available and (world_size == 1 or device_mesh.device_type == "cuda") + extra_args = dict(fused=True) if use_fused else dict() + base_optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) + if world_size == 1 or dist.get_rank() == 0: + print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") + print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") + print(f"using fused AdamW: {use_fused}") + # + + + Initialize a ZeRO-2 optimizer using veScale API + if args.use_DO and world_size > 1: + optimizer = DistributedOptimizer( + base_optimizer, + models=[model], + clip_grad=args.grad_clip, + grad_to_fp32=False, + ) + elif world_size > 1: + optimizer = BasicOptimizer(base_optimizer, models=model) + else: + optimizer = base_optimizer + return optimizer + + doptimizer = configure_optimizers(model, args.weight_decay, args.lr, (0.9, 0.95)) + + # learning rate decay scheduler (cosine with warmup) + def get_lr(it): + # 1) linear warmup for warmup_iters steps + if it < args.warmup_iters: + return args.lr * it / args.warmup_iters + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - args.warmup_iters) / (args.max_iters - args.warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return args.min_lr + coeff * (args.lr - args.min_lr) + + @torch.no_grad() + def estimate_loss(): + out = {} + model.eval() + for split in ["train", "val"]: + factor = 1 + losses = torch.zeros(args.eval_iters // factor).to(device) + for k in range(args.eval_iters // factor): + X, Y = data_loader.get_batch(split, args.bsz * factor, factor * args.bsz // args.dp) + loss = model(X, labels=Y).loss + if world_size > 1: + losses[k] = loss.to_local().item() + else: + losses[k] = loss.item() + if world_size > 1: + dist.all_reduce(losses) + out[split] = losses.mean() / world_size + model.train() + return out + + data_loader = DataLoader(args.dataset, args.seqlen, device_mesh, dp_rank) + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + model.train() + for iter in range(args.max_iters): + if iter % args.eval_interval == 0: + out = estimate_loss() + if world_size == 1 or dist.get_rank() == 0: + print(f"iter {iter} train_loss: {out['train']:.6f} val_loss: {out['val']:.6f}") + # determine and set the learning rate for this iteration + lr = get_lr(iter) if args.decay_lr else args.lr + for param_group in doptimizer.param_groups if world_size == 1 else doptimizer.optimizer.param_groups: + param_group["lr"] = lr + # load a batch of training data + X, Y = data_loader.get_batch("train", args.bsz, args.bsz // args.dp) + + start_epoch = torch.cuda.Event(enable_timing=True) + end_epoch = torch.cuda.Event(enable_timing=True) + start_epoch.record() + if world_size > 1: + model.zero_grad_buffer() + loss = model(X, labels=Y).loss + loss.backward() + grad_norm = -1 + if world_size == 1 and args.grad_clip > 0: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) + if world_size > 1: + model.finish_grad_sync() + if world_size > 1 and args.grad_clip > 0: + grad_norm = doptimizer.step() + else: + doptimizer.step() + doptimizer.zero_grad(set_to_none=True) + end_epoch.record() + torch.cuda.synchronize() + epoch_t = start_epoch.elapsed_time(end_epoch) + if world_size > 1: + loss_val = loss.to_local() + dist.all_reduce(loss_val) + loss_val = loss_val.item() / world_size + else: + loss_val = loss.item() + if world_size == 1 or dist.get_rank() == 0: + print(f"iter {iter} loss {loss_val:.6f} |g| {grad_norm:.6f} lr {lr:.6f} fwd/bwd_t {epoch_t:.2f}ms") + end.record() + torch.cuda.synchronize() + exec_t = start.elapsed_time(end) / 1000 / args.max_iters + # masure mfu + if rank == 0: + total_flops = { + "A100": { + "bfloat16": 312 * (10**12), + "float32": 19.5 * (10**12), + }, + "H100": { + "bfloat16": 1000 * (10**12), + "float32": 312 * (10**12), + }, + }["A100"][args.dtype] + if world_size > 1: + total_flops *= world_size + print(f"1 iter time: {exec_t}") + mixtral_flops = estimate_mixtral(mixtral_config, args.bsz, args.seqlen) + print(f"fwd llama2 flops: {mixtral_flops}") + # bwd ~= fwd * 2 + print("mfu:", mixtral_flops * 3 * 100 / exec_t / total_flops) + + if world_size > 1: + dist.barrier() + dist.destroy_process_group() + + +def parse_args(): + parser = argparse.ArgumentParser() + # Training Meta + parser.add_argument("--dtype", type=str, default="float32") + parser.add_argument("--max_iters", type=int, default=2) + parser.add_argument("--bsz", type=int, default=128) + parser.add_argument("--seqlen", type=int, default=256) + parser.add_argument("--dp", type=int, default=1) + parser.add_argument("--tp", type=int, default=8) + parser.add_argument("--dataset", type=str, default="shakespeare") + parser.add_argument("--eval_iters", type=int, default=1) + parser.add_argument("--eval_interval", type=int, default=400) + + # Model config + parser.add_argument("--vocab_size", type=int, default=50304) + parser.add_argument("--hidden_size", type=int, default=384) + parser.add_argument("--intermediate_size", type=int, default=1536) + parser.add_argument("--num_hidden_layers", type=int, default=2) + parser.add_argument("--num_attention_heads", type=int, default=8) + parser.add_argument("--num_key_value_heads", type=int, default=8) + # parser.add_argument("--hidden_size", type=int, default=4096) + # parser.add_argument("--intermediate_size", type=int, default=14336) + # parser.add_argument("--num_hidden_layers", type=int, default=16) + # parser.add_argument("--num_attention_heads", type=int, default=32) + # parser.add_argument("--num_key_value_heads", type=int, default=8) + + # Optimizer related + parser.add_argument("--use_DO", type=bool, default=True) + parser.add_argument("--decay_lr", type=bool, default=True) + parser.add_argument("--lr", type=float, default=3e-4) + parser.add_argument("--warmup_iters", type=int, default=100) + parser.add_argument("--min_lr", type=float, default=3e-5) + parser.add_argument("--grad_clip", type=float, default=1) + parser.add_argument("--weight_decay", type=float, default=0.1) + return parser + + +if __name__ == "__main__": + parser = parse_args() + args = parser.parse_args() + run_mixtral(args) diff --git a/examples/mixtral_4D_training/sharding_plan.py b/examples/mixtral_4D_training/sharding_plan.py new file mode 100644 index 0000000..523827d --- /dev/null +++ b/examples/mixtral_4D_training/sharding_plan.py @@ -0,0 +1,69 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +"""This file contain TP/SP sharding plans for Mixtral example code.""" + +from vescale.dtensor.placement_types import Replicate, Shard + + +param_sharding_plan = { + "model.embed_tokens.weight": [Replicate()], + r"model.layers.\d+.input_layernorm.weight": [Replicate()], # MixtralRMSNorm + r"model.layers.\d+.self_attn.q_proj.weight": [Shard(0)], + r"model.layers.\d+.self_attn.k_proj.weight": [Shard(0)], + r"model.layers.\d+.self_attn.v_proj.weight": [Shard(0)], + # TODO: buggy, cos_cached or sin_cached can be updated or recreated if seqlen exceeds the max seqlen. + r"model.layers.\d+.self_attn.rotary_emb.layers.\d+.cos_cached": [Replicate()], + r"model.layers.\d+.self_attn.rotary_emb.layers.\d+.sin_cached": [Replicate()], + r"model.layers.\d+.self_attn.o_proj.weight": [Shard(1)], + r"model.layers.\d+.post_attention_layernorm.weight": [Replicate()], + r"model.layers.\d+.block_sparse_moe.gate.weight": [Replicate()], + r"model.layers.\d+.block_sparse_moe.experts.\d+.w1.weight": [Shard(0)], + r"model.layers.\d+.block_sparse_moe.experts.\d+.w3.weight": [Shard(0)], + r"model.layers.\d+.block_sparse_moe.experts.\d+.w2.weight": [Shard(1)], + "model.norm.weight": [Replicate()], +} + +fwd_resharding_plan = { + # TODO: buggy: attn mask is torch.Tensor, in training, it's a None + r".input": {"input_ids": [Replicate()], "attention_mask": [Replicate()]}, + "model.embed_tokens.input": [[Replicate()]], + # No SP + # r"layers.\d+.input_layernorm.input": [[Replicate()]], + # r"layers.\d+.input_layernorm.output": [[Replicate()]], + # SP + r"model.layers.\d+.input_layernorm.input": [[Shard(1)]], + r"model.layers.\d+.input_layernorm.output": [[Shard(1)]], + r"model.layers.\d+.self_attn.input": [[Replicate()]], + r"model.layers.\d+.self_attn.output": {"attn_output": [Replicate()], "attn_weights": None, "past_key_value": None}, + r"model.layers.\d+.self_attn.o_proj.output": [[Replicate()]], + # No SP + # r"model.layers.\d+.post_attention_layernorm.input": [[Replicate()]], + # r"model.layers.\d+.post_attention_layernorm.output": [[Replicate()]], + # SP + r"model.layers.\d+.post_attention_layernorm.input": [[Shard(1)]], + r"model.layers.\d+.post_attention_layernorm.output": [[Shard(1)]], + r"model.layers.\d+.block_sparse_moe.input": [[Replicate()]], + r"model.layers.\d+.block_sparse_moe.gate.output": [[Replicate()]], + r"model.layers.\d+.block_sparse_moe.output": {"final_hidden_states": [Replicate()], "router_logits": [Replicate()]}, + r"model.layers.\d+.block_sparse_moe.experts.\d+.w1.input": [[Replicate()]], + r"model.layers.\d+.block_sparse_moe.experts.\d+.w3.input": [[Replicate()]], + r"model.layers.\d+.block_sparse_moe.experts.\d+.w2.output": [[Replicate()]], + "model.norm.input": [[Replicate()]], +} + +mixtral_plan = {"parameter": param_sharding_plan, "forward": fwd_resharding_plan} diff --git a/examples/nanogpt_4D_finetune/README.md b/examples/nanogpt_4D_finetune/README.md index 33ea83a..0a648bf 100644 --- a/examples/nanogpt_4D_finetune/README.md +++ b/examples/nanogpt_4D_finetune/README.md @@ -36,24 +36,29 @@ python3 base_train.py config/finetune_shakespeare.py --compile=False ## Loss Curves -Here are the training Loss and validation loss curves plot for fp32 runs that last 200 iterations: +Here are the training loss curves plot for fp32 runs that last 20 iterations: -![figure](./figures/nanoGPT_finetune_4d_val_loss_fp32_200.jpg) - - -![figure](./figures/nanoGPT_finetune_4d_train_loss_fp32_200.jpg) +![figure](./figures/nanoGPT_train_losses_fp32.jpg) For the bf16 runs, in `base_train.py`, instead of using `torch.amp.autocast`, we cast the model to bf16 directly and both the gradients and the optimizer states are casted to bf16 automatically. For a fair comparison, we modify veScale to store both the gradients and the optimizer state in bf16 instead of fp32. -![figure](./figures/nanoGPT_finetune_4d_forcebf16_val_loss_bf16_200.jpg) +![figure](./figures/nanoGPT_train_losses.jpg) -![figure](./figures/nanoGPT_finetune_4d_forcebf16_train_loss_bf16_200.jpg) ## Difference from the upstream nanoGPT -1. When training with bf16 (`--dtype='bfloat16'`), the model is casted to bf16 and we remove the usage of `amp.autocast`. -2. Sampling mini-batches is done at the 0th rank and the indices is later broadcasted to other ranks. This ensures that both `base_train.py` and `finetune_4D.py` works on the identical batch every iteration. +1. veScale enables EXACT single-device abstraction for multiple device training, where even random operator (e.g., Dropout) with Tensor Parallel achieves exact training loss as a single device training. This is achieved via our veScale DTensor and patched torch. Without veScale DTensor, upstream DTensor does NOT provide single device semantics on random operators. The comparison is as follows: + + +![figure](./figures/nanoGPT_drand_train_losses.jpg) + + +In this figure, the `1GPU_fp32` curve is shifted in y-axis by 0.01 in order to distinguish it from the `4GPU_TP4_veScale` curve. + +2. When training with bf16 (`--dtype='bfloat16'`), the model is casted to bf16 and we remove the usage of `amp.autocast`. + +3. Sampling mini-batches is done at the 0th rank and the indices is later broadcasted to other ranks. This ensures that both `base_train.py` and `finetune_4D.py` works on the identical batch every iteration. ## Caveats diff --git a/examples/nanogpt_4D_finetune/base_train.py b/examples/nanogpt_4D_finetune/base_train.py index 70980ad..a98c7fe 100644 --- a/examples/nanogpt_4D_finetune/base_train.py +++ b/examples/nanogpt_4D_finetune/base_train.py @@ -131,7 +131,7 @@ if master_process: os.makedirs(out_dir, exist_ok=True) -torch.manual_seed(1337 + seed_offset) +torch.manual_seed(1337) torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast diff --git a/examples/nanogpt_4D_finetune/config/finetune_shakespeare.py b/examples/nanogpt_4D_finetune/config/finetune_shakespeare.py index 295e2d9..a7bc8c4 100644 --- a/examples/nanogpt_4D_finetune/config/finetune_shakespeare.py +++ b/examples/nanogpt_4D_finetune/config/finetune_shakespeare.py @@ -45,3 +45,7 @@ # finetune at constant LR learning_rate = 3e-5 decay_lr = False + +dropout = 0.1 +compile = False +use_dist_dropout = True diff --git a/examples/nanogpt_4D_finetune/exp.py b/examples/nanogpt_4D_finetune/exp.py index a4f66e0..464b497 100644 --- a/examples/nanogpt_4D_finetune/exp.py +++ b/examples/nanogpt_4D_finetune/exp.py @@ -19,6 +19,20 @@ import re +def parse_train_loss_per_iter(log_fn, name=None): + lines = open(log_fn).readlines() + train_losses = [] + for line in lines: + if "iter" in line and "loss" in line: + token = line.split()[line.split().index("loss") + 1] + match = re.match(r"\d+(\.\d+)?", token) + train_loss = float(match.group()) + train_losses.append(train_loss) + if name is None: + name = log_fn + print(f'"{name}": {train_losses},') + + def parse_train_loss(log_fn, name=None): lines = open(log_fn).readlines() train_losses = [] @@ -48,11 +62,12 @@ def parse(log_fn, name=None): GPU_CNT = 4 -DP_SIZES = [4, 2, 1] +# DP_SIZES = [1, 2, 4] +DP_SIZES = [1] SINGLE_GPU_RUN = "python3" -MULTI_GPU_RUN = "torchrun --standalone --nproc_per_node=4" -CONFIG = "config/finetune_shakespeare.py" -LOG_PREFIX = "" +MULTI_GPU_RUN = f"torchrun --standalone --nproc_per_node={GPU_CNT}" +CONFIG = "config/finetune_shakespeare.py --dropout=0.9 --learning_rate=1e-5 --grad_clip=0.0 --eval_interval=10000" +LOG_PREFIX = "no_attn_dropout_lr1e-5_dropout9_gc0" def run_exps(max_iters, dtypes, run=True): @@ -64,22 +79,50 @@ def run_exps(max_iters, dtypes, run=True): log_fn = f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log" print(f"run {cmd} > {log_fn} 2> {log_fn}.err") os.system(f"{cmd} > {log_fn} 2> {log_fn}.err") - for dp_size in DP_SIZES: - tp_size = GPU_CNT // dp_size - for dtype in dtypes: + for dp_size in DP_SIZES: + tp_size = GPU_CNT // dp_size dt = "bfloat16" if dtype == "bf16" else "float32" - cmd = f"{MULTI_GPU_RUN} finetune_4D.py {CONFIG} --compile=False --DDP_grads_in_fp32=False --dp_size={dp_size} --tp_size={tp_size} --max_iters={max_iters} --dtype='{dt}'" + + cmd = f"{MULTI_GPU_RUN} finetune_4D.py {CONFIG} --compile=False --dp_size={dp_size} --tp_size={tp_size} --max_iters={max_iters} --dtype='{dt}'" + log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}_old_drand.log" + print(f"run {cmd} > {log_fn} 2> {log_fn}.err") + os.system(f'export VESCALE_SINGLE_DEVICE_RAND="0"; {cmd} > {log_fn} 2> {log_fn}.err') + + cmd = f"{MULTI_GPU_RUN} finetune_4D.py {CONFIG} --compile=False --dp_size={dp_size} --tp_size={tp_size} --max_iters={max_iters} --dtype='{dt}'" log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log" print(f"run {cmd} > {log_fn} 2> {log_fn}.err") os.system(f"{cmd} > {log_fn} 2> {log_fn}.err") print("train_loss = {") + for dtype in dtypes: + # parse_train_loss(f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}") + parse_train_loss_per_iter(f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}") + for dp_size in DP_SIZES: + tp_size = GPU_CNT // dp_size + log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log" + parse_train_loss_per_iter(log_fn, f"{GPU_CNT}GPU_DP{dp_size}_TP{tp_size}_{dtype}") + # parse_train_loss(log_fn, f"{GPU_CNT}GPU_DP{dp_size}_TP{tp_size}_{dtype}") + log_fn = ( + f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}_old_drand.log" + ) + parse_train_loss_per_iter(log_fn, f"{GPU_CNT}GPU_DP{dp_size}_TP{tp_size}_{dtype}_Old_DRand") + # parse_train_loss(log_fn, f"{GPU_CNT}GPU_DP{dp_size}_TP{tp_size}_{dtype}_Old_DRand") + print("}") + + print("train_obj = {") for dtype in dtypes: parse_train_loss(f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}") + # parse_train_loss_per_iter(f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}") for dp_size in DP_SIZES: tp_size = GPU_CNT // dp_size log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log" - parse_train_loss(log_fn, f"4GPU_DP{dp_size}_TP{tp_size}_{dtype}") + # parse_train_loss_per_iter(log_fn, f"{GPU_CNT}GPU_DP{dp_size}_TP{tp_size}_{dtype}") + parse_train_loss(log_fn, f"{GPU_CNT}GPU_DP{dp_size}_TP{tp_size}_{dtype}") + log_fn = ( + f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}_old_drand.log" + ) + # parse_train_loss_per_iter(log_fn, f"{GPU_CNT}GPU_DP{dp_size}_TP{tp_size}_{dtype}_Old_DRand") + parse_train_loss(log_fn, f"{GPU_CNT}GPU_DP{dp_size}_TP{tp_size}_{dtype}_Old_DRand") print("}") print("val_loss = {") @@ -88,9 +131,13 @@ def run_exps(max_iters, dtypes, run=True): for dp_size in DP_SIZES: tp_size = GPU_CNT // dp_size log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log" - parse(log_fn, f"4GPU_DP{dp_size}_TP{tp_size}_{dtype}") + parse(log_fn, f"{GPU_CNT}GPU_DP{dp_size}_TP{tp_size}_{dtype}") + log_fn = ( + f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}_old_drand.log" + ) + parse(log_fn, f"{GPU_CNT}GPU_DP{dp_size}_TP{tp_size}_{dtype}_Old_DRand") print("}") if __name__ == "__main__": - run_exps(200, ["bf16"], run=True) + run_exps(200, ["fp32"], run=True) diff --git a/examples/nanogpt_4D_finetune/figures/nanoGPT_drand_train_losses.jpg b/examples/nanogpt_4D_finetune/figures/nanoGPT_drand_train_losses.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8495ffe15cd84b27a72f2403fd7a914a9d46567f GIT binary patch literal 23677 zcmeGEby!?k)(4Cpf(8u)f;$NW4epxY!QF#}Km~U#BxrB~1b4R}g-e1k*0Vc42;C>o- z37{Y$AtNE8AR{9`e29XIj*EeghK5dv^B4=4l!%;+l!%0cf`*xvf{KxvgoKX!IpYgf z4h{}-T3$gOHUVaK4z^!z0{8IYLv%EB0t^fSHcApow!i#${}sSSfjfaafrq049$>@4 zW5eCI1LObzhX@<(F9ZJjAKU|Y1Vkicl!vHjunQoMfCq5!@DC8+5fKp(U{`y>&I1V8 zh&Yt2qDYUG43Vj9aoK#LGEk_+%D>?$4z%#NNMtwrVvH0}x7t(&E>^~#S_rId-AB6prE+~Kw4+mo&JT@Q|)2O;nl zE6Y7Vt$bV42|j@vx(5zye>qk}mK3*7vph_@1ssHLNM!GU?lI5*yX61o<^LAL?_Uxl z7iZ1paRXI9rHY!Wy3xIk%95sU3B;_D6qF&9n-rw{>azf@lK3B+#)B)dmxov~qbIwP zS`RJJ=I4VQbrO^!Z=xjBW+D=5xaFg+C-Q<5-RE&Oh;U|SpisS%{*E^7UZbdX(Ff^{ zZFJ9GCqMGi2DXY)4$p65T5#@xAI$6bfHr;s?@hL`vj=l zwdWK%wTK*4UuVvwJ|b#e6600o`zxQUVXn3fz$K-PKu{M2<{m0hvi~mYSMxqI#{8I9^7+ zk~l8#FB%+fMFktxv-S52k9~aRlA5>gT*dDJ&$NS#dtmvB<(&93B-NW@-tis~$h-%d z2QQTVc}c-WItuqJc#y7Tzw!>(y8a$WRRYhO-O#(yJ-|3OIPPV+-nj?-iT7I8KN*}) z-2sn^1gWA3BX@ic-GV?<{pl$Civ;NrNPm9?+h?2Q>^y7m;C;Z$d^%D*ogr-MYruE`f zui?L}-0nQGos1H}|HrLJt8FXZx!nFpc*gR&)B$9!+U7yakt5tiV zD0@~*Uw88?z{XdU%CFMS3f111mn1a&6I|Utry|R9^5*E{-L>$dTHeHaj$1yoe&xH7 zUS2#e_ux3r83`LSqB?|wS&^9OK9ZqsZknhsriNQ6+al zAIzZp%@Z_`%?_zm?m9*}XI0u5K$-JZGR$wr5c;t`izsg>x%2=FiCa~c?$-ut6FXl0 zt6X<8&YtN4y zOt7A(J@uE@NeJNgj~SJ?F1uXCzdZxD=4JkUDH{K}6jc`E`&DA;dz~s$JG1jC?g2g9 zyL;dS7cOwI>=MhR;jXpd9>{hYAD-m>sbdIE99Qw@hU6F?W^Guwr5!C$p}U0{#vmmM z&2Lb{2Mi@?TNqsnxrK0j*IHZg$4)7o-|U=Sc_+xV#&y4s&e42i>zi#O$lw(M3=BJf zEYTFMpH|8~+axo8n@}Uo7-~1rRY!ZU!&AIZ4({PUWV!X-1D~j9i0kf0w5+6D#NT~$ z7QT9(51;<4IFQx8W*xT1IfVa4^Ex!}qJlMIqBQRca)mlD;ZVn^X&h`Nkgd)rF|#Mw za0^0L$}cfjy_^`%Vr+cimY>4bLi=8P|o!$nux}{Ra#+RueVKXjVTcB zgZ(gvG)`Penn4hwf4R35vd|QC{($H-WTm&Lx0KMnh%;R+OXkCE>~sBj|5Mq$>I`vQ zjqJC7DrWyKdPVMm!(068u$vA$qJV@$Bsq!I9miQd28q~E)<%(vTo>eXvt;*4{eN7h zpu}VE2&T@%2zDU0V0FGTNn5iKo)Ug6ykRtQ5A1itB##5wTu9@qnRHY1 z^6HlW7BYPJ>CA*x#^UZ{_0Mwo8&{qyMZ4D&P4BZvK2Vh6_Me2c01m^OMDWazt!?6uc_R3NwWDJnv7470wpGro1MG6n?{CE znkkLbVN*)O)5@Z+C6B0{!2|hSy&|*-ZljsyRX(Le#G1z3OwD#tK5NxLpYc8kz~+t$ z9O*p)fj&kA;<)%^y2ED*PeE7w2}r3*mR|BcNceWBASx_i2p;r?ETxY@a46jJ4d*? zHMeAAh+_*@ReAQt`4DqK-h9{7n@YOT(f3@}zA-qZYhRIR36~umxS{e%g*r^t8F^5> z>4M%4Th#8X4HT-J%8IJd-8g9P1o1oky$*CkTu)%ho5x+Y&l>ASd$O~L!>Okj3hgoG z_>M>3t1cPJx@<#3;`ps!GH)xs*}ciQbrRluzg_uux<1<6tktEc;@;Q{ByfjH^=45o1Qo{DEr} zfs!+`o42Z8_y-M3OOUmxB`G8uThtoIKb$U-KAqH?1WwJ7^s)4WK+CQmbMJNEb((pN zw9_p&dTV4*AGH3m0Nlolc_qzCkDpWA`q^vnWIg?H01HlaRm202ho zsIQj;?tw0!!ntP^pd^>4;ZG8x&MN47Oa&!3uO?XzD%Ny_cOx)nnoURdq+h2N%3~tu zN`My}*M^g`J|@#tjm$e=&dg=Db$&7_+kKiB7k@dgClIw!r6leq@w15k?J)w@JwQGT zuATp`dI}F4z=^>g_c5N#wQ7A;9jkb*oXuxsH#73?3AG}zP}jVgEO3;=W^Y!NoD2iE z!dNNus$-f$+_;*vb!0NRmG76p;9Na@vsdQn*3mqE#w`f8WA}jea!I1$rUZM8Q%v>X z(ac;bX7O~0c2-$tf%$o*zQitPa$JY}gJv)5Ar+2Yv85wPt>UNEsBCEKX+s*XZ8=u^ z%nmKHHjhTvOg8%UZ@I^b@CLS;c@FMA7WCv9=nd4^c2?Ajn{i~|Rr+c@@)dUWV0kvn1&=5g)%zwY`N?LohWMgvLyDR1A!h01W3E(Y|FBO_2jiZQ1k=FhYKyfA z1<}F1zoT(6R;+t+R-;iNyB35(WIBYpPR{U{3_pK6^lg3p5z8*aG%*&=^yWE}xlqpX z(=XS`-r9&!l<0k_?4dq;pv;Pp{Z^Q5&9O@+sG~T>ajV?xO2mKpeILCX_cC^n;)NaS zD3v#VefW}Nb#>h0x92$cZ<=h7T*^p>_aG_U3%e~JO(r1>CHgMPDA2XVRDxGAC~cnv z*y;Fdxa2)+OV3!)!A9oiljPHG{l}2xplx>-Qwg4dg*<-&FeWl>2rx73sL45t)29xR z#zE^=cbQOxM>LFzd`Ng>aq2Sn(HxxNA#?x^GMeP1{oG@W8<%T+^43>E^UPI`4KJ{N zlA5SrOoI>b!JdqQaK}QjT<%9^WszIv{rnt zdh0sHk?n8KN3Z|>qZ#Kad_2xPgr2e_=F2iI){m|Wnu zU#CcWLUF)2(#87$!bkMF;2+A@_kawU_p{22_(g_I3MofPC7U~4WjS*Wr3Nzs%nf|r z^U{EAmCA~y%F5a#(~fO6o!k-USmpj98vr@R=Q8973uB|1=Fqv-le?`Cl&5X}IBm{A zuuow6iBZ>kLjWTxT)V-wB);9?5riQ7?KM|?&TS0EQKjz9fMjjh?U(vI{ z?L+I{LC7C_a1Ar=o+dHzD;?~{$AZqbx!aVvz_Lnbv9xJ7SCLlt%&(Vq$)%YBW<;R8)^zWglx# zDkkz6rD&|aq_{*g@rC2jPtknaIJ$%!RAs3T#?hR&sJ0Q$+MLAJ(U#*YnwjfC!}#{F z2c-n!vYOI_`2NE8s88tv+uq*X8F#oXmzuGh=S@W?Zn&h)mN2#mS%jkKUH$UyJ+Sfo zZ0|_t)e-Zg31VPI7e}}7q%OGS77;MK0NuTuW4;Hz6oNNNAr6iMF2C<^_*dQ{puxYq ze@W~06!u%a!MhNPjXyZZfWyOo`nI@Rz#yUqN{)?#a- z(tqJ05RWL8!a=|YxHiw`>5h!0V}rB9H{{w z&NuqjW-Et z7b$~En>M2t>A^%ww#0OE--9gS^Z6p~rS`N=$h&k*I|T&u$D8#XAMSy7air=e^D?hL zzgwjQSYv@w0SI-$fH{=q0~3d*yqvu3t)%i889Aku5ZFVse4OIY8$z^i?uH5Tnkk1TLX&>@-YBJMCuFVs+C@allGzQfh7w7j5; z8kKFvMu)bn#6eXNS!fPxbh;Jl92h!jVDbgz6nxI|c}|vrdAWupmhdaNzjZ}XfjUq? zKygA97Xq{&q(;~;-Xd}&WQR*xnm9lljTEA4Kew6Kd?jr(@kl=}B6<>NxwtF1Q$mxl z?7CX}O!?s~b-dwM64P)7`4Hgy!5Qpy6Y8W?Wwbfdt`ExSBi+@Hb|UZvCx!LgfzZ}T z4V}8iDF}a0=~I^;WxQrpMYo&$LiQ~Nrx0LydT}+SPb4wCu5?RXvAcpPcJ+r#!lzYd z!9w3V&g-pTc z7h?1SN{ytWiRaTz^#X-c5Em|?XoJWLl53zj^EIR9)**N2sLt8}*%+*?gD&Al>iQu?r@mhg47h$9D66al?5J zQ10V{f2awcghT83)&+j*0O&G9{@BXs4oaKy4Va^r^o-_}bv%{#dqUL4en=AYvt<$+ z`cEa(8Oc8Ob(=I?o^B{o!0vA61nd+%ijaRCHR;i zRsHSZD?u#G3&+KKK)7q(thwFJJSXQ}@Jjs)zzf0eLZewD(Pdx3lliW%lkZs%Gh_18 zg;2jHH`DSRnV3lgH;weongfe^%ck2(tAHg|DmQO;a@&zSlus;JX@xvrYPA$U#HPf| z=f&T!T7=M*wx$2+q{A5D^AE_(cUya;`0b0?EhncVpOol<{2y13Ra$Pq;&;nnRxQ3V zKM&!Oxukwcg~u~^sJgw|UpK!}s17F1Z9ZPFsRLcc-2>G-r6qgv@s8$Tdt5oHZy2rR zM#%7?%e3%9r)!5EZ)@WgDeu~XwUSx-ovYB3?PLxwCet_W9uJ+MB`|X|U|TZ=mw%Sp zjUBK(e!RhuB#h3{bED%~Fh|vIXv3ZJ<*jl>OdO3b&z8s_*kZk~sH#pcO(m|P9YQ7% zvwh3q6z3KemmYei5qC}>h;A6G`~3V{?$lPjQ~FMv##s$x79(YXuFt-o+6frT+zo-P zEeDF`vT8X1Uukl3M`gi8+zMq0T#M6V?$&86#OsclroKavv#L93Y8;5&Lwvg^lNbMuik`**UUr$o&n5?&Z1_23|_gJsHwR9ve zgVzShqB&z}_w8}(-P4-#k-5Qks0sc2TS} zIo=d5knswnj=@QOs}Uj|gYY~X5dam<&$bu8t!QjEGP#uE&#*!H#(OCO_-IF7w5g4= zlm_@OR=13nSP?;Gdsx)TS;rP{swnn=6_Z#5zmUGhQOmpqZ?+JljIMY(#OxjIu; zM$uZh&&iZsO;Soy?IdN=Uez43e2|-nuuR8LPyqFqK^nu4c&cf_b&Eshxm3$nMY}tI z_uKVrExB^+nl2^=Q>{r*W3MRo!)R>O3-fA|tB9F^k@n|mc_G}aK?#WL&7wGRQW~cfO?g53v2axuP=`AD|r4arn zXmsEthGY|@1=Ld=I(2C09D+ZE=8`J5*fTn;Ls9kwW9y8$wyoUa3v217FNfqIrF$7i3(Za8I&bIl+=7&( zzw(m!@^%2OO>1R&*#XOQIeE46z7dBd5=Q`6vL(Bbh%g}5=XdPO!-JL_-n7*p=wD_jppzkIPc%FA z;~i3~PCUKyYcfWU?DfcE`-o2O=NHuNCCGuNNqyVXw~~^OF>O+z({$I>L=qNZsyG{h zs0NZdyz+Yh?Xk$+8q9n31-0C9>p)XtJN8m1Cfq(i9kRnJzXp9zC-dfE7?h9oy9ZFJ zTke6EbFhoOtJ7{QylD^hbL=RqhB@lA1y$6IN$kpWT)e+mx?8IF+>J8X&kF__RS78|E8hy^m{uiObCOoT>AK zr-zqk+h+FdCrukp8A^6RJCgXwrKfPVrLu8NVG=}7f>Odiyti|dfh=u)8{NM zEQUz~m%G^n)2qogStJjzRFcu|q`E)Br1E zZ)BQrh5rD>#@(_27a64IDZ;5i&^|!Eq?#kGstL1{x^!rzdrP2n7bMz!RB|#E&mJ{H zPM?urPVHTN37vZ$@=6fX<*xR|BEr{d{h%hq+JiWs_eCXPF^yFy77w}7V7WT!_R!j^ zVC_Yw327zyAz;SSrxdGG+x3vl53?nAjZ{ZQ=pEYm7fh)e)><#>CYK+&{NKct?AM|j zv~}c*PWf%|7>uB2e9y+mPpnPBF*7H84ESwc)K!O=f!1J>sG{0wM)x{tWY~*`WyE07 zl;0z*N22^OG+BQGru9wyag{_6e`uQ-JSEqqHlmp)72AeF$G|2v zu$UP3JuuG!$Gmze*mBNlmFlO|7Yh%^!-a|{^z|3Jk%`@QqE+- zmy6`L$I6BN03Cqot%%J3Q<&&KiyQqn04b0eLI)IURWY5P^UK@@96bwPV?6Et*0p}K zZIQreq?yX5iKP5c#g~C-j9Ps_R7a}B&?2B44!mky z7fUP~iOX7MVFs>Mc-~&ACzkcg*c15vK3k4;vUfYPP2|=`Cbwd74=|baPf`=jLFQ#( z>XCPuA!i8xTQKOap0iISx(6c9Kep`9r(H=}>XH7T<9ERJZ*Pw`aB+PnJs%9-!MDE$ zG860R*F8M%0h;~xmNT@vdtg{jg!I9$Qbyw(0qO>@OipE&A4mB-KhSm z+fP2VPuI1?_cq$*b)h_0&6j&XaQnN9)M9olfqV3bv=`Ruz!;io!1W~=rwp~$gl)O z%eW6=>aElfj60=oK1>>4OkMA&8x(ZMzFTu@;MQclK>pbL5yFUo@e5UgaKcO>?eI3l z&se?n&j+eZmY+I=%di&&cl4{R=8?MbyE>a?4(pRHqX>ZWBxVVLbp25YCS|i>qJG(0TmpH2$Q%6?4a+Jc!DEuitl2_|NjET;gJ5dYXV ztT?v3@5V#7P~fDHA^gDcL!xky$II+*Ogo8cHB*T*voc3Z5)y2PPCe9nBxXR2UKoI@ zk#J$$>6|z_kybdTZxL7Tm0>fiBf*&YB-HL3HM3|unB3i;>1%0Pmk?!ATOzKC4*XeC zLv8V*A$DXDZ>JbmqYWmqeF(4#`a{0RBB=~0x23YNX>7H*nr`GxtcABybjK;L$Wy@!dlsB6%AD`Q&Te6YHY@22|Q z`|&s$`5-#hg1Cd-_DEZ9pTLAA!qzuDnvHfC3cUn*eh-ujQEUp)2CkHktLo?-x3=SCOpJ5JDfI)^dQB2q;(&8ap*K+UDiSPv;r8uP4u-N`_iyo z0y`2*W3#tkY)sTk`Rj{XRQB_CCmkOjAY!qF>K=oALSE>1O&E6rn z%pg-j#p=Ka47@1jS@v|o>vd638YfntsA4i&nmgK-dr{lu%fc`qtr|%RN5KQN4}T%< z=m6KHbwQ+~Vrt@Oc}40!Kh0F^{{^#{puR}9;gdQq{5t{k7;V11KxWP`K6g;l>x8&d zYeq9ETw>)=BF;{m3yPK0jG(1@R62O9S3{*7w-VQ?W7ZSbN{yP!uGD@#Vbl>$8o5mbw!B5-wjr9ry1i@7Vn z>zMP2Ju8Z{GIJe=MJYm=O_;pIO%88uqVKtC{$MP+$}laITwg&pOWxG)#& z@>+Ts8mbyX&c>N3X|hfRm2?ct6$`F&Z%^yD^9&T&u?*F5m+4cr?*Ri>L0l!h4?i1w zfAyGTJ)mVpaB2SIzr{mvSllIVg-zq(?OP0ON+JH-yzH7CdzHgKYd(j_E zJneOy<64h;-BlQeg0tRzkF1k?9crrSKrKMlStopN0;VrNTfy376}-uz&M~g z($V5f--!W!Pz1MqW2H1*jBV79=&lul(#|ge{EH027_$a?qVD-~y81yA=GJ?!_g=YI zdn*Wcg1r8wDSuX{dU!vhf$5k&zXp5@3Vy}>e^1Td?PM~&Ll36p`e%8#a%Vv+j>o{I z6al654NR{+D8UNUe=cXVwbx1PI>9p}qxTh>M`7#-3!T2y6An4SJtPlgu37_R`kJ2z zw^3Z@U6C7s2JT2Xnh6gD?#Rb1Z$3_mFdZXdMWTEovAtEuak)s8No6JoXl%aaWc^Si zQJI^l7}vr%>{>aQ4>MdaU;zD}!vB9368N(|O#3#2BeFy1TYr^%Kx(%rtMuoN-`~C# zRHD-vfA?a%s{er;+sG+FrNt;F%1thhLm^0Ah0-hq+-w#aFYX{!rg?!r&dB>(!Ie zwZx4IiSA|kufBrfX6hsdwzwIDMY2Y@G5O4SxgB-@ro_}E^7`d`n6LF-4SIA-bawu2 z{?W4$2rtRzRgqOq)F9k3#LZ!VEA8ST=N4Yvlj1^Uc`&|#at@u0-O3b~Z^-*-8l|f6 zZ%PuCn1V{8e)f|Fh%RfZ#UU#S#Q#x z^YZR)i;zbD(T4w%S_pxAU@0}8IV-2NyCIRWQIjNqRKo3u-_1uiJm%0^Tc=M7f>l8* zZu+>J(r8surOk8sWLv+ZU~{uzD#^(*?ZulP1F(Fu?YFfb$?M+r3u$J0(`7*IFeP{f#LCYfPriT8)Z|O-= zzIj6lGZ7uJjR6~bd?dQJ$Q{Y@aVg(uC*MSkD=1g{hj|gl2tB2gCVBhb=C=MiS!`Yh zb_hS3YFK?{YdU!wUOT7-jfv$HEk*Km0~fwr`9e6yG>H{oDlg%3-pfJ3=M#N7+|T=u zxm)r~2nr!H=8XKdwim^_VBrg=fh#!r>CR{Lrl|0-8*3bAc*&T?GeC~A`f%jtFGLf;}|P@-%RmNn_ah941|_%{I)k8 zF3BN%9k~=40n-MD!ix_tqYQxf?a=mXBPDv(Zh^35syACFbQ@+A&#S4utK57gF5VloqLu zSD>NuSs#eL_4aDzyTE4zFYPyo5#i~N6s#LvvbuV)+6Hl=k-!yu zu3fim78~0?`x`%p-~W|8Tn6T!*6e3gs*0w+Da@itozZxjj@HMdgr+?}9lv5{EMV5M zu2+J$zVrpN$whq~6k~dMyRJqelP9mZ+VTRHRlY%z$FfY0q~nq~ZX)sIDjA;ZoI1j@ zBs@zvXp7&kRrF9&`(TTdT4N-7>1gcIk}}?q!eIaod5dqdd@|xjYB->>xM?<|u>XOp z9sGo@oBLTA%Zm~c&ceL6C#*#QB|N<@*1;Axut-sATKX3I7L+tM!xnU=#r*jDk+T&P zOWuYrl)SSFlyJZuoOQ?@xExzuvf4~l7it>-)2d$t!K~1Fixcx6nTL7yQv4nKGq8!M z_^bN=FAMt~juHm_iAJrgfq1Ut88nU9I7PHw5#@KCJoI0gT?r$OCk~58bS|*kRbEne zF}j)dWo13;ZGA@J%ZIJ+Y>O4>*h4jTst}M$*Vr6Q*EVx0V%>^}U}XCPa|~`HmF%E* zLq&hKKYw;h<2*;;x~L>0B%Kl$5rAcbIzFobg&mNlMw#Dc+{P~IX^73PUZz3e$(*>2 zt@p|EcsI`jwRiG=g`IxN-TTXC{2!7-R7l~}p#?1@Pe|Uv;FMC19aqGTnN-X}whLdn z5a`VJAq({l9(9_OuTFcC^=aCtwa-G<=6Dw!qo3pQ@g)Z+v!N7QDztGyv*U|kX6#0d z%}(44-q|Ui9Z}#Grn}zO&pmEvI4BF+qps00(jQq;fsj7jc(~&h(eBjY`a>8bYNrm} zH|V*UT(S%)*`;UJ%)uiInb-%nNM5BTh%1kthH-w!P(f+XGBHFxebr-6+39 z^=u?s$ztHHc^{S&z_yN6tQbdHR8x5S8q2?44HD@0ram}!waJ*VR4`Z(X0|NJIgPJ_ zG1mCxvx;5R&AB1PE$?cKZ)~}EZ$ii~u*+VR7PHGom$RTB^eD%cMo-N<+GK_>*!~q? z>#zA)|AXM;t$7g~rmWRC=55u)C+3u<Dkhr=$bzx=!S6RN_yZHz@Cdx=gR9^6h zlb>X4_ypMQSVXj2W50{@!aT%GurJnpnJ=9yUjTaA%2=$;h58^SCUvd~Zt8alwZFc& znJ~=we`}?n4kOo_k1m$kk?3`fVz$%d>(F{_hIqn&tA2cmvpNWBb7y8snQ`Z-M%}@D zaFsQVG(vfMDO0IdWLTfyWlB!@-9ni$0$YGG3N=OeX~dEzDQ5|p6d5b?;2cU2vMX@7 z`XXm;w#HC|<17MYttOCLPdGap~$ z4BoachVgTh?YHaj-+kvD!~l1@cg&yRDcn+5mxkB67g>+F^BKt&JoIx@`+NB@e4Y7j6D`#fnwfu8m+mLvzuR02~@K7wI zrhu|Cz1cj3bLY*u9t`ZQwQ3rv$vRY2fK0AHq&Ij)1BSX(>na?96O zMvym~xPHMLk`P-6zu3MTY_eIyPD)>ugUVWW7wO-WPbzF%t-s822}i1tV##4( zJ6z+IAX}fPfBL4EP>Ledmwz(Z1xp)psJ6NyDJDf3K%uW0!SEQ{DoxbDZBYiDIFQdvFvtOPgZg+)L$2pKegRfBA z|J^G@((nRI4;|l+mmjgRlIv*1EU$?>$AEAzuMRHjxxBU%P-6ip7^=4oe7a52@Y0hV z(|~2oKEBZ?+^U{4xcsaks9*jPz#wM)A89Y?) zZ6&mXBo)aSYOJLq-Okw$VG@%f8?;tJEJ0n(NkIua7yf0)KB3F}+i~Go$F7~QpWhL_3hew4P4{!&Z zSgOI1CnbIP>4{S#t7;jYI?f`3pkUJh+!m)X{cH2%yDIk9!jim@nE1gxaA`f6!7>XX zEC;>YLu+voo3pZuWhzs&TowHZ!4{v1n^l?k5G06m7iUtrW|8)F?0^%ugu&Mp4OYy$ zRf{9};k_!`)z1wG8!vZ=~dOMdwmIvrOb zT0KvzR0hmGKJ*qfl3?Gk<2fl)KR9aw+b=c{#x9RriIG=xyl;pK?$%{_jb&x4`-ycV z#I~%YKKpdz(&i33m6*m!Iiz@0RbwT$VR|IkyPqyDm zi~m|OEC}!~uN^)qdOvKiIixw$yNTEU{ciFJefjZ+wFC(-C?orBsUs|a=E|QV}t90*wyd1<;l+P8ldXk1E%wIJ;4t( z=Qkr7PCBvD%S19OQV0|r9`>?+GJ-144NO)~o(M?l9hLn9iyBydtt|@anWg)mQ_fT29WPC}b%kYpQ>WU#k*NGzP zVmu7F&4~yz*3ulK@tIU4a#A00N<*~UX&HU43To`EG{a?+M>Tr@HMzV;RKXt0(<`kz z+GvYrw((Mqo@c{4iq+*R)f#CaB;-&`tTLWc+<81sd1VnV`vG|cJ2{=N6IuI8 z&egLVTlH7r*=E}H$+;ud4O5VuUHgbJAvz25Y*_&^f<{OTkbRuR@iNuE^*X!>^q_62 zEM(srhmK<#R6X$eg~;k>=gx;Whx(oIXM=8MAlm(*2dc1)2oTn*gWQBg(5;F}@K=1f zhBaLj6<&J5${X)pJq&q;L{1K!-5meWJmnl2uO7R1WzA*HP1hu|YdMF{)<$>H%CO>qRzt5f5?TVSC7Ljrdb3k)kM%P7RYO<3d6tj%7&e%M&aaIaBG@LMeApG6 z@8myJZjO4>{mqg>ojTy?TZf%mmJUv=%;U!Q=gqg1@oSF8bc$ulr^~&g+L9ngy}hd~ z^@BY~E{M;9l-0KwkT_d0T%ZoicG^5yf^2uQG_n_N60JZZ0$1NTY6d&kkUOm^(7xO7 zbUNSl=qZ}inruZk!|P{a5f`K{&J-LOYO#(|9aN+T8#ujVq}I=L8;nk zVT~+5=z=UkToG1&bJ|2UtsFvV7{2ICF9W<`lW;lh^a?1ObTcnXm*a#XSE>*UW4V!6 zW{S-m?lnjB-C=Y2JUWS2!yj@G{?Nhd|35y&Z$3#dQx=-E2PS2P5_RA*&7E8K9%iwe1>@EX$**d`$gVf!bB#)keiRg73oaW~2lt2UZ{e1H?Cp{F@63K|~GO zAGhxTdf{i|3CkgLZI*DKAGF8Y$s;TBs|a|c-v}H&r!c!_UQY&(K7$#_{{?>kTkGdP z_>Nv;_8On{wiN%-tOz$_LZ`?M2xnkR;n7S<$<8XC{6WwQa`>oQFyEFnQ{2&rtdo{jm1m;BYO?f(V zj2}ToY_4%N=CPFVndKzQ0%nl67vLoau1KM@rINGrv-#%<#*9zN@~m9XNyswm^Op+p zc4lIZVU_fJ<^HIOj!*ab=eGJJb5fyaqpBx@%#9ICnEB1Dk>*5*A1$!nXq%TVU+2FN zRTDhEieybpQ#~XrI7GcYVL=b{drf|L?y^t{6A~1z%SGeucr4}?4vK_7Ta<+qAez|ck7xK;W5=I zTy6%1MHq(ZGOKgsRD46I!k)^YqGkMBki+p`?XW^1tb)&)SS*mQY3*tz56zQh!DV()bp8m@5=3#}(XOhNkVRB%9Z{b||sy zmz?W{xn6f8xIOV4wa8J?E3?j4lWrnpr6PK2I?u`*n*3BgXU1TCew@s8f&a4rZ^qM< zjs^yP*O=GIH(ljda1pHy*Y2O0mvBVsyHg9zcX8kQ6XF(ZW!r#=HRCXdrK9;(HB;YG zie07V9o@x#5UfT`iZw3~r5zdUp?;^5C8VutG&*c;Zx!h(qKij=@&_AK+HmuJ^Wpg7Ocwu z3fQD8m2c~Gu1`&|1e-0~+Ue+gj(cZv%!IUJ7-?hHy>MP6d=S0Wa`F_Ec78Pob6tjh z-4p5eeZ%*_L*}A;KujH$%ms!e&R)PaOmfh8mdeo>qnxCmfV>^Bg%WK}eas}JYWGdv z%zb9%euqS)42OkVh`Ic;%@`q*eHe7lfXv& zHpL$73L7zl&bvhO>|-X(F{Nl5+MWfcBKUS_0T6+>{aN?~`nMrV!?tnykA@5r=s#1U zUVWIex^iwKx^Qc|Qu*m6-({Si@h1)Vgv(1k1+pr*2_@b28=CsE+lSigw|AAW%90as zt85`G${$+_p5GR}6BZ#Q#T)pkPp#C$+|^7l`|*Xat~ATj7hjuAGv*cvyO+F#RE-C7 z0LM=ne=-*G-xKj?yci*WY>|%RfnWLtJPP{t;ZN~;ez**rtBjVE?c9);%7W~1`q5KO z(`Q22Fwrkb$WG1vY?AGN?)-@}=`e1sjCRIHTd0_CF5KG8isH!8c$wz)hHE|_ObdT) zed$R*_|v?=4G;e(bJa6{F>wBRwA6ZEfV@;3WcgT*Bh@qm8p0-za4CQ}{>CLqjRG6C zh>?G@i2sw}{=B#fC+0XB8&mgCV2=@qvkNjaU>kaG7qVb4EX+`_JmJmP%aD8ze29mo z7-ul(!nU`V0N;2;Tu3__T;V_luW!v@wdfo0S($Z%zwg_E42H2Vu2=bjk+*}NuH+Kn zf3M=7q2?t~=;{r1JxRx`_)pgP7U3?_UiIaHB=cO-7>PRmPjeaA4$VhMgDrCMzb*2y zUe@TU*r8GLS>i96&}m*5rlqi;Z}jOTs;7UI{nRMmwCTG^l2fL!euDN|P{lpZO1#dl zvM?X z$spxp`Tp5R+?vk2TJcXZYPy&mJ``6wWUn2QaDzI1P z)RX>Lt{&ZTZ_(u`s*TJSiiNX<@$?q}X;I8!X~~(rS$22)Ayj{l+gZIj=C$(EPGd6V za`ekim#;JS$*RJru-Y&$(#?-n74%<;sL1dmBHIsZ(L=GjySMoKL`YM9nH1H^(UKi2 z>rzT|n&xB$E_u5pj)(|_IL>17pez5igauONt*;d`J|u=wXOGClB`8}y>>3YRLDCjw zSr_+l$(Z|-WXbLmZ| zEaCTqg=I`*m&?;j6N?&sg{4~|PmOjm+URN7nsKilGaSOww@32|V8)g`Wd0-&@>cSK z)d_GKLLdH7z*+lAnl0{;t|U{HRi)I6kq7Ak)MKYbV}LStr0Bk^uH0IzvFV}l%f{*& zsXc4Pvcx6PXcKza{&~3y>msE3|B|mO@mybO0yq_(ybwN!5-q8hw6DF4`p|F`jeYk` z1mF5)gCJ-EmiQ2|az|sX&KK|P*u%m!@)k0>_cqYUOA?!lm@LzqkKkYO+|F&aCQu=e z<9hzll%z{!Yx(ro4HGu0MVnx92|JN4LXs20qC1SyGKVtMDX0Y@KB``07G+G2gp8l0!j)tC$5@fU2|uVxXx`Far}9nYCf`Yx3L%^ z05HebuTZKEawfy46-Ki`?)>@L;UnW{>}eIRekk?OM1%SI53Dk*?g5-VZ7r5(V=O26 zb4${UR#XNY(vI3QNL1ah3O%N}5;Z}@ErB&N>}`IbykX1hldDmMvVHuXVgov|`R$Q14SEl$ZBCXNd+7|z zFEo}I_|@GtJIkFxR*~qAjdW4r67VLJyaWhC$c0M!;Dv4Okw&w8t^SkAW||(N1G1?p{N~~s?*X0Rs)9xp5GX^73Um*=4?jD)V{_Jd z*rs`^%ux3}3%)D+S7Mcv$YFGVjXDv;)G`|x;qa)dd|IcolWiEA;IcD9dB&_j5 zB_u(n_+sPQQyPpLk4-o-(WPtR01f7={~#YVR8O_7wLm+xF$J9k0>Z+}vAbot_Hk3sugaJg@gA;-A`DV1=TrQL0tk zWKdA(?aGjlX~%QX#MkvJx5Rs~XMQP}rQDs)a~Bq!dgm;|z^L5AbyX^N(W~>TIQCzT zW8I!u=3|a3h0$9ASj?tP`zBU!E6QoY?L&>q zU52f$G7PM~t|E-Qch{6}n{?avS!K{lw@*(pZnIyU(!{?*;hO0DMjPwo%*ssBIp=dL z_~-K;eR5fO((UfVI}Jy7-myGi_2R_>p0(F5h%in3u}++1F;WM@TyNF1O^y z$;W`H8yFPD8-FM}w+HUn->Uy`$BrH5#+6ZXPbr`9^VCq`sxItF?n&CErNJy5_@S?L zb(VbI-VH+M4!X>kdR3LHXpswx&WyJ^q;6x|xKF0o=k;Ce?o!iR`u2azf3yEBr`^SG z6L=~<$xEK!r#ew`CHC;o_p!j_y3LQ`fhj09te?*_(^Wn>={5LoBZhb55oZ_&_ zO;9QuerfNn{tD6^L@AVnYV}b-9CBe?e3Cvd2Im(%b3z#*WeJ`Jy;Sf!97S|AUK1D3=$v^oB+Yy-Q696TNoUI1a}B} zC*Ln;|M~Vh=iC21XRmedI>V~fbWL?nzg2HN^~zIy|LuMmc=$|CQ4T;rKmg3(f5815 zAOk!=K|w`9et?RKiiY+89TOi569WU21n&_JJ_RWiB?T!tIW;{8BQ-4>9XUCZ05jVY zE?!<kf8j7-cdynOruf*6H8X!>VF|W!a)!9Ly19G!eee$m{1_A*75zCTHttJ&LPlm* zc1~_yenCZLRdr2mU427GXIFPmZ(skw_{8MY^vt(e*vjhK_w|jnP!V3r&37Liq1y4#1^|iwzT5g{Q_|lQ- z<64_J?GDO)%eoOR|3o_CMr;0hk~J_~3zX0SVyZiYdbz z{og69zDJs+YX<1nV4MAS;^(yuNt&wku}(yx=3WKbL6Y>C7Jvn;MAvN?i{pIr+0o#33F#9hd^F^%b~N$n5poq4WG@n3 zKq*3e-~(T`_P4`@{&&ORu;laibrCosUwhx~%^f5Fvl;cCR3BR&P1>?8S51w}H?fpa zJ(ksc^PR?|QHE+i3a4884pv(}Eh@vXQQ0!1Totfttx^$r&O1lPa{%JDMOQ{~s9f|G zMrTc#XmVO<)_WDGTt0=T7e84=HvrLj!Yy5{S|quR(APflOi8}ikaWCS$a!Z$jjsL! zhzYlmO@{>G8)AWqu>vElw=fG!v;SF3ZmyouZ2yMt!qaH!YlZm`gjarjeRRd9mqz;a z(;G8U8-l9ftD+h;cjM5$ybuht6vy5iRAnYmT1FXPDe0wW=Y{XYt4U4P&=ga-NKyYs z*CXZ1?fc&d_;XjAaTF@68k;5`%DNBLMuqac;7ZC!V+eR41ZoD}j%mMi8gClYm4l*oII)sQ0oVxSPxj_j zQHBVDGAKUm$-*}4@z}2!goc;6` zq>w5<>2eQPFUXQ>mBFgp6h2|?r!gaYFEf8nI<&Xq#JvpcxTph~ci6z$~nCZR` zVKNj6dfta`LjQmWAyD6o2{f8lWQ2d|^G+ZmafETCE=qfEFO9QuDE=EuzHKKX{W&m2 zg>Fq6Tgu)!zr9F4NiF4f8?oENII>`(<3l_m3co{MMlcib-&jDQm!Iwd%J&raK=~(? zdtfXh^YFXcJtKZOjM?*dN zRX>ZYg}kummb+@%>MR6P4D3Do%Fm*BdaZ!!WtE?dEu{BAnMUbduxMBFU@ZTb=H!aEzKeXpe0SSglpa!XiJy$3!FUmMm{RwcO`nagSfwP6$YlAg46L??xz>Xxo}Iq)u6%Z)bitUfCF zfjxvVJ7gVHM59j66@lD_gxndDwA!p246NNTL;D4%TKB^5!kz_p&`Xd&@}Gfdx26k< z!z8&CBd>4<8g-dK;g`0ETf!k$Pa<~&whM$2&I6YYuy%HEF_Hb?E z(Vj{OeN#Y1E{_n|Emh9&7Ij=))NW#3nT4UT?;VJm)w)C|G`_G`Nw!1Nz zsVc-lHU!6MjmdgC)4IJjJrUH2h!XcEJ$=Rm8hTIDlDejX0Nd7m6{@9NF_~Upr)R9+ zkH>`v2q&>V3b%S61j*Vldc+Q)LeH{!ZI|7Tb=%z}OZ$c+fawa*ejQ#`D*tb;aR1q= z$lN%Ihk+UHA&`3!canyxrnH#4Xfz!9<97XK3t{q=>#yoe52!m-7=Yj}{89Z!3vUWp zfj_GM?}h*0>m2U@5+vNc=aR_yYX4Hm#S>RTQFV?4K5_6V5h=ICLj=dn>T48;T^d4w?(0Z95v4U{u zdW+t`cn8Sd16?z+vj6vy{(sX;V5+tj%urVe|5HB7oF1BcpkhW|<2ypB36dF-3C~V9 z;B*fh2Dp*k`8T7&)i2}k>c`WzbjrRZaf$qg`iW1SQpJ#IjlNe$_zLYHS+=_(Jf*rp z{JR#8fS{wbY4^Y~{x3DW{?*LJNb095Di>TGEqq$H^hTkF=zplAj3^d-DlGcJr-B;v zh;L#K-cA!GeT@R+#UuxtZFc?1UW_4>rNjLtbtiRY1_8NKe;WEIC)*!v zcxh2W57%`R`lv6E)D!>uBWt8O73%5kc)X9HcO*H;Oi+kdw?PzZtZ+(QBW@Gh0@2+) z{1-W?e(N;r^y-Lt-EQ943O$iLQ_MNn^ivTH(wBu)fRLhBvpw3g)4} zErtiq9l|fj82fFY8@vj)Bv>*Ft*&eXEjg-LzwE7A;Fd_S?Rv1}A(e}JKawO~`mO(d zVcx&Urd8bZDq~^JU^se^6eevEZ&4jfq}@he)2rHL?}6^y+w-1k(jbZ~Uu_!Xh_^a! zNwP{Fr$)X_$Z*Ck2_$Hku91+iHOrH1;$Mp? z!!A@?%bz;el|yt0+DI_dV`h-nRp_>J$Gy2)|BCSZrD->LpfA5m^`{a0asI0zti|s* z4cqj8rDM6j+GSe7O)_tx@XLRlOuWB|T$1{$UFffY05*ObugSt?36zEaBMeJ3$`9Iq z1N~|*)OA%ophBOC7*QE@nF*A7h}^{ALK3^UNFP94Eha@5#)qDlv9Ah%(g2f&;7w87 zdw`lO^%{L!;!20l*%?ri#s6DTa#hF#(?g!B>cocup!~;~RZTUm)nVJit7M7pa;C9F z*kNK)Vl%>OzJT&kT*6xVM zi>g1iP{Ot!vUk_s&|R>NwlduVe081JAwt7AB%txMdLmHgl}6IUoO~K=`zo|*4~j8j z8ywZpR^bm}0FbbHEoq!!%_+ga_ZBKCf|gG%?A6v{a$_a4>DTy5;z=0qu>&r{$r3xfgoWjzY`mXTeOM0Y0#E&{^^`I|B0ID}zQb@=4NX zQ^v^fpVqA;y+fhR1!6>d9fO@qv)=cBCA0&Bakh1*bf5oyU8(f+jX0!2qdHW195s9( z7DEwNQ$uK2QM&CUBp;<}&&IgC(svxjMgCB(T+T0}AOztOL6TBqS5dhQd368r0R0;#1|bB(r)sHjqIgPr6YtmytP%|4 z^R0A4s}5|knh5#Nr*j_WXvZkJSeWzz*!*EgQJ(C% zL$$|dq1oJ=8tAz7qLd5P1yv|U!K7J{(85oOHNlR;B$!bZ;dw7r#{40CNQ5ru74wRi z&%zxT6M3*pzF|aPGEqq~A$t7ojQQf6%}^b{XVk07htQI`djNWc4s;&M-+0?XcUkl9 zfgfg5V>3d_;Xi@-d*GWEJOnb{UznMO7|irX!G6`lo;SiC`azfPJU#oF?(>y5W488n zbGUfIH%0DmEok7eBiwEY|Emy3w?_)K35->rgYbwq&9DSD=_`O>hf>ZEjV1C&Bx=wx_M-rntfTN1s_BM-KO-b-Y^xM%S23l%dqa8M zhj7b*+Pe`0CC3yRcB0T5W``T0h7YN(>U8jq$+YqsL4@BZ4UUBREOF%<5BG>z~(Q10c+$uSm}sn*6j zv~*dEd*C&g>J+^kZqdLaG3f6+_~S4?(d*hIDCDvXBKqRw31%N)PhS6O?NU!PdGj2) znrx~(zR&xG{q)S+2Lh%Kr-{U%iZ@pkhC0s}=zsc=!ccakT%PyC(@wLCR^>bLT8-9$ zmb5Xm20B_+^{N6);`i*@H}&Dm9DIxw9^g*>yl=Gj^G@(yp5QB*NY z5S8>%Kr7Q>)>Z!are7CxOz=1wZS4F{xfeI%f)L_%2@Y=7>?6W#CmTp6c>1NgmT>6S zr>`*VzDQl22dH!muk3Cx_+GWY$Czkh&j26Y18;ux7s^^i%rDG!2q?5+^}Yu^R4LXz zV({;(ZJ^8tKT_gIen+xq2#?N>iHrMmLfe!8i_6yT+62*jlxUS0jU_A!(Xi)b_HQ)wZ2vqCUgrA7 zbFvn#wfF8(Zfo#av7P#a^G-jnP`mQV5$Rk>;U?ILDFU}wUc*cxHw zqUjAChS*~r_SwggpDMfhiqy>$!?3xA0PDpa0nLWqz-;qaG76=XX+3bHOgC3Z--k6p z!O+Q&$P=7rTRWkjq$Y1h(7-d5{ca69){TWTBM0=_Fkb!WC?^^rvFH>f7xn+?OjrLW zs-~p*-9mb=?}1TO6e?ed?;yC#_}W1J9(ew^B&>b}3EGFUdDqbpeiy3*k7oRCU^)FB z@Ay*%KrndO-Q_Ay5}m{M<#VbjdFrjzHf!9`Hu5NSBV(e-H8^mx4;~w_$ff8y7Y=61 zi0Jq`hW`W0|MD>ndU^Nhw&%_yH;?L00`BZOw0=oL+bR$cxOk#G)|+KG&(uR7w1N&j z$Va$kq1w|MxdN<=E{>P(fw1Y|&j~H59`^v3{$xhQJy2~Yv4xNiT^fYmiKu*f*$`zW zH^#ZDNloPHhq3QKBN4%S`?;y*?C!ObTrlf!lYEhtuy|pyrRu`ac`^944NZJMm*@u> zrk5T_ZE+x{Ls?XTadK_ zI~t$lXrZerBr0kceL?P#L3AS!2(pb>oU=UFB8qCRNj}f1_k!$>%Gc#NZWH&-zDn{% zCE1ar_zH{3B26wyrf;;}&x}$qg(}f8(Hb8N~~C)<)V}X;0jpQiS2Di$O_JN4LgK z)h|dbwBxG7z6vsiJ)+(8-h(HO+z0^^zWGWTz4BH@&9)cL z5>D>zVz-@m0SA(0XV>StQGHN1%`!b?SB`eKA8x*m$@RE^*G0YszyQHLx)(!Lcg(v`gqk*=v8Ac!lc) zp`sD2t%RD^3^49(oJ#G)E~Nedc{63i{?1^sxax5m5hIK8Nw2LvAf{WeO`04KZFe3j zcLf{9T>DxVW2%2;{*j<(`{iYXm8YFAwK3FD;%k4YcV##m093#>%)VfAcZ!=Z_7&dtmEw3uLV{ z;~Z)MJvS^g^{R}$X76}jyHk~>o^^t9G~7jHe>pG+2wW+3$hQlkZqU`yrcd0ct=t;! z>^9go<`~Y~u&?VG^wKRu1Yat%ruu1BR3vt_Fi)cK@7OnZRrt@3vk4h|6sdf1?2T&3 z!eqmTB_J~+s^x_L(k)hx#1O-*v$%jIR)L7Zz1>9o#1lc1V#@r-BEu-e{VVC4+MLC& zC@)-Qz;vcF=!zvg9T+7wxmk;emy|u*C<2sk>MjB{nPbdCGgUu4`OXYNI_h@1Bd(Zy zC41?bXXWbhJWINd*4p}IG<_)~1>q*NJTQJh?bkubK&=O~TdhD~oY zYPoT^BFG6bCPOYtw4_oEpZBL6i9Y9M!ycqer%Z`##54w>E!3TlP&nzg#E557bs6Ez zZ+oER66U)j|7eJnL0NnOAk8t0s~aP1J0;=idW*>o+|8&3E3%uP5DFO*{d^YVY*O?< zG)#W2og7#BAIqDT$jaAnVz*|vF&ypyvHeDYl(QgkCpep0?5z6lxETLfmav^VjgBdg zTyOnf-2)1Hg_$M4SEM1Qjist9C{KXS}y5&7r7I4(u zQm*ADN)i|_NK&TA({X@fJs!M?>OW2W22ImX3ohP2 zZi?3TVw9P%IoJr87RfhDIO$W0B19|u;Mlb*=1uN21iUq}yn2&s6nqM-b)8+4jTBAX*V)UrU<->Vc)#VIvjXy_I@s8CZWTyJgNm^}cK4tIDg ziR@cs^&BBZf%u1_`)V_72W)ZGy+kyB?L@WLg7=GxCi(|>)Y}suAsstDcaF>yIkd{? zja(QG2ZIr_9$AdNxT-xjsy@0e1xT)TQco>Vu`Kv!Kf7aWUqzJ1y7$PA`kNz=Ku{Gc ztDLVj*0bS_&0!My563P9)N4ozo|FqL&vA$>J(902S@g zGX?{rlT~vzniIP-G&MnZ8s>s#&c*|6=hWU-V0N301d?oDW(J>6(!1VOX1-bZy9)FU z-=-}-LcmpZY#CJwi^o^TA__>}F-Yt2ACVqa6U>op5=0kVMt6C*)SKr%_R0BLdooB+ zl&b;hdulNlPQ5o!Dyf&k1ECVk|&{X>}(PDnZqerhd(N4nt9;H&x=a;qv>#2&>3QccEiF2n) zB;R>YAE{oVUV#KkLB(`P7^pg0KaJEeGT3sE$u_iNR0(q{ZR#hunGVv)&c&L(rSt!m)5ODGXuD5}|ZOiD*~zTyy%N zj;@R=#bzWJ?`b&gL-jIHigfDbgS~F>qK@u1f0V9~o{W#+c*ffO9tJU350 z(*2|hthdt>ouaW59bd`JBFi-q0x`J-B?h9qt|s7IqLXbk-Qvl}CWW%97HyMb7R(wl z;)RsJY(OPvOII$m^_ufpdHK>s%b;zf?bRNL&vtQoamv7EIheo|Of2UkBSKzD4xx`y z+*fiMb4ygMoZ5B1^;8J!-P<>*sjF$ceyT+Ma*p`3be6Zm8Y+b~pQ!u0{PUe+Z2GZ# zplVL(`i0iqy5=KdnW{HW_+vqtcnAlQ6nJhHa(7*U@^uq?2W;sgu_Arb#@bQODcXFy z%7rd?!FPm*=h4m_RSOs|D(ir}SlGDWJZREflFx8QZ_;e)z2%H6P%$Sz^)Ysqw*rkLZ;%*c}ZroIT2 zUy5!;x*Ir(d8AC8M+82#~+S- zkuvbEk{UaQw--3(@q=e@ejJ>#Q)wg!Ra3C^9Ba2uHFq~C^q;a}v%mv-0DuZ$PYtHx zZ)?poe9*-Xh&#V}H)O|qV)#H{Aag^ZXL?@FUXg_-GKFO-HRO~m8D&JfZqnxT+PbI! zqN6$M98)6G_PducDWgn$4%Y4m1g>cG~gzp>rQrO`4N1-^}LL{<@94&F{-TJMb~6&JQuwWd-$?5_D_lcT4zXjnvuL5=&z%@T7x@&x#>zGa`(a^}9^^oT{DLUcQKwi!7 zd6j?wn&Xi#5#jW>GMw>NNg8|OpdFinI#G16x2adSSzVo6VpDYyqtt{~<2b;-zvg62 zNNwTGnv;gD((YPrB^(}HvaC2R#HZ0Q0_s>Jj9vR|Pwc!ewo~+IvM_H#FS`t*IK^L?#(uz-|Iu4EtTl4EWT{O%sOF%TDHS!5=ZjXI1GLMVOw`qKBC zM5sYV9Lp|wW^mI@sqJ$C@^yUI{g@}@5XHJFBQJA= zFBQ|UG`QOO9@xFrQ>&ber7g_08Un8qgwSZIaDkZHs6$c|y0xdAnMg-e%Atrg=>*wX z4th#h-=+25aqzo`4J>2g#4C>I?r){xNK8~g@hk;MISL)DuBPTaOtlT=DLT+pfBH5_BB;zOrjYY}Hy!8)wg+@v_NbWO8Xr0t4B;be zTrqyA{#bPnpurY8jQKtI_L3)kWlb^q*O7W~Bp?kIxx6jJht)%$ulQ=RGpuq91Z+8$ zbipE{@B{Ho-PBO@!29h_fnc zinb+@X&209BW}g@aJ5%N5lXutyRI*%7SN9WRtB$Ic|(KFM*lF?r+AOeW*_p&Gq%G`d*rsOixHL=VK# zD9H=@DI;7QJn$7oPHIM{YrYCAetnCK9rC2V2{XO+?t*R-5qhJ=_-qcvG3zmd;AA`CAJnp>j|=Q>KfX!-XxBSjAg`mO#;i- zmLF<8hMGGUI0_-|nB0bq9r4L9Atb4<#(UJKeT?NLi^0G$ZFCpf^4y-*(}}E9&FRt< zdo)=~8q;zrJZ@i2`WMV}8{+Z#?5Bx|t*AzqFPA#6Qlcq_lE%Cmjy>H4o^3nso5sY) z2hk^aoR;jv@5YOTxB?+lEOll#Bumsb&|E#xv(D3q=TBwF46*l`a?s%AM5MyrS|plc zb7s{(yCz-@zT?FWjG>SoJlRB^5%IG~&a&7jeFoQqWbq}#!|a}7YjB2mA=uTEeXnSA zzLzT2TqK?zTMFP3G+$^Zf0H}*jt*+Nq9-PF7H`A(v|@gK5R9)}N?$TOfRwky#O~83 zs%0b#n1CDj&f`Vn9$C9KL>%{NGx)-us31Lj$tSc%p5&x0PHNz|(JhAE$LhS+dnnv3IBd!n9Q|)vkTdRvND6 z2PcmFJ+t$FJO=qIZ0SUa<^|88I3J%Sc7sOUKZI`@h&hQ9Yt+n$s=Udi;heVXQJ5Mz zDM?7GH8GTu#eYfs&y@FHOp*L*|DUHxfT+c8On(wadZ^>VxA(@lFYdM$L)#SHk>9n4xs=U6 zk;4LX6Z#MqP0sQJE7oR5&iC*^D}%9Kd3fga`4w_N*jXla-q-^oJql(NBw zgblUH3=?pR=Rx}~}(-9jnivPGRBzTK=2?d;B<{|1J z))kZmQv#z849N6y49(pcs`PP%K*Yc>75_(au?sfU$?D(c9FRXm)5(yoAvjgK1J1}{ zPsErITSN3mBmz(29jlgSWf>jzHL~Z#M!c&hEo_ap;+=eYR|@ChsoJ*gyn~fi9O8!H znP-e9A{sq{Wb^|NLdUX0Sj^-OOvY+@I){XiwNNB)Kh1>%;q$WJZl3{kS%ZB)(hi#c zBks5xk6?6O!KulF*>MA}`JrM148QMT0f?K9t znd|CngJ&lr%h=xTJ|G@;q8Sm7F~lP1sV&^*8CGl#W_5d4p^*`Ue~4q+JpbL(Llu>A zCB8c#H%F>eAc&nWUMe(nZ{j)FmA^y#@w8>HoH_BUUgfF4bVH+sWDXndgDx%7`e}2$ z30D2))LGfC;I?=PXD(vk?IyM7jF;UC@r-I^W6Lz3!Sh{Z^P1p&g#o$iXZt^KpFbq_ z)=sI;&iFxiMfjzLu#~l{VvvbDLXTTvoYLDNR`9$VaJ&i+Pw0JFU-igM)t$S_t?r1l zn-P_Y+g%=fMk1JNEzVc_>FES+wGyjJJbYAWu@~GzCWa_TcDiO1!J(!%Q zNzps$1UFG4m+6Z~b8u8u;YQOHK*$YMCQ+cHj%+`8HRtx;i%hX>!dax1mbg@OC+IQy z@DUN10(CGX%H_OGXjopL`j!yq;noB=?$HR6V}(FGg6g)~R6Cgug*pVqTPXP%+6`Ta zC7!mG?uYXP|I6~as-_0d2pPFu3SDWPXPC8y#)u!J>g+~@t5pQ8b`=Ocj@kS+1k7=ea0=Zvl^K) z^d&kZ(9JY>G)@1ukm~81@ghh#*`Pxa)U@->p6P2yo(fkiA6e4|UghglVznNk=>9cZ z6(ojxz;AW5^o;DwLJ8KwTc9Q~hJdsP!bET@YCzU3uW&$P{&uWW$LOH7 zOU)|3Cb7`*s^*WB+bs}C<|v>PS|DTLYO~SpgcN4(_oA_eu%s}e=me_-hx*A}AGwqM zi0{FfYr~;VC}azE1!6cG&dqb5zObU9|7kTYgw==%pxfgR4~vmKKOCE&kU!rjhS-t6 z5!~B1mwT#jnkPBrGg{93X@wQ(a3dxo>*6_K>}Yk-O5Bj#r{;|hMCM)5=3RZiC_Ma!LGfRq ziu_mNxbTQ6w$W1*PKwKPvO7JmSo14)DQLosn*LW?tyQnDULeacfytmOI9J0e-|)&? zO_qTA_l)UZIVZpz>oZd{f$2yg2}r};v4WfoZ|Wro+(LjSHCsk5)mnD$+~G-Ruhhf* zt+WRM3(yg!)`QABe7gp?bEXDevbbUKU_!*YFgh9FyxzSB{K@uPH$NL)%)&VraLSdR z*UZiFr@L#36}juupX)@o=g^Ma|N9gFR!6-`3ylyFLSCpPcP3zm6Zzf#V!#7M;NtY%wZ>o1!YpM4VSmSeHZDA&KLdl6SDYWRh^K%zy8GHKc<;8aOZDuB$1!B0wTsCx(|NAZ+l z_`3qCQrf0aZyx!oH`iNHr+iV;$HktvUmD(2da$ijhdinC|59e~RS#1N;P%(I$nO=a z$;|%-CJfkl*<4#$i^Lqesu7AHhN9T*$hlm9hJ&ZKWbi{5D@4_xcy3-h-8!*G<}>y) zjd;bcZDs^{trR zd`m>Y6N&1-fE#S+|xd>n;yC;V<$b4u)bw(~Rf47qy09qi>$y`omRO=);$mj;3F^Nsj@ z56ge&Ir-*aqmQUNj$q<8MGcaJi3cmo5ajd4^80pZSu;`OfI{Xzx=GbvyDI-kbp46% z_NR(Jnr*aD%Z*QZhI1B#`0 z8+**!_a@5vLkjK(JASQbzyd&uznKvce% zWjK4es?dOcQ>i6PVj+Z3*jGn?KJmGi z)T$A!juflUi5(C_jh9BHyUp5H)y#6lSrx^*@S3mh%v5aN@hKcYjxt5u!-*vtsAw3T zz56M8#<#5-mW+~bX;VVEPDjsgQ)1?Gc@J3eV_m-S=DPWfoBsWV$eb`a@3LSVHF+&u zL!bAgY1j#D>+Z@BZnz0+82CdEOQBY^FvwJ<*X)X0_?QpXbX_aDzC35Zo*JQxC%yvb z-UC`s#VJV(sTowHwK`F!7_oKZ9bx>L`P`e6Xo_aCg4Diw8-{_? z6N+d}PS_&XldQ04`6pimaM#`@Ss=LxA|18&+Qz)E>#O#3U%9r)5UYyR_kq458uICl6eng zM{CLJ{|PYucvpdtbum>Jzwlx@G;RJzfS7EQ;%aN?33>1*%5^hQ--1@=ec_XrmzXS< z&*-%*%o-XMDL!-jME>gYjp;hA%mh^s2{TDu;_w5`6-m=`xKjYSBdAQ@n!ZAbrcAb` zAAcH7caVJ1$kQl-=Na|Y5+wKcZul!{{C_lM?IpF+Q?e+X^AC28XD`WGsMZw_QA?dM z3*f)Oa`wO)U$8uUf@EXRsFw%Y-(0el5RTKDFq!7HUf?(VxS#I!$&|qr)66rh)!)ew zH2EfFVWk0?!{9|!CMGURxuVI41h||F%Q%zYE>} zDzg8F-4}o20sg7tpIkzMHixtunYxcQ+O`X}V_Kp?2#VYuuXQlI-ViL!P^?|cU5h8@ z_n`5g+@?r1F)=fv^XxKcE1sj2kS>bd11t9cMT_yd2q{ZSj2q`^p>^6^Q(rs1s~FBs zLj519rN5N__;+$9|C4J_|3pnM`IGwpbF(S_8HLe{u~i+2nB4e)fS^`=vMHt)f(c7^ zbC-ZB{rQPJFE>^{{aX@k4_`q(d2Q_T6?Th>l~(F6r01Mml1R2cWtc@x)Aw1EFCh?X zLhtD}mdO2+9FspCaX4J@2wgrn9Gf@bG>%43dfP~?;b1NpWM{X2JPj%_^9sUrEw8Nf zK*oylDzvJ{_&A|^8EO5tAZJF(Z|~)@y{$R2De8jh;?f~sNwiwHYV2r9ypxTMQ)Cc_ zF*ur{XLJjy3-tcUmWe*pHSEp1%;!T}PT8_5Y*C=$qJw7Wt)c6DT%+oOWth`a_#v*#8%hM2w+|dOqIyfAZ{nx&z9+hLN4suGoeor! zutA{RCOtnKt?5lZ^k~cRl>F8sV1WU*7Y8t{WDby@N(-m0**6+@mPl;A4)i zlC$hE`f5U{{No()Zt*BAU(d0e=~#-GM)0Q;WsFrY5J2|a)7>JL%q=s@1&1dCFf#^_ zKW^It{kzCYvcKFJ>pPi64mfBRKz~qFJHt)ef30x-=TQDKXSAqRu?S8pAw@qN8f9os zAQ;50>p4Ovt9QGDA@@LS2&KJsSw;c(`YEXwiBHD1)lsJV$HM{jeU%210F9a3AcVeC z`mkn!9e;j{&W0WE4>V7Vboy9_f@K0BaB6lX z`&}Az4Z>Y$P~!>Xk5`Z~1H+ye0yU2BI#FbKqB&0f=Ei-!B?!x#lxZ3brhmLbheaso zxtp#IHm*Dzb`Z9$Vv2WjcDwDz!k?_DJ$U@M$3>uCUtn`D0*mzb-7NnH&V_zsW;>-k zfNwm<8oBj{P~Ea6GzkA;k^Dyw=TEltpDO-z$#07;yWp>9QiM zyKY|*gtg*9>MKGRdVp<{_h=&*Mu#Om!9tm}EM>YP+N*ERis>{cHRBkz4LIHnV))Tk zBDb3>b>&66y{6>>AfOLI*g6dmBXy`cO^Nz0TrEsR@rWS%_VlO|bF_!< z#586yu<--e*S6HA?Vs;O2#G}>xffX+xnznZm4j&(_YFv)!hFXZ#nn(7>~q zeSWr*_Y)|3b_s!ReN;%!D7zlYmlM3H+e6j%U-8|?w=DktGKYxR(j`gsg$S=w60}U2 zr>60t6yO)(IcZhp8P)X5u6+F;mPC+DdPEP2t2c3pJFtm*DiI@G@2O-pZA3J$#|p9s zum!N`=nIZ!4yfs7hOTDR6nW3_z`OA17}x(UE9bvrwP8m189k{S8J^e0rVu69Lp$ zW7(MsUF0s^ktDwR?cEWM?=k9QPt=aoPi9rj>9z|K-3N$X=<0<%e zB%md-szm9oZ3Q#q8qAX+3487;sP_$L(LdaDGJn@As3bTzufKcVa+&=`CugypREG+0 zd-qAPogGNdSYjcvO<7(hRt!s9Nuom$`b&htsxjA~L4I)w)-&PtlrP?Hzg{&@&)3C|_Rr_vI z>)3hlr>a5S4{Jxkur-P&Zt_1O9p9OC!f;k_x8WBU#^ z6t0kIe4LsCjd*odv|sG#9#dji*Ql-iG}t^jK$Br7{MDA&$hKRN$;md}&!F$hX)9(? za$(#{3m?N2cj7j`LAsTm-7QWdJ(6`;*RR)Ro206?!r7Ag{#baE90F^?<9r}rY}<6C z&SjuHSV6{X`dx?d;uZ1WA@|Lr5KM&Nyxw>x0rQ`}`mknP=w2Di@gyGs(A=&%WH~Da%*aq??Hy39-p1mXtJ2wQ7CU&G&%fCyLL6KRbakI;Puwyy-N&Tl3kOM+dZaf&H_wUd0%p2}&0rLr z0$1#-1N5I9b{Rkj%^Cp@$iKfj*AUQzaviH=hjlQUrZ7Y?YQdG)5lg&&qcK{Ac2!ef z2# zGa2%j!bo)@L2Q2Z%bmT$g z9{-u3ECH37@D0&#W3{AwqezA5L<||p3)6SojUe%M(d%eZwU2s(&vO#+3g# zDft`!TqFAGJBXRh0q>p*eZSl_#jxOj&`&g>WxI~e|EIXGj*4U3-fV(23DQUeN#hpW zgKKaIu8oA?!67&Vf`t&YfhITv2(E#~2?T<>2X|=#H13`0cc0AM`+l?LzFBvE^H;6f ztIz4GUFYmmXYcR(sxqkrUYz6-9WedmoJ zqnXVu1a)^oKg(N#0N6zV4+*L)-D9{{EOoSgIh`f-M*1N2gkrhbnbfY8VYIfVN0Z56 zgZ)CIgI8Yhge>8ZF09n(@^%TysXeoOBVEXmC&V5kk5$&Z;(kNff+PLXk5$0>g7Ue! zm0*Jl^&SO`N|Nt@vA#EoN)i8u9VK<=)>yR%RzhbhO^AXAov^{6FWjqj#0VUP+3t%>V11KzabVSqA;f=w<%W;n0>l%F>^AxbiAh9hoR9 zNq^cN;hFQ##E$tHr4+Hep7}%4g*RuQbKkS_?vA=o=L!)>3a*7ppisAT<^>tGhP}Ku zU1J?=(R)hKN+u>s$ci3|`dhsuMn7k=pL>q+7jh0udxy%dz4Tm+{ZJl>>U#|mRs-j| zw{Kc)-5e3W02RBZ;&&oh3W@ZE`Fe|Xr-?h^POxoay$w8Phx9#~n>>qzkQ;%z05|rZ z`B_oLH`(cTx`QNn*@j-L6a4W0_6wlZG!a^DG{-yFrYDovC*?!_POXh>WQ(4~uJBmY zwEyZ>&b|o@+t~}8>IifYZE3nKmTV(3LuWc3yqeSS`Y?daMFYfvxE?>*Cl^`wbKf(d zZ%<8%6_NblA?iNvIF|p7W0wd{64vGUehQ;c_Rb6kGsb4jlvy8;(;)Cjo{ zy8M_NhocN7VKIUrXBy~P`De^?}-G<>(HWQrRwlKRsas%(%N|2C5_7(BWH3jVB zx5y_6sxV}8^mcqn)#vkm`v=8hjl$@0Q;u&n&iO1(mLy!2Z2GK63uQzYxX-!^r; zd{M_c8uRmbwqARjoqNx~Cb3eI!(e-y#ar{qhw^$7=1&;Ak*d;wy)zjp@}I_MObK>( zYwARlSw5j}?C)X&=%Mgi-Rf`0Kb__L?ce{gF%%vq2<-9uzI66L{gxj7H`?v*&;PF^ zT*ZC?zPn0%SY@8D(7l$(ar%C~UOY%*nJ{B506I7`--V!n<3kzC*Vx|Q)Lg<5x_H55 zw4CU2R8u|I4VLaO(Z!qY&RLb6#-KpP471mkI9ec!kf^AlGX}jf`zLbU!xjC`yS>=3 z6@xe2V@vwTveO#8#tjWMCvHM38L)yxzSI>dBvD7+xLTFatewJQ&d%+X7dQV?-}i?2 z&y{S=F^T1s*5jJfj~mhiCN$+plKJZA>d^&rx_it1+L(4!yw#IACRIz>;W7Zs!4K^D zu{mm;*D}~%lObyyPfnfjO%$+XRO)%$9rE2%i1yKM+jrW{aw+OxUHa{y%p!a!HB^&{ z6nF!A?{g356tR6oW>zOtVDVOFtU76FlKTf^TEVxlwj^cPa>Dm#*CbRAdlWQE?^2s7 zhvIIEd6h7RC$0op>2UVm-UQzXYzulJIXh>BA|K0=NA6`L4w!Ce2XrK491C;M4mXh+ zr*EN1sKGNvm>xM92gsOcOps(jxubZT5xmRK9i%R91Mv@%d213w#8x6jU7UU<6 zIRwRnxt&M*b9vn5+su+DW1N^&U%5Nnr@ol!#q{lHcTU+Tw75O+5LJ-0{&*PGDs*tv z$J4EClFw`VvLS3GlL5v|X#$pI1Id7l#xs<` zYFtZm=VOuU3e^+S+8nz=hHDbF5A@MJJ#$b^MD)_m^yRb&THDlLfNjBBw-Z+CrIM%4 zMDMIkN?pHXO6arA>L{_Ap7ETb$g^i_q$uTgl<gO8+-$m7x${!zh=+2tl~bQYO{dK0tQYV5XgvTVY**ucyv8XXDnUHdK=^tLhV?Mv zeM}_>2#FdHxn5A@iw+&3;?))qd@3y_3aWKYW|y`>@lh|%Z{78FBas9*?VKki(+ghC z^mm^t4Mch>x}!!h+D^<(%4F3^Q?7b2HJJym6!qdV*Db?a+la%kOl#IEw+JL z0VN>a3Ah}a@qK9$pKu&!DC2IcCm485Uk)-ITX@@azjoo35WU~~s2mq!5*q(`$$JqZ zm(^4N)Pt>cUPMJlOIMak66y zpMoXsDhy!;?5zI+e2x(TEgKVVf#q|}or`zs1RvS8F0Km3$eqzY_NDgyVMO(_@c^YV z9%v^0gQM??<9x1W<0WJu>e7kwz|!~n>gsXKWUeuDcrD%su;T1etCOBtl_^!)XH1^{ zSS6zHtK{;o{VP$;nwz7&c_JciaiwiwLusn3qeQSt%B#q&=Q#{eIkz&=eU~EcOUG(o zX0b}AG^D&@cfQlyP%<}rQ94yNQs*2vPxXW;vn?;Pvqn?5eE)lOmex&%=uHe8h|Y^0)=3%3Z=k7Ylw9NXIVye_tQ7#q=eIa1b1TPFD}at-&pnAZ7S%>vJ3(F(5wqncaWTtFM~d=#~VHjF&epXwZ7U|X@s89KCya9 z8w&eD$HJVw04jLC#W41H-sfOvYia%cm(8@jWRT8mzA$@<#7;iON=~F3S(r{ z)Gzzh>||@LJ@tF%+9t1t#!MEv)Ts%>!Tw8IFzEpB1(p*%162#ixe(z=A(zuQQQ`P= zul=4Cn%ym>E)8}4uhI9LFp>7KXDV)D3YOZQsi^~Ex%gncc`)^tRZ)N$ zvnEE~lG5O>HI#!`YZ@{?JWGcwWqO_tcLW%9v_-W&bCg!`Z8?cSPUUftgmLH zaSq5s%bZ|vK-}b^U86mb+a36&ytr5T0$$5%R0>c>q517;UVQ#W%dW?C~Bz<QlPq>V9~dQLd=+NX^~@hU*8FhE~FBi&Ijx{d*m@R+O%!XTB<~)3rYT|IM*ME-LWJ?93ac#pVdbS|^;ce- zB@!Y@I(w~O0JZ8#j@IEoJAwhjc@iCbj8pP$AF;?y#*09EXFl~@Q#C94$JJl>!mk>u z(dq4Yo;@Q0@?Ag4j?`Wc_dASOtvkKNuj>QTcLgSfuGKC#H4<>ZR zP7H6JRV}=N$J*qvUBa&*P*ZCJV}rOsW)o;FHA@8#G~1l$>2D<>+k4 zwYXowqjPqyup_>Vmvw4xxmOAMP&V;V8C<81A?<|OYR~<Jb@!Pyv&lFC;hMAFYHe2J2^Ff zs(RLH)tTkWH#d>lhSm`=TCX+HB~jd4jY2$n9D6I4?O#Q+bKjyGG;zF* zCC${ZQMZLgMb;4T&9T+rdFgMHvl}sp*p~$)mw0wjGJ{;2Z z2qsw+vv#+3?Z}FzW5=j3BwX;(2ez@*6dSc>P*vob?o_8UPJr$@*f=fA2NKpQdu43g za$UTv2P0j7%IT>a`-3l_e?FWm6+pXZnQB_mZ(y`Kn zVfQyw*wl6~RU&!G$uM!ikl)p`G@8``>@X@j%P|+MEm!bzv=C|t$7OF0J_5+3Q2N=B zGXuDU%0C9l;|`*sd4m_0808oB7;kTjf79k={NI+VAm?A3_P=A3oKoqz!$PZTvmC=Zz`Pj_4p=o^#zPjct~sQUk#H~-D>uZsU)<+}g0 zAIpnmV z%XgxvG;Rg~R%kYkbAI4(IT~UUEIN!5P`F{b(!aUd&AOm(Myo#Pj-qLKTfevI1pdB> z2^d}S{e$gKbfZgxMo!7vC)XL&CWz2Ab@D?7`$`qH6q4fZ>a`z)?3l=Dyi->6xNb2} z>{3ayb&T>}dFDuud2v~AbRQ%6k3OJArtO#=@0JHH$DpHIw7vp*mS_FmruoM<|D`kk zL`OO*YqtdDrVB5pRJS7!h(?{CvR2V5QRzuBLV&HHe2)h*;3u@x`={pg}_UQ+{m%QGD8r4I}OAq1iciuFD#~b zlb9N?U*rbmgXg$$8WW%6MhYP%aocZfF}YBr0c@xMkKr5Q#lXX2$(5&LU$>@>e*xOT z=tYI1JO-Vu{i8AW_ieVf(P@s~oBr71zwD+k_Yn1tAKHE2sK1%3TG!+uGSLC26#xI{QKI3E9^Va7y;w#Ah2Xdv zUye0EpK|3YXLc~|PwDbcg|=5(wr5%p1$xB_cky0!o)Gp5TI#oMXpAKK9V~Q7`4=F| z#Qhfl%PZ^VRqTb@NyX*qv?Nl&=NF)>B=>Lohu_=K`j&ra)F^_6g8ZBHPjpqxA{Ll4 z-(?xhVXZ^hVBq&kHBg8B zx5kqt^i)E_xs(6pq!eImjY^64KF8{mb+9hzar|~&RC%5y(K7>@nA36MB>D+mFsX_f zsr2|s;kunx->j&ICmB1jOmHCUf4|GWqMu&;F97F1BRemM+BRA5a31o#bP%hUN|=7s zncM`eSsqa-4c`4s&#F*ZEW!j!WeO?S(d-$^Ppa`mS6>b2!mv;}e{P(=!CCYdW{AZ? zUN(lE^dx*6mFR9vN=$lz+ZX2p-p%_5diswQ6t-TjGV?$Jdlain;0jI6l$9c<7d0;} zxF<&F-8D7<+ZVikUI8gpYNpRoFAaRv`;uTf^$Yc4eRQ5?m^+o{Lo%|Ld~KUSg4_Q~ z|LgwtRx6v)oB|w7;_5mX5`*TBy-07c?TN}s9 zj6`GiM6K&%e8`EFx}*^W3B&&Sa>W-43R4~ir=s=y-+#JX-W6y$K@wAWIo{3{A_9}M z3l=W)m0Rm-!(Dy#e&D-SBS>@ignDcU$*(}oY5?iwrkM(VK2Vamv-z2ib*-PD`E{Om zo~2_tC^3m4z&^WpiXz~1k~u+ibhGDKW%d#^iw7v=%u%NbMmrnmq!&7H_(}=k2Pgp^Ur)=ahLe?L z$KcW%tVGMu(X{De0x)JU)c_x*s!XX|Lm>ss(2zNI)2A zLRV90;8KUC6Fb8f_{hPF@z;4@AOU(KedT@MKVZgJrk}PWyvtE&a_|K6msA9`z|YO& z<$j`vYWZF-9+QR(4Y{VD7K>Dby^emhJ{*XD#77f7$+-~?miM;QyVsF0K>|&g0VqgU#_y$?5oIHz7FlYLh}-bFYk6G z@br|xL7rte`ZyBg4Lkdm))CE{s)6W_(^fL5m4gwK5Y@7oITxFx zEjt~2Z0ooxRty^aWy=xZQ>(lk6vg3LtW8cQ4ZL-{Ty0*nij5gI>X6g|*)p#`^3wM( zxI%xKApplF?Rkyu#%p3{s7+E5koJBPxLC508%?JDdtN{IXlO~}3v3L*!$49#Eu%Ou zE>wh|Y%(}-WQ3;;9OY|8Z}g19LAg(`s9mJ937Nl9=p~{Er}%g4>C{-krZ@bmbMyvd z8$H1Iqv+f@%uH4SkhGe%@}a_+7^d3Dsh|26J7f}T=uZ#(x70{BCu2fhQlWLdF@I5H zpWSq7;18$Sv6D@7nZU_Fc&Rzlav4U>;O2K)hG&Ln=LxIR>pk0{ZBXZwtDkUUTDyhr zE1x3j`O%vx_2=sME>if(U$;CI-e<=@T-<@R%8GQM3F1!}5|XtDU3`}H{R~C)!)wYO zu0qn-d?dzMLw&X~mDX8>@ml4r@%Ga7YJb}jJW$t0b3D`NZ3^e|zkeYbq};-aNjWV^ z($*X=JB2XUH1gr52m7ZiK3v`kmmz!e8e^+bGJO3|0Zo~|`=e_rq*uj09MYp&C8Y62 zPfA{TST!R1L$VD#HZ1dmSfky&=DXR3J0A)l z9exjrMVf6JlZGs4hy)s#$KMg_U~!n~MkcatRR9SXv<9ESdawx01y zJ4>iOObM5aGNaexZCmeZP2q3Rr(AtiKFOlU8)ytYT41s_k=V5LvRQW}HLpUjm6i^= zOQDsCUpmwZZ9-?Z+~nLDMBic&&5?WExx0KQMYc!-9|4At>Q2|BiOYzK^bbZmy)eXO zM1L?1h#jGM@cpG65~1{ldYgha3|C zo@Nlgig}^khtD$xDxXbjYpQm3s)3HhD-)hQu5oVCz*Gb)EIDt+7H(!aUIxSUdPGB_ z!h;y-W_*9jklOV3$?qm1P>np6Po{t+P!$wE&$6zBulLd)Nj-P@*voj)ww{UYG5PNQ zus~!JzSti2D810{c*79)il{wzkiY*42Y9fLOq(6F*ycqq(XeZ%$0C6F$=mZVGvK~` zYFK5A0y|R-F>NH1WM_6r`|d`p@(JEBzMV)|bbkH>pCarXx3W_{DQy=aBf;1B_REF? z6|S%o!x}VGxk|38?gcW>sb`eC7r#7@+ev+Z&`0dZ7L&Q1rz~UZJ%5PBZ*1U{vED=vS+t=m&gb(&|Kg5BBlzvhguRBK{805T%;%OjAHGB_ zXDna??3TQ@)Dt><2T{>%#_CoD!#;OmQ@Y_M5sI`C8oTKM<4a{`qhk8`_E5&!+Ns@P zn&%pw+hA|(xJry^2K*xW_(2O3&Xy8yO1+Uy$7V_57!rUFZ14py!qin5fdRF}cuh5kS5RKFO%aZAT8LdIncy?}4c6~wm l{-KlJ4^|xGBK!hiJq_`;lmN&5-@&{82s8W-&oO@e_#c9B@WKE9 diff --git a/examples/nanogpt_4D_finetune/figures/nanoGPT_finetune_4d_forcebf16_val_loss_bf16_200.jpg b/examples/nanogpt_4D_finetune/figures/nanoGPT_finetune_4d_forcebf16_val_loss_bf16_200.jpg deleted file mode 100644 index e5a068f4a70a9e2eef41225f47f557132d41f7ab..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 30660 zcmeFZ1yo$iwl2JIhd^)-65QP(1Pj64T^e@_q>+sTmjJ=tEhIFa;10ndxVyW<>wVri z`|P{Vx%Zv3$N%m>{(szN)ab?PRn;}CW>w8Ozi;-_&!=VJ*&As&X#fTW1~7;I0Z(&) zB!G;Fh=hoMjD&=Qf`W{Sj*EeghK5dr^BfD8jF^I)jF^;^l7^Lzl8Tv{l$4&2ftj6y zo12@0PC$g8Q<#;DoAcL3U{Fv{(9zHdF)#=@Uy#1w{I_3EtpGMM%rdM$91Ilziwy&Z z4fE6qPyhf70<^ckcKEL^7+5%X1Vkic6jU_m1CVC`EDRhREIb?n0z5qQX+P-a06aDV z&I=9+#OEp|NL0?aoB?q;$keZ@+wfFJk7>9}T>?>1@d*fth-qKa(K9e|^YHTV3kXWS zmXel{m3yP6uA!-=t)pvZZeeL<4YF}{bNBG{^7aY(B{(EB>_d2b!pFp<c3qoJ*B2AEALgRtbKlQ%AaB*tsNqy9;%Cbuj-v=$;qn`-fPpnW|S3l8pD(Z`9 z8OYCXB2H$3H>|zO#uvuD=gs4oX-_lHn|0~<-Lo5g(UQO0WB;g0q%8A9t22s)+qU?e z%tEXtuQ8Cu1cm-4lRHi&#Y13LPe_hBVU8PSTU`_L>zV?bcoT19;+Yg@c3it&AElHC z8m=m?lQ)0)63+E)Psdd%Y~;M|Zaj`<2oWtTOK{N3E};)XUZl%g+1Z%w;jLdMS6WWE znGnl2iE7$hQe9`h@S=Y@i}(x=_B3fb3XksT$t0YBydJOMv0kr_o1wZPr78dRN>;i@ zv|j;mbJOA}zJ3Bah!r*_Gd7F9Y>oyF%YuKiB`I zjT41B4pZY}WFS@!E(|SRK(=kNU}fDZ6T@1t*|;OlB94Dxzp2cCd3t*FPPj+i8W~(i z9fN(T4J=crB>e{}Z*9soXXO0SyL41wA-B5fUMC0CXi!<;T-h*Ke%FP=00a)Now9Jd0t|n+wd99f@j{Vqc7Q?3O@nV4;5KmF>ipW z+|N8>MAm|2v)a3Yp~)U2%zHQ;)HPvp%Ib{i=|Qc1t#pJHEOeTmr@4PZ$aptcS!jRs z5m58Ep(CS$lWdOB2Z8Z-D4ak34(EuX6}VTD{TZ1->E zhXd+$b#>lbjgu9NMEN)Mur%s$(E>dZQ;jXpt^>ueyc`TPwcq~8!KZ(=JiruNoSm6s zNqR{-f;fcN@1>Heh0a@PeHpuo?%5880*3nqsi;}laI3!dB1!wop-J^`(Zc*nILXR|KV`4FWi zdCY)MU0iIRWCl^A4hQ+zMVN5Lmgl!A@Ms9Gv`%=!ISusNtGPS&5BcmhRDuG6qQw!v z)rJFSNg=jbB-#=VcXSEu&-3lmO(J9c!Gh}$46qFNmszn;O#%bcc0IyMdvs>bMDaJX)r!Aq#vRF#r zXxHR3Ani(-WEd$bHnH6;+0<6S;CupROJiCzT1N z{BZBO)@?melkCyE9@E@nsfblacUDG$rd`+m1YqF{h?%S0;4oF?`V05Sp2UPU^R0;y z>Fo(ZARoW%ax~K92M(-DbKQ)W-z(*A8}d(I_-P)Pex9;Fpc{95Yd1QC)2$IFi5b|4MWp0XJ{h<$YqLo?mDg3RsSh<;begr* zO%N|05_u(-DpO6ZkIYt#gP#sM%m@#3eb5_u>sAk5o(MaGD-35)GD$pYqF@mt%a9mw z64yvMNCONrqyTTa(2HJfKP}nOd)igR$@7`Y=AtpH8Y(?{Jy))qPnya_UxTKJDqB`P z?-SUt!P5!TI#Phs)Csk(r5q8A+wmBJsQ^*EK$p(-d)T56vFMD~fk%KxpW1z2Wjj=H z6y){C?!`_Dcds-Zso-)X37gDmsb7LHrzm}!s+VGt$rHnR50)FMrifmAo$ximScn(I zeTkpRP6Ap-eDMkV&2R|%ZDIr$LtKi-h$SZ8?{N7%rz&elIW|?1dnrEM?s_M zANWwi5K@9yej-PuXP%7xC0XLACL#R1ksfaFd~WEu{&hr-!L>C15F{h9-R8MotC>_( zU|6mL(jejUy5hE~x4Gc}{r`{4hOfn{o&o(Vh^=%Elq?v1bl&tK=>@!e)FIk4v{Tp^ zwajo(76#Aj$!4S?g@?S$n^%|gK8%8Myt4_MydNOF}h z0cUdIia80hCDDtq5hObMHTk)Lh+fL$_w}bb{V&qJlY-1XYQd%4+!!bk?P-EnXeYtj zZ>GhLF0!5gio6b^Q}PqfhhViQK=z<4FB{YK5ywF5@s|qM$0tDN2{^eG`;Rp$qr-u% zKitdvW1a2o|7qxd#_0c`N4ZIw((14058`3?0aT#v3HT*h;|a)N(tiS4roi{_W3J`g zjc#!u!*>r>Pe9L}tKanfX~g55*oyRB)y)RM!zH-gpZy8=kNWu!x;Ve&{sai;J^|l{ zuT{S9J$g$$0e)G>+2~^bvHk6T=s+Fc0(h9d<>>1pu4Cg9kf{P*vbv{-!}_DP|6!&p zhOUgx`q}RGpMYSJ!Joo8MTwRGzNjo*d{N(JMVF!mEbNs@Jkx+`|8!{g=#s7T@d zzQFwOvVD|XWPJiYmx6c5Aa3q|yx|6@!53RuyPlPXQ2&e<%PE=?Z>Bi@^et&6&iG4W zkJ2J`*4d7DrR}H`F8ZIfRv3|^aAS#Hupp=P$44Kl&8cwnK%HZVNVaBS6WvA_XvT@VT;S5v}Dp8Sjw6WEw zugm7mf<)tJtW!76Vi$ZwB%Mv*v%{`37mWq3QI(aaimp z08u_uGSb~DI7YbF^$aaPE81fR)hgJt)E54zCZLu52$0YK>I`I9DvCA7(s^df2W+pC z6-OX+HSsGyhnQoL2l@x2o-^XKFc}BLh+D%IG@ZQm=pMrzgD|D-g}r?OOhdTOOvS}% zulP0A*N$9&a0RJTp9t9)QXP+sZ@P~kEq?>$dbQf0;2KoKB{Fuk!L?MF>HCT_JLfMf z$Q0OK)z*rc)yHVgt8}04_TCu;5Ha-%g>VJw@sbl{POBYXjOp$qDes&Zluu2GmzxyO zYZ=dST`9`-&=mku7Heq|h1*PcwN)kaHnj?_j$x)az0WN@R}%5QJH2B=%Jg3{A_krA zXlm6?)zv*Wd~@7kL25S4DJ>f(%Z%rdC_w9fwQeQU@&4Hax99O!eFeLv8Ns(xYBzR0 zt_@*<4JwHbbM9^H&N^3MELK()`}I^Z*4x!-weL;xdaXn7)%vnSkmxauRz???Hh33I zYq&3?kTO4qS|qkD-_5nG8?t_m|HMywX&gGZJJ+XYAXl zV0EP=Xrjhx9#7OUvluSQ770e5fa;nZ6G6ko!kYMGc98##AJ9jutZ-0q z6wbW*wJvi>CC-uZPJ58`M+9k{D2&+Yu!x|TFEJ2**EyJqH*~O=RSu#eGFCR^Czf8K zDxNzNT%$u0_V?Pt9g^v(6?DI>u$y%#;RGRwE(o)iMB`ddDcmrl!05%oP$NYer>h_H zR?do+SQmD8GnKFqV~C=qSOiiq8XvZ$NT{~LK-955`fNgugxTIX+pl?z={e2r;?)lD z^*x4vzQ1=qV-U8A>Gx;ZHRSj*epCR7?ti?VYpvr`)#%rIAT5R1D~2k)1O^6j87EygS~PFSliPUyg$8ub$Zu7$z8lYQL`Pd&6?4f zvnCzM%~=Th@B58j_Pmr2mUQZK0?o_Y|2G=mJnFxo14B9+^U$(MR8sqdt zsYa?CA!KaYV^)DIN(!9t^0B@pd9lEDr;F%rthAGiWGayTo#tC@^)&o!UFZecg5a1= za~W0tNh4fhJk`E91Z0nB|HVXfvYR^@r|q^vIU9CBo|LcP*V+xC>R1`7 z@}ZRqSB=KWstcGx-FX8L&dx}OJyUOki{h0_xX+q@A;e7yY5P_*l7(`d*}jq%f`6r^ z+g?28Tkn_#WV(1~qfWnn`7Nk*R=t(o|GoVEp6+x$6R3mZn!E^!0!tlJjGt6 z*wSNQQ8DEcVBie(RnBo?LYHH1fcYojrUspvywRgz216uI&&AjwFEST>Rch8-}pPX@r)}Y39MAzV0kq? zwwF8s`5u#_(|^1%0;50JYkZw-FK}0kUAXCo!iBea*K5}P zVproqx9Uc-1V-nG-=Ws>GW7KO*btXn@Up6Cv1M>1_v;f-!`iZq;Rm&!q1Jj2!^x7q zaCk+5p)^XHhqcr9Elj0#2GIM3 zM@rD_`%uE#nhEZ^qUJBZtkRxd+QrgI30Ue!S{-L&-sWYlqG`XKelCM9?r$iK3PS^M z`pLnaR8d`na|U>-g4{8fI))3F@mDS|Uf0O7DYr{66nAAYYkXNYBZtgi=R|&gah#C= zm8xboVt{ggh!jT<>uwwIwm2IzH~f?>EOlP%>%7c6@;r1wgNy8%NbMi0AmQtXLH!r6 zV3VYX`e^HEDm)jNCt2<8k;-fujlnfCCN#(Rks)fIW8cKZ%IIUyr2`i)b6**eF*+Nq(R|*Q|D&oUTHngbi6=E?w*wR=wCM5j3SIg$E#fZuR)eTtcz)iqbEQ@ zZQ>cYvi=DG-$Fg0W0`w@2k<^q@e{CP{_C3fW3_$?=e0uD^g7!i-drjVNPHz?RqAq!xzo5m)-GUZz2x z*>uBY<(GzG3~=weuBA&nr3dQT%4mPo&1Urh%)FJ9ld~tlhMDqCSk>wVb;maKq{eH! zW^Zs=XIJy=`GtMfod}OZ_NAaMeQxZZ4v&jNBQ*~McTx9UO24gk@#j!O_%9Z=CdO4z zje(46=Hn_;wX;*{S3<|F59X-sd{q2u0lQ2-Wf6D0(VG&UF!GvAPVEw5VK8B4_cX9r zB2j>qftT8AsLgv``{mSbe;sE|#+GaDQ#ve+X4J=+=7qcQ^A-${pm4%~mCft3VjZhg ztyC?Q5hBtg>Ny^sAK&t3{IU*oA(0dyplx$L_a^(CDdFrrp1jJ5u!n9QD5=phr(~fA zXNcRwI;&BdhQZQPB?0mohAyZJumR2bRVEb9fi5|`HET!9m3!)IC*c6|c%?OEo7=ZO zHJWYt)`gIU=uIMZ^RkjhDS3uVGdbRdqM4n~M!Q-ZpxAyG3WqooM%f?Z^=qOt`Hd%- zA1v^i>wLPi9LahY;09d6Xvs_EZdy=*>f)W+R7gtA_lP0?KGHN#yqma=5@-yqA8SQZ zN0LDkC?TZ0Dle_7xzl_#qJC7`+#GMq8`b+tH-c(fDtjRzErYJ;LH#dc>Uo6$VYv|D@1^77Da#Q;z ze63j1O52L$R-zo_94zBWtRm;^({qz(UD+w^Ro@$9t#tv`yO{)IJQ3BN-u85}@b6WU z&^}OFzTL7VobhCpTV8hou{Ja08O-VG@zi|o08u2u>go?TVDsOAJqF@hm#5r^qrMOM zT!N3IZB^$bt&=@ba9X?JO_vcLjt0RV^Q?!He)k#$5g3d&wBGg*MPoO^4;wWFbNubW zYMD2!U4SSXrgocO!4NER#d)tGlb*Rg6+VD}8w`(il{Qt&1SfkUbJF?1Qxgz~>v*ENN)M|&k{U`!wq z5nXMT+RjdplaD&^$bYg-rS}LTg^{neLcRNaOVG5*NFEg1#6jYZKCb)-W-ZVD$ z))5q^SyGypHaqM!0qHncc(PuPy-IW$1varid*07FPmocudk?&b$f(q3aQ1%_(7Pi} zW{(%;w~B3BJZxfHXk^2uEJ->>TlJB*>?^{ij7nBqvSGWT9xrl22Ek6Uexyp2Y|QxZJc|mDZer)urC>(J#K;M=lVT;*2WW=ry|;yq;6rA` zL()-!O%m^I6}}dM96>UdJu|_2TzEWgTs`JTn5j|yjipGI)d|-9eYbB)>hGxVclJ%C zWwj#@Xj43Sk%hFJiSVkR^?jz29ZNuGC>0I_ZVoU+`w~d867bJ_u*C`Jb3gZl*Ib_1 zTsz}O9aRF_F+CqsAlSesF31S(ELr*dggwNXxp&>d_r)(PP}4mW=)%{uf7Z6?-O0fnfN> zU5+y0SSK##v@*52Ss&)uvR3xT9MLB%r(a?w)S`D=?U)diJ93C{Axd%|hx2~+`+RG* zzbWQ%`G0sHey0CZe^^M4yy~4p5n(jVFVNHPdmi0IG5AZB*tS>+c&Q2e$fNsFGrT;$ zHp}G+c%eP;*x&pFWIq->0q+lnRL8l-x@?gPlusr5vqFbU|vMo5-0R2JR{bvrpzY)XegApc{YE^tkYIVkjnCgTbWB zD~cL*a9D9@JW^G$C}(DJ>5c3!O}e%r#s)kAZfgZHD=c_x^~VR0v+HoC1-3_R(G*;d z6aJ_~=x9}AY`gdfu*53y?A9nFv8^GnFM2tuI9^|Uom3vm45x{f)rc8gH~NuPH9N1c zx`d*Di_?co;EBOiic=kAPZiknnI>S^)qpfEzq-bSs*w*t+nl03L!4CPGN9Nd>zTWj zf{g08-TGKXE~vsR5+x=@bOUqyCnffkll01gs`k#5vsYZXz5Ux%$IzM_5H8#%^J(Ol zv&mnai;Qlz{cLLr%dI2#!%Ba63cw*m1AQ0EFxSdR86fNdP~{H0mn8-E+|Lhh7H4f# z-mv8OQ#*yI#J1G}*kOCK`Nvs7g^go&8&hoVFYsy$oINp9sA!z?B0dli-UmvPh7y?y zAQk>PE0|cS$8bwFo?50o?dUHP^A}umYT`-L?}xT9hIXB}wA||zgO4}&6z&Mpn3KNh z$zaA5$*R4}Tf?qAjo6jzRx-4OV!HwrXog?wjtf(0=RK3&`9_c>Yr1r9n&OG0_M;$# zkyJ-FAA?$+WoNFn5uD8F4LP%JYfHh~Q-6N`f?oqZ6Gm7|hWXhKGtMs35V|EIcO*08 z$Qa2$@2e7exX?+~mnR-o*=yf`LZ!28**IUy7xSu{FiZ2fO^2jnO9!klMy= z7>daZuTUh-5%k*2FmX_G)r_V{H_4LD^<_Dnl|hCfM8tB20KXUXVmI zA}C|RMjiuj1wQc3`5fiWRJru+rICOLP^65MdMz)nR&G5u_(&z6A&L|!;abq;2rphZ zB}w%C5wSZEM`G^{E(|42yEGqe}t_2T=Yrmz5LJK|uoi7Ur$s zeI{{X+Cc=Kjq@kQwNaR@S*1E~1fd9+amgk#6^5uArMziCGktLR@a{dOCe; zpMY3})ZadUA}a9UUc{}(_QsgUBn1O9*G%LmAQofzA=sVbfjPBN@b4IMxe?AxM27Gi zv&k*?^awxJ{kguEPJ%~o*DBH8yUVv~CG%wxVtBhUkIT3UOIgtjp_Yf*xx^zHXaSiI zEfN$5?|y=}x_Z%XxY$Iz=qlVl0vZO2 zBSt$~NiqeUSOgE}`Cwg$6!$?fVdk@X0z<4ZS|kC8#0@(5&0M9~$}+>UpCIDhG17!F;pN|lZ7F|bzb`dQYRF0SfC*1LwRhbCz7V??v(34ncm4?5Ws)E{GZw$XH?CS@$& zl@s;1tQ?hBu4xK(<@?SaqZrz0?`o}SK*&@s|5Bm95srw|N$2XNz{b021^w=Z<@Kp% zg~Hv&9B_LxNX)IImHLNsL%VwrHHyRZ=Q^WkO{qtUg=UjJjH(%6XDam60_e zS}c5_?6?XgTt6U<_F4U2wwki_y|&Tb@J{OZVdxh=&E(okH1@SP{$2g}siL`U-ZR68LKepM&8~^Ax?!E>=B#h0%XEsanR$3t zlE;fBxoPr_2gO$7ECd7LFF(V2h_jcD9k%t52E(uV)0`O0Z{M<}*mK8zFDTAuqp{OS z;?_(&mZ7E+X*p)>o6eqQX;#Ud)$^)S4Ovoe%pm7*wDC>U&J=(4Np88Af`g9ye(LB0 zcf$p8&0QLM;hEC6=v`Vf|Fw_q!*<<=NZ%Rl<)t4-rl9`j0lCqo`I!91{&33omoY@c zgbSGu^l%@h-3Kb@Iz3bEkD{~QMvuP<@gnv*@YL@grZ-jZNm_Yk|Moqk0x67VV2=V_ zIu@Bnet@s#TOA`(e;mx$&AX<8crhm!ycW!mYOl&OeJ(^k{Mb3Wm}O zqQpjd!7#k*Bz?y1}tkQuXc%}etf=6CT2^9;b^6EWBWyg%3ca1 ziRQig&lvRGy2i{j*9zO%Mad$i%P^<~(QBDtM#(xi$B*u%oLkc@kW&?U?*92X%1;IE zLL)Xy-#){)W^{%3?g!#{hrUvNj_%A1D+&(oGK{ogb9PQxDvicYeyyUDw6ivXwONkN z7#EBFfMo_bJ-MSR>Pp9Wf(mv%i|&*^cAJO0?>-B5d*K={F2*O>EdfB5hZL_gk53m2d+8qTOAtn4 z2JYT6uG2KVa{$QsMPv}u)-qmQy**AGn#yvSntr9{HD)1Y8nF4A1`*!PY=StQF3fJ< zE=TKgb3LE+S2?okK>mZ<+Go8z9*Br2>(T5Oa;tuL_mV3&=RYfxrsT-!bW+*a8RYP~ zwV2WmdQ_(TIan(pS!jELoC`LmHjN$yo|kk4?9%Za)YP1uH8$CiIq=v^0+ps%gf_1> z$SH7H4hv`ZlVo0Tv?|525)1o>(->4o&3%gN5wZOm(IFVTS!GsiJKOBJWbDcU`35PP zS=B*JTx@MhqQpVsC)&{^rx=-*sLq^!`x4}E$SQyHPU*-3#CsvaXC`to>(BMJF}k7B zmiMx;t2@KO*d)NbHX7fIcbq1@HqUJ%{zH7<-LkFH;kq|*AhSdOq_Tl zAvqDmZ2uE7Hm%K*Hu_Q=FHAdF5}nc}lyNhQtVEdX8k|@Q`kq5Z zF&&KsQ)!}cA-Jq9C>bshx0s-89M>ewksz1vq5J>? za~WNZc~sy@tQlIaBS?D)arCq&wd6r73M+76NBe$?a?7}x9n`it*CB|Qf|>K>NQ(Uv zn{d{*&-PAQo-(icrM-T5cL6Hnhdf7k^ET6(sdGD;Ny%SkbUQF=T)ac5A(^k1 zGqYTvq^?@O{Q7)9F?8@GW1uG8Jdn_ej&w535q=xn8ef;^M+U+Jq;ONxo4CL0XZ#Nr z(K8&m+Yy8eG419FOD~!ccM-lTi==lgl!fomY06!H_?QI^nd!~j5?e*qEW9*uzHn~{ z(Q#OfpXI*eq`#q2^?!!uTW{)`5Y$$YhK0KABRr-LB9_iFaU1>AIikmyPVVF!8R_qQ zpkqjQ++CNw(8n@WXZD`e9OOtVXe;9^WE?!0=?W0|1Ns0v@WUE*DWr1iar^`^!I?Y3 zFYuhuJT8JSk&!t^PWZkcMdmYJ-QNtO@^5f&GW)l*TAEQ=AtW6f8UiT9#nC*&cf}}I zX1cL|COnG1p!FtAmHm^upQNJom8+%MO+N-yK7R|BzV_}yL= zxAag3*B{7Xm4CePzazD_Z7u{gljDTgwglEl*=HyeKN@v6@43qyE4-416ki#Ah?$a| z9Nx9fFT5I&?}Y;MlK(zEA4=&D_zjKE%m!aSMm%&sn)=h_mi1o+zIqXeq}a8|2&N3Z z85bUsdGxafAJev6hXZv|>SsI@Hy02nI-m3etTL#|ukMo1E~p77W_*C#`K*jkN)Pp& z$*YUEZz~}=lAe6Ckb>Zv5G$I_Q*87cx$Wn$`k07o;^YzPJYDii?Mi+~z(rnUV{Y2N8N*PVR*W~O^Hk}h;W(gk>8bPn!Oro` zC-YvT^mKtur)6TUjkr1DGU}BAlJZA?v8i5!7L#3>J(-}56O8>-%u2^)>&ouE4%XVU z2>)I=jXI9tsL9Vx`DP#_5o@?H#|tzA&XEN}`HglWH6eMN!8Ub(Q;xY$|0$&Ja&9}Bc@)zIB z5rf;UmY3qZn;VA7z@rw~7y`luE5&bNjj%!VRuVKXs?5ku@W*-BdJ;Ku67RLKmt_NL!d+j3Ss3Y5K`?U7>RJC}m>!>z2aI?hU3TSj)#T0~wMP7d)hB>%$Ue(5NS@)&&16y!1mRLZ=WVsbY zd#hgK%)_KDWq1G*1INEb8Z-w#dmR&V{mQ|0{D3#>Rt*oZ-)Y=TC=V4hEV1x?eng`8 zOS0oMhzFOPcEkaKKVEV1jn{U=G|HZAj5{DyS)Fv zq>PkjLwp#QZ*?cHQT?SA>EY!H_3}F08TK^NDGz?esvn$I1MjS5L!se30;_3lQ?&G~ zU^)XGY4}zJB!Y?JTJ~N8oD_T*997DA=&ecGT^!WJJ;aOj?G+0T8FfeMRd`V}gRc@x zWV{TzmpepCwGYz$H(&nH;bt&VoC*Q`b$84kG)$wJ9)jT zZNzcEZQyPH{*w5u-%xpo)9uU-jO*-n;GoN5e<4k_M4=ZUb(n46Fg3%?O=8ftq zxvf)E1++SbIMw;l`Wt#Q<}gXK96!X7N|tPAMU#b{Y=|LI`D0H&S@h}(wbd=G1Wz_H zsy}pUr#UB>Wq8}R%)jRR(`}?!d1eb z72S#b#F@#z=f?L7&N7!$j` zBG5bt70QYIWzy!)>OiYVuC98M+IY?eQd0MV-yR(Od$ZZ7w;`eVy9XpT zx(11RQjg{HbNRudX#rg_?^YJDRJ=0pq{k~$Y0-!>CrPu(fZk<|;V&`|JK&==a1@1k z%pFwQmb#w+Vj$#BW}HVz9nvyh!Fn)`@FEi-j)06ExwTLsd;5xNLMi`eKMTjD{pI$e zqM%*xjSps`)!!wOnW26>Z@D^Y>GmhWOevZpF++%|F|w47#$u&4V$dT=bP`o|;zZI% z4J3n;uA9(P36=XCe$&U|lGOVp-e<#qcu(_tAH@|k#@7<{$o;#qffbQgsGYPI9daSD zj-%C73z3mVFUK-PMl6hZr)wLOC1$0s@FX1pdpwqGv!uwv^F=Wu%f$78cNwK&EbJ&T+p z?GamUjf-z8&x__87_T_8vp-f9>P<3f%KCJOw7GtE z2$}^4-;3oh_?0P))C@3175L8C%C7m6=P|-`8UEIp|De3JsfbY~?zEGcaVGZ8tiF9Y zwXZ#g^uM(27S8X~jrKUZx=dBkP`f#W%oi?4N6!dnD!q3V$G5+C zHF0%!$n8;;Ad39~0?Z2B z&mD(Q;OUB=H9PN?g=BiB{IF*kW?iq|o4?c2`3vofbm6jBlydtNtHs^I8TIw7nDpt6;9DU<61&B4l%z~TUMC}yd~fUvkBXI_tVhIrS%awr@wBsFyj^Ncpy|lQIn|jL0i{NJ)8sFjM~wL{p>(FE(B9^d zx%I%o*A+}CZxS&L44h`DU1FI$+tmyeR?0kDK;-W6)*SK*_(zIgtz;0Ww> z(xCp%sv3Efz%T21JD!${a$Et^K2~H-dS|7Eu+0t4$E~&lDVC{1DXUBZFReAqh0H$2 za_-i{U+pDK?Cv(I&&K`IQMOC$D$|t4Mezeu^w`bV$GKZtL33P~!!|!n&DPTwUR8lP zr*NJ%KjBEY*|$eH(ij!kV&(oNVik&X?#xDi7XIk>>RzsNzCVQHR65w8eByp$Gg-bM$ zz4z>$!p|vX^q_@};qNWn}{eoUB<|Cocl8Uz$dGFdU(l>AOU1G`(JsA&~xH9&fwosu8H>^vC z(1S9N9`03m`_~g;pXhAk>Tw>D2nnWF?(nalIjBk+C!@DMlYSK!H&tg1mwr5bAKGjH zH9XXzd~2b0MgZ{Omi(52iGBW~|XgVl4V^4YP2#~~b8zbCN3&nUhw9Ye}ki_y^MK5PEr z3UcuFq>VA$nr-a==dLc(^KRbNy*8f-DZZ@gH~nI>8q6)lcy+Sdm66doi6O`$$VCKJ zPMr5#9&F8AMfz+q0nXRX4@{a3SRo0rkEoD9lUjLATMwjawV>P0##;qF36?TU|2<3USGF!GIPiYXPp|Wm%gJ;iKT+n6^Eu@Mp_YlhMd&S3mx!f+g>v20 zXMO4X3;xq5Kx~0{f487M_P6bu5%MrS!y+G;Y35tfyyCj+=MSG{44U$!u=yJ93+TKz z#&lXhzQ>dTl>Vgje|Dt*y-~x9Sk>cjE&A}*r+M>voMXX4nm5f4B!iuF_cq5LC0?kf zn`a#`Rjd}N)`wPMi8l~iHIM-83xI!e4;b*sVHGR3mf$zLTg zR`*o;ax9u(>P04(`YKI&_jr;B)&+Y0VKYiN zHW;i+jR`Nx?0?5Vp`kzbL%Gp(Y~DpB#e1(!#J2pmeH&9=-M{SwwSU7tc%*4;Gep)U zJB~E0gCq?8@vQJ2@#vlI?89)a&RV%AJ9iw`HO@NQ^3Thny+xAR>De(k)+!cO7M)17 znSqVorVv6aq@aKs4Z?e|RWv@M>gz97AHW+K54lQNWv~^ule3()weR17KixQ%##LLC zN_s*(dm6r0=fw)jMK@`?Q`av~-mb1^UM{KQr*?Y88R~z2rvI9V#W8FGQxmj{y~@@-NJlaT%3FuFQ0!6Yk`ZTyZxOef&*g*DmgggVVXWA)9E4Ebt|iq-5oo zq7Wl!fcP>W1>SD?e(O=En8@lL5lrecM%3rkp;F(Jgu$k0ac6O!Wm(rdj}6|)|HTsy zv>b^|g0)%(#)1$z)Nt6lHL+Iog{yb-_J*2OEsN7Fl(qyOv_K$C!_WKEOzdVsz1}mZ z>0Fpfs@dRHE@(mm|A;7!VO#s$E^1wMw0PeQEi@?y7VSj^scKY3>0`ysdz#tBE2F*J zyC0|F^yU5komR}hX1)BHLOL|*Rsx|$3Duh5-i-0&wG>My^;_OMM$S*wkyNNRkfU4l_PUAoCTzv4lLA~nUrWCz)?HY#xh>3RUgw;X^kgJl^?gQy z8m{fHi1s{kT1pD{UrBrYo#hXsr2_pWCR>zp$84R0LW9bvrIDbCXWsBG$i0=EdQCCt zMs1)fgMmxO29wHV{2rs1Ymv_%r=|DW;G-||>gz&wkpfoG1qeF*h;lk zc-X}`g=v5WboGL+t<9RY(;@GK;Tp~$N_~?18e;$QOffwUZ-NaCdjb*O{9WW&#I*8~ zv)j&4SER!rQBUu_AU^5ybN=Bq&9dlY@*M#_s_|+8MmVcIAVEb#OSRe+;%GnMgCY&i5Y?Nr=DkVCX1UT5U^$0N zxc)VZN(?d&4Z)IMI>sFR@U6ZvuG>VAbOnpVhcwtCOOmqQ1@ciyoSs_$B&pp@#0&XO z>9rYipsCv9_+s|!FHU?uCLbxrGiNID-I92`W^@S6Inx6fUS_r{h=f9Cz5D-`mGiG# zGk8kM(;h#wIjG(L8ao!12_mP0W&Y*o?-EuEp3r%_FR6{^vxfW zB;3P06}Ayck|unRz10}flNh`Bxu)=Di0(9cdp^F3Xf4rLl1*advOFQz_Ez~vRAt(y zPwW^8`|W#EnR%!RN@mMZ zh<&p<_81R?L08EHhgRNLB@5l}{Z%h*59ALYTV9Il`8pmx0U~c)ys8wsXPUKE?tQqy zYb=^EowJPnqdI%ttC~a%c{`6F$I1!kn`t9Q6Dvy&dk{R_V~RAvjYc?&oF`Za{Hej@ zeT#P|nY`{dDPCRt7AKT8+q>d&X%VL^zDFS~=JO}uIc-}!lOLR-CI$B|3f^4c=FXn1 z&eYxQ!UDX~@cSo#y+*Tl-C_S;#W=>@^&RFnOBYt|iLct)-9De9*ODj}%zC#ax`f_N zn9AauDK_b42x`t&u18fp^ttwng2SSH*ngnxF~v{c^SNW}H*ST7_xSEbG)V(pKHnqr z1DZxoGNq+m$PqF@?&cHc@r8klB%+;UGO=6|2Du-Rv4n=>n47p`Pq$35BmC@*h{e8^e zBWCe&YX%v?2DD4bNhAV_26FSr1zdcxN{8AZ|uT;Sq z6!Gn~T`fG=LloxTX>BpAL}^+4W5ID%mgFd9iUZ49is*hAsUPx?oc?qG4@c>>;+SI^ zBgYxJlQ+dMMX+0q&j$oN(leHC%?Y9&-38Giz>wKaV|6^RY39}>DgO(V?AbgkjJZe{ znaDe)?8cw-xoM;ym~CksDkC|G6eVzX&3-%n7z>L$n7NzEa>Y;cEQv09C@L(yRBsMz z;@}5E-uW`w`z@CcsbV?FDA}oguf@Bl{w9U1M)bzPd z!(Ri&#N%FNN)cU2i=-&R$q9KJZpvn1!3&TyF|R6q;8@x@E~Ed9P?&2qgQ<(02+P~V zn^?(}XJ`yG7N|4g2qGDxcEu#fCWM0rm_{!Cll(ov?D`MCEzx()+Rq}Sa-xxsCDcP+ zS`XUtPTw4xKRuKu|27PPl70UVL;CyAVY;u=UM2P4Dq({>4g@UF{#S8d8P(Rdt{oaS z(gLMukpM*s#frO>QY1hL?$8$Z6bV`g(&7yiZE@G)Zbge1*Ayo>1d4mV?6c3=-`)G% z?;HD!d(XW;GFHYr#+qx5xz?O--sgGVCk1-llUon2Fj?WbZaz~C;;xy?v+=912xa|W zb|`tzN5qJ~cwu^IUHoWWVJoQ~z2sxMEZ$j0HewMmdNyVnqi7W0km={JlSBBBb!}9O zuy~NnAl{Z8Wcl&lj*|zxZmPck<6`XD!Z|Z5-8K^jwN7-a!4Z@yroS}}c?v9p6ZTkw;5x(5D%Pmnj z3CvJIeDoJX@E<1A!`d`5lW*wHOzSO>^9?j!zAa{~AyT~#qzbDMii2RQ*=d%#%eI<* z@XPxsE?D)N6wCfx+S|O5YH#d>J}_a*c5-EfB2V!~9TSFX;pF+ni|buC@^fZvlu?A# zd{Qtj-3(5l`s&QW#myHT7V91TLP~?knZ3-w+h(nNTl)x5AA6!8<1-&dAP(WtD z0OTb{Z}m4#&H5>HHHWr*`6$oSMhEL@+9MMzCezZP$R(=1XQrr~meTE)?nw5ZfT`m{ zNyo6k5or8H8;z3{BsP9a!PMv42al=Rdi=fkmV(44^p5g_2MX!EtYyG@*02=`nNqIg z$IOKX-#9$+v02ySDx{sL=+5R>TKf22W8Xabfi0wO6)i)+3PsnuyTl>Wc1h#tKDt~n zrYu2;RBP{)Y^yVFaVBvU>!S7LW+<;e(?RM%y@K}(?6_K_SuRdUPHU$FpY1sCps+!JYd8Wxpvx`kM@9Dob zoCF@QbMA7o-`%g>Bg5JNV%TG>l9M`Vva}s2-9|$zi)ljya=^_hRy@{fyr}ml;LVh$ z93<@e7*nfuEpNFVW+k49Ux#0b+Qg|=C ztuC%1IXAB8PLY#$o>B2cRoDiK+1r<-O=je&cEp~dReG~c&v zQ1s>vO|cU3r_zVtl7OhFmqcxTLk7F9-=0dUe&r?&@PMI)Zzbd}?(}?vQjEV_2-xq94xSCtCrfST z-PS!$hed`O2qgts+iT6M@cSPC6sFd#AxoMb7v23~LxvG-LZv7plBa7&v z{l2sWC6CBT{22;6!npm^SBxSn$fla59QyM4J;mZ{?9Ehp;%W zDRTV9C$fC;@>M}nyr`wO zVlV6La&>-T#47!}X#%sQ$3&=m8Jqn(HCtInwS$o|m-w@uQL{4v7if$BA&m6%L{3VbKBYsbv9ST%iem|*cQ%dVL-(1j><6eMwl=$RbMAeQ#M@c6yOORaKo#@E{ z85hJQo<8p^$K7SYNas`J+-gYzZ7FD?(_1Bdsd({4JXi+p$@czAWoM2|Gu8-cSGu%t zb>sHby*c5_f| za>;?~2uh%*_bsi8(|cWX(8AdSa^;y=!Q-BVJPJEyjD&m3UTwg~^ZV2hUjG zRabJ$AWy#sxrU0)&RT5q(D*62u1o?sAZELKB((8S@+LP|u@~%V12YsReNXgh~k6 z6Q39L@z=-`7qEB!IA9$u>CW*8)~L~Qgfa|&F88;|@jLKG1XXs#^jDkTkQA!K*uqeLrEB2g|q~)FgPh(bnyha^Uy3s{U!X?%n zC$JjL!`K5dmu}}>-R%^G$6c?@b!G8#yN!FDMeDKfqm*mB} zkLAV$lri!{vtJ5@ggvdXoW;wxx@S>is}Zo6mIO!=#+z-`7(?YeK4bQJ7!J+Uzt3nD zZ*#xb$%&N@@?dMn2mXWy_f~xi?TqBxlX=#f=L?o}&yPGsqGw}zyHAzroLVg<-?@$V zW+0aF89q@cW@{iYrHx`6wwGF-r-`xqP0NG%%TGw(0)j+%5@%xodkV^V%XhQ8U_;^5 zpFD7x%X8<`Jtj|>=uiUG;U3L!*^+_-%GogUv6TP$v>@c{sb%vpdYNC@)_qfhm?}|} z2yiF4(cGbx-h)IDm#JVsEoiOkLv@M~bRtiK{Ej0>?ch&e|0w^F($#zmx#akg zgpFe+_El)CaeFKNWOnt;`>LUMENpIqes#wR~R3-%%VRqbUNKVQTL!jr5@MH-G0tWUJ1Iu@G%+;nV9~{Ql7~wY3b!yA#vL8TX%^v zSkeBG_Oe9y$8)qyH^|yzFCT%fAi1c_Ph{kTh$kXFLtk)F97--~4wyE9yCc^tviX@| z12TVSTySxh=LD-HpPU6vC{}**8^7lg^4;%YD~bs4#>KBVK%?NSR1HdO)1VrBn9$Os zXYo?>(Ulc~kJ3r9u4gg8he9tal{HLDp9#jk(ec24F+eEzAw40A&j0mFxe+#RcPUvlTLNA<$y^R48a9*wF#OlouwD-{kDwa4Qh9RBj6UhJa6 znaHbvVbQYD3lE!EUS^9>r0@Gq$9U@~On?}7j;gl3^4yB!zN3!ZacDy%`aqo4VMxMl z|C70h`UW3(2Y`m;ydItIT_QAK+iidUaP27l<=rE6b&^)$^3@ITKt?Lbtd94_&!<#*4u$8kXI5hag-OvR!pI`}9NPzRyD2t5@YTug1v6{a^*VaOb6Og zq1!=XsAz0i(_9y%h(B91z|7hw-#Z(I2UyP)0pUy*W$4+-lX}$BWXzSPH^k#mb;+Aq zFuxX&itU#sv_2s;b>NIAHCywf0R@e*F0iJ3OkJ(81|lY0LtG)$*0pXIlDSDX57WzQ zvTPpNp!G-m5L&J6@siHhxIj_Re3qx=`;5c|NPSBbu|&9Kgy`w@3j0lUlFi~S(clF~ z&#r}My6N!wsm4G#32&^fY2jUU`f4 zQtadZ%&zDkEB$}r$g!YrXsg4oZ}PO#P!*%vZ{QK^-~2*3j#Es3Qr#&mM8REZe+l<- zTe&tIbn1u1@%x=uj)-dU>ncS+rA zCN*%Thd^WK$kai$?K1Dh3te@s2cyv--#r}?-sm+XjTxPY zI*nrMjg-yRJAlQXXGY=b;f3zi#dg{mQxitRp*=G4gm(RpF|9vM+Q%>xF*bGz7=G*ftf-+77lp6Wwimc*ETG_KPO^50@K}20zyv%y z&pRN)O8B=;=3jU!TI#>vc_P>>jLTx4a)!Mm+(StS_^Jig)&UHhjx^X}4@%2Av*)|JApJj81wOZN?`iE8A@$+7@qrCp-J6I1DE!Um{tqGPKb`S+zx;o7n6SI_8}7WH7{c}AfE5&xa!oI( zz&f;YwILRXs}(k3}4cX`fqR>3?`b5&m^CMN6;#-x^J#kImwK!tse0*|TLwbMfJm|Wr6=-9Ji*K2lBYORVI#GG-^ z9c;Dqw7&n*`G#ViBs0AK_kCR@up7xT(d16?!;W8uJs6>QWZ0#v!H37lNvz&z{`-4j{9pb_kw?I8WEVOO+v~s!-tP~zY0>0&aq(8 zyTzAHhq=Q4EA1X&+B05s82ki;C1X0#xCB3G+)3#k2vKlmLOi%>dxB*W>?yfDxlIDa z+tA#mM^@+s!qz@_pO!G$ki9bMPI(uEb3bBVG@Y>FKU_eRo_fZ^GfVlsGB{Fmf)!G! zXXtGH;hce0M8!P%>ZLQ}R^{>ITU*p6aZ4e4R%4G#s)}?FE0?M>B}}RE&orE(yb4f< z3*;lVU+#?@r|3wkqQ%R-X{DY%hx$lhE2MmfyKRZ3t!_Lz^(7xIE#7^VY!X)^OKaco zcH>f&kE$ccH*k;i39$$F3C99K$Eu#;v`?yjvB~_j?SN=cfAO8L<9`Wp_apU9CU^UBEEK*zkHt; zVTD_;Tz14f`o_DEV!Z0;xfCJE$j;|aUzFjzIV56cK&F0e0$x8H-b6GEN?>C!*P436 z5e$Fa_X;#S!jTI`_C)}RAZlLYUcrG!AXRbmy0P}+Dru__8IHD)y6c%~c}@R=G7RLW7kdf8pKK** zizr$$W-_SfU+9pC=wX5>*PAJ*r&5szl%bAs{8;wqb8yO+QXZ?<)Lu$Z3YGUo{R9m0 zXc4yrD~;#M66owJU}GGSooo_J!eLwzVlA&~ZdxfV*+leM4nY!M>D#|LgF$XU5=2^wv9wb+_jzl4W3SKs;nL7T&-$|0Am+b4@V2%Rb}@ILQjf+2Rq~* zN)%{ic{&R?1a9=n;LZ|T14H+N;&P+|tKW;E@nV#Mbz2BYg?c6h>YcGrZ?ddUTGzt@ zWrsDjMw`^HBLQQsk|tw2$TuuX8R8&=MTM)|-xGilZwE!mtI~+B_!41jX(xs?Rn4Mx zp~p1wMDBa6pF~)Kz_h8?c2B3;k;Xa26BRxUijJE79~=c~LNVs6=8e(KUL}M87hlq` znD#FZpW_%mmP2@l10~!WEJ~vW%PTlI!@$g|Ykl*Os=f?5 z&-G}}+lZ%oA3g{Ju9~FTtduFzTHl6G{`Q^{sm~wq=SGC4nw|#Vacex|l zSsBtEy1%OvU5W~#PXTzMBbVwW)c4`-Zcf$d2O7ww_ybZ-RWkaxsC!W^V~iPf$%bdn zdmWkW-73V}@$tgl%9;dxa{SM}28llHYrAj0GTCqmy=6_|Wx6@JKmXMhQA*979kP~1a;@UP>Lc##kC?v9PFd(T$OqT^_k;Vqo3m7tUK#1 zzUK2NX8+sn!x3}#AxCe^1Z#@9sTQ0#1*uO2&79Ldo3{4~+veam8WzA{>3T_As-oF7 zqO%{2dWjZgu!XveTLU~K)Nq7DxT>%gJ11bf4~NtuRVxn73_nsl#aBNlk!dp}^Zudt z;f+}Ii{J~K2vSvqw_Q$4L)hB)0e_T*^M)C=IZh0~{B>r-#uZn#LbgKAN7nd0(P$XL zyTIanB#>MGV>FhUTMR->yI5Dg0%J)eK9P2>=E$1c+~2=T>c~FtIGq3r{6;cuOKrMY zLqW2cQ0&0zTe8twe|l`5NyDYV*GP< zQsny!q0H28d8`kg^Ytk!xHt!q^#f1KKs`exUX$8kgikh7K}}X^a@1=X_Dmo8@w474 znubRHU}M<#IyW0_zIE6^t(swGV*?`?Wyu&INh)~O;}jSw3q`&5b@U$XU-F@FpdbW& zR=0PUmUu){A)vx1%TYi;6LSbmg-9dOQ?)+Hu^_Z`QLHu{UKE~Pq)9j zRv6#DT0EinwzrIzlXCCnN@+%+!}6V7#!>H|05Ss`q)7@zt0Tjg=oNy&{Pai`9=?v` zWcQUs3HMxMJE_~_RhJyls#%pV3Q%9;{+41tu^4cCigejW3^~?z5wLlKLrq%&F~^R+ zV*!ptF16FM%#{y0A*!$zrq7Dx#!Z`|@5~jMr&_eug)Z^t;uFdGxkcUmHoSN!AGh30 zsqY!xx1SKk%l~<}iM3yn7N6X?WtSk*YZV41+BV@zFsM(1P-`mZ_yo*9S_2c5u6e$# zX@8>Pm-U4TfvbeE<9IgN^up7H3!R8$B&6QPa2eU}nPYZ(uVfIbFMI*RH*v{S^08_lv;87gDlP02~|~U;_IC?q>iA z-~lo+2pQ=C2n0ewd4P(Jhk=fUhE9xwi-ku{LP+Uh=qoXjFIm#%M%W6 zZf;5je&MH_LTp^zoIf7|hk}BFj)qQzfkDJcO-9Z6FTd`a0qh5Gi|~F3aE}0ZY&ZmL zxcfGM5&+j>;Xl9N;1Lj!kUnyGP*Lpx}^q??YqbKEx*^e*Bb_m7SBD zmtRm=R8d(4sjjK5t8e|@*51+C)!j2PIyOErIW-MkT3%UQTi@8++CDn|adLWgesOvA zQ!Y3F;XlOsi)6pZg$Hh%lSxjbm z3knUF$|1gy!!RlVE%y@L(NEF-B-tMmEa2ag>@R}-RW2xijsORnJOpe&6gat2#Z{F5tgKm{WL;IfkhY%zWo9uczkwKUG!_F)-! z(y^A;@6*-!4mKC)uFd*7*FI%OO4K1sLlNKPyfTP6tuxba@6O5E=JUpiW!-!_mXGr- zOE%T*qJR=eR?_h*p~hMyIeTGhr+h|x-ass-pNzs|HrV!5;<};U=8hyx`RE=9=1%S{ z7VK9v?axL^qYKcmIL@qiR(NUmeSMQ_!cRb(8TE&1u)NOqyag?pjIWj?PjRrG7_i6o z0ltfX*X2RDYmqA@b|ZteJvXk^b`;{CE#JP$BL?*;*VJ0&TI||IGCO&|nOj|?W;vDg zd)}P0;4toNEOi8pyWY`QkeC9Fnn(m4 zIOfYL3h46*@Z;Vr(eWkNB-BQ~n$=Dbot{`N8RgaJdEKI}f`GhgC*@92)i2l~fLV+& zJ29DSTp87$T1FL%@qVz+t6l%8(6tis9)3{P#dnD#jOpr8w$%h>;*X%}geP4ONI%s3 zePjA32`_Uq-~>CMq^8|EJ8uC?$!d@6ih!H3wUkjzDCNG z&CgXc_W<^e$7!33CC%eXFn=tmspU=WyEwBAhPA;y``T4>4dd*KPO-_)ITfSM_(^NH zeaC$^L2*y#ll|tu?!<_Ze$YT!7R7eUZ>=8LK3|x6wN>6%1g8F>z7&(f2yI(RkG8K# zb1&U9zQd{hijsK`7P0WqFjNx)q_8PB2rFtAUo z#c9&;!N7ycmn6eS=I{1P8VzofU6)aWg!9q36Hv0MwDk9%^${a*enAp5_<)`|zWG53 z6lM{=v9u73m}aBVUc|u}WPkVJ2A>GH@_T~eqK_BAa5NZ2oUJmVH$xLEa(w&gsmN^l zK!XP(LYEL@d)30*+(-pB=`}E8anc)X-`4tzNH0iVQt%gps79$pu1nNY0Y-#gFhtG( z^WyXK2}n0O0ih3Rzyw;SiZVS$Q*^Y5N^>FcP7Qgn#*)+FoWycNvIPYZ{So@{qk6oj z9BmG&$6P6kKl+`ZYrzl?LolUIlsqGm3~ePL8ohmBK#w?)FqV(>3)^5Xku0w0Z{$K%})#sgHH>-mj zPGs7febJXBeQarjz(oahTc}*`z#L8EQs>@t=IIrExFt|se%r8=JK7SbUQ@*(KtQL> z_2uC%Ax1ozEB-dO5-4FaX+FD9U(d}~`Uv$RpjDm|oR~M$rHzAED4Qx^E}!beRiAa~YfcjMz#6QnqPsLZN4G$+eXNM#q{ znx_X{%Fq_mnB3rwUQd){%Rp0qdsM!ybU`4*(g(x=FIHIndN^;MMe!!g}{4@%*ebnojnWLf4EiG=m zI>QHj-qVcH997Rhu+Nch1wA^5d_BUpC&WXM56-jV8+MOdP?(aFaYvdIb?-)GyitU2 zLJkF558a0}zN=5%k=BK7$k+7A(b|4s@Q0s^5nf2y|q*+L96EBQ~=|L?~C`vyl$ z0zKftS#Ab=DC6H1OiKB~j=++O2my@T~0pJ>ZA!d&KSAim-Ta zh3}+z1BBfJ@5@3%feqg!xI@=l)Ovk~U@GeLX_Z|Qq-2=^zSH=(TfobdPqGMpP9G(-NgZOmWkN+OHkPyOvsos1b zO!buiq(NW?l%=Q_y4x3Gvs@S z`cnnHe>UU&oN-ft@DZkpT+It7tec`2NdKXV)M?1DRbVR)TLp};Rj_piYgmk+bQBqs zXGKWBr!sqa*7I33{Ez`t3yb>|EzzJ^A>D-S#>o|2`EakWbYhGa!ek2 z1~AG03q2mJ8!s(&ZW4VF%a49O+)w+mh}WNDZkm7aDBO_Z1TF@)(*9P=s9(CoGX(9j zVKqvU(_&2>;|K;f3|*1ugx{hRR{tv_Q4~&+3K=MO&3hg>0+lB1yCB8;YF9VlYG_Wh zDP0N}`?mhwyky@!97OOPs>!$q7V%)K*ojYt)D1lb+HKcdaJ)QIyXv3*(Z9pNeG`+@ z?O*Fiu`tBi7cR@`cn7PP;nG1_w zCt6Qr{Dd-Ic^u-+Cs9SH9pC42#Px2HD#1k(k~)UD*kLI^>ZEo$6qJRO6`y`P)~wAsEQNo8OwiIeCXG};{sMcEjB!K zI0l46Jh;IN_<2$T^j=$A90_#ow2YWek|0%OS4IGC2;k6@bWD424@6@O-Uc~S-m)as z3jmB+p?_C4gm~@&>H|XGDK*iPNNByl?0tx2X_C-h2bQD3D>rOV;hmJ>KA7tLBE%spatZr%?$X)Y^{dMz>YiS;GGWX z@9UbZ0Ucjuxv&^l{kAn%ZUeAHYCmhhJnMk{Gn7WYB@9uGkt%Oq-cdQ`?L%1>4$imA zIP{CIF~P*(X3Py9{kc77EVv8(4efEKF>9B`2v+AD-^$sb4_W)Yrh4!NdNc337t^e# zH#4wE$2Lg(4^N+Pan-MCdXgec?hg9Al1H(({S!5XgxD{pA_FFxW7AY3i(jp-}v5`6r{Lc zEw<{|jK71Y1o~h#vW2@N@L`^bKlP-+Wdv=5pRq%chVppQhbtM`a|p~_+u=ALEUMRLpUc6lbo zt6pW!%72jQQ+nY5&+;x?l+BGcsV472M|GULiA!Y+LTWibBkRTj-Yb1c{j!`Xc6XpX z5+DPU?+jUFV6l+2CzLLxW4Jkt=w0>5=hL3)C??ykc9Xs^(J#WIac~o^pH#Cz|AxmS z3o{-X1aV-zRVPkjI#Wueob+teZL0??dwZs}u|)9{%Zj{aKM?@(ZZ$0X_S>d*B2QE@Uz45}3UQrZklAfsyXw>@vp|%44{?d*Dl1_C1gr zry>2zAO_=1%fqH%RZu9z=8fA2#jKDUY&3quZ_kFn zdi}h_cP2nNsx7D#+I(bMkT&`<*#)Dz*pl(QuPD8{|M443xfQvu={~{#6Gkm9G}|%D zk}t06T-)@+4_+?X>dij;P8BA~3#lshD|pLq7pUn=+H?xgnseV4g1lV(MA!y25d?#b zILfb11S6^c36&9exhV*A@ zIT@CfvwY%X+C(?8Z!w?!xDkJ#$Kyz{5g4a^#G724JP|Ai&f1=muGM=mEBlZ^2<%lO zkrLBx!o{V15rr93FlabznYBKE7rQT5sxIh)-rHaz9&&)m(GBG%J8k#YXy+RA20KBU z&cc?ilONcWvV(UCUNegqj{JP#r#%VVS>RfM(TI5Tf>l zRSDc|t5y49(b_#N=JPOrr}GA%VV$J$638o7+@(iBq)9m7pz8h+%hjs)h?um11&D z_x<3_)Kuf2GeSKw^K4u`cP^q>o4bZ!*(8k@yy1YX`*517kf@t1G~v3{BWR%9wpg;H zoqV%`z!gV0w7X7%V1bDBy1JL5A>?IXD3BwhNRj>7?UiaxT@pFxj~7i1vAjmegk4ni zktn(~Wt{dC*~2~zT2@zisc%$TjK)6Z*2EM_3vIe2P&mEdq7V~CGh8GyE8Frx3{$Kb z0>JKXnvrVgXcNg_sYgS0v5SC&Z2w7@+CCQtxFrze>r_G*-CvS+j>-SnM zcP=oJ9~Gq?kX>%58$QwFgj!O2%PGs~%S;usSd=m^MoK;R^j%bo`z&6j?QuW$6NrrQ2xkHc?zyHjt1c?B@k}ePNKHry94f-1#*#3bO$$6IS{tcTW#6Cy+rR1d8Z6Ks7wZB()1A0e3cpS#M@`YggT z`;cQ89qFytnxWgUQdeUW-EAa{@a|0?tQ!;U9VoTr>V139zh?VfO?ttGG&`Y71! z4F#v(73DrEu;Il}rs!RuC0Xc}w&g4x4(W(;&CmTZSF^5p?Emi7fP{$HjWw>^QE)D8 zy-w$H&=?PC>gzo=ZUZ)WV?TvWg9lm9p`oNKGB-Qr6vX4U&ceJK+LP3cAqwex4<(VJ zqWkGP#Pn*a?T8LDB`w_?B@D1OcMR!qza&|{{Hot~8Qb`=Ls9->2kOT>dl7myzpeSI zHEDJ`QdJJSdTNm(+o^f++RwX?ht*|`dq<}9O-`WZ7)m5`V#m_r-SvoU`8#O#`ct7* z%#z18Uf4@Iy@fV5UV^#(LS_20v=AwCBL`Vp2P?UB(1p$M$i&FTv(dcjDsrve&6R;m5&N21!TT5#XG&I<~5J$#p$r zAbcVf`<;e{le5YyBRmrkdx5{q2#d(-*%}2U9_xM{bT>iz8Ar2XG#iPKUns3!dBn`e zm`-8q%CPSOLF=E53#_4yt_y~atdJ&1{`9gIYW#e2Ljn~J=qd4*HU;I-q`Qt} zn|(I9OHaiEQ%l}cVLoHwBd8zO%lgQ=GHc$m`u2`gQ^VK(Ce@Jy#=N7nDP7?lkY$QC zl8zx~uoUn3jimcB+RPxD=FGR)$%kSjYjKJ3==PaYPt-j2`Gq->W8LJ`z8y+fN~=UE ze*IeKnJV5U^_nVXdY+9K^NAWK)|{?uzd_XZ`rL_Rk(d%I0n?5+o?>lCQ(DqQUy+6d zw8I>Lo3YIW%r?`p!KdEpPuq#>Ymeb1J=Cll-k zUaW*?)wzX(g<_^-RQ03kM493ga-CG~;o#09OE3>|T}jkKO0)#%_93>eU@|ivwETCu zHc!yL9-~|_t!D(btj~NGK>mc8S#cow_V{GAE7rv^aW7M|1n=M~b zB1L31H_~6)PgFGWITka7Ip4CYj2kG0LP@n=vm)B>a=JJ8L6N1qnq*7$xLTpeVNNzq zO(d;7F3*DHJe{SOWN1w=Q?Pm}ss=H3s`Z%~Cv13T9TIAd2oP|~UH!`KJ+R5l{%j<_ zs%e6Uv^i@-xnUk(;o0+0ns9HIvX%W+^_)L@N`6mnu_{)2*O}Mt_Qr3o=eF;*^ki*7GD1V|_|6X%W$HV0 z4dGEl-K^J7pbxN6>X|I1YpeKTe2yqsi?;ogTxKa=X=XlA1qnH+IGeuF=9)w zPOO~1hf{Isb*wglGS^Nyjzj&dv&$cbD0Vq6 z$8);r-PTzppK&weT0MFgdwIeav98U=(%MHX<@sjEvx+P=(xD`h@mm1tJDNm<>NU{S zc)bY^_*cC44_{~aUfhMl1)fO}v; z0c#89s9MbV4v}2nUBh6sNR{>HKWLtX4&Dvd-vb#u=eOsuz>_S$)7$QtsXE@BhyiIj zv-LVfMsfAABh|Vjl<~8`5k+~y=2aOaQ^J)G3ds$s3O1*0JH|%Ol-#A8b9K;Z9m^DQ6up^R zMB_p-RRXQo$4(We8O5=LD^?o}4e@04IJuh8yZK$)cYEkFE6!9|&Ua^Quiqhu?%FhR?$|!9MSmNw<4B}ArvpNj;*EBk@DrxL2b3VAtQg~~Uv;VD zxvI)p`tZdNXk%pnEbYGZjx>BoZNmIYTRk^d+Zt`8+(mqhe+wc(7u8W{{_<|>tJZh- zz|k(W`G@cP!xzSbH|Xpup;2uy`8V`F7?*^}y(veNw*=b2Xk7a|$=hbSBZAmqC!A*_ zXB_K~X^{<*W1^*`GfQRJN|%lsruPU;xd8MR@@i|0L@qv?9E z%^R4;96e%Idk9$w;dZ|V%AEy?({~dzcVR@&a}IDz0gQj(%|4pZ7KJREZad@ z2BbQQbUjV%UENb+-fz?H|>88!Q= z&NwkVMN%|M;2^rUkfn{l~u5KC@nf45WA6YoedNQMRO}damFHWm_cG zymMNyNT~NBAh@HctG{+s=(oRa=rqC+{-`1|lX{D?0s%aDGt0I}qW95~N9$#kz-Wu4i&b`_jm*cxl!Rpv3HSy2v?BC1{E(n5bz z6Pi`F5&yc+>YZTa%TYtmu}7oIBSQS2-HP0n6Wflbx@qk?Lv&QtSfUvUKnXMlyxs)} zjkJL5Q-ngE#C*57XA=xhER!ZuJsC;ET_$B|;C6!OYrmoA)v}kA^`Btn&qBM*vC4#Z zbaCo5o?DDjg9=oTQC9pl(}*j^jK5qs7HlWF$mL3suwa0#RMg^^m`P;T2GlB~#E8M! zWUGKj>RD6vmSZP!X0awzp8cwJpACoCH-i>e8wh{{QJt3|wE9l_xQ2bJzUw`^2liy| zrq7z{u0<&nTzz{**=+=!&x`a)!LNPDwkQn+78M12D{uW-lN2l6BUbmX6^W*6h&G5a zuwMRIqGz4DBRsOJfr{gAY}UeVO(Zq0$M8pVJ5Diy9nxz( zDNxdg0SAso&{vgQJ)&ZG%m-9~3W2xIo!2m!P(yGjB?|ANi3SESzV<~S+N9$>vhBPR zP^{w_7XR}W&xnBaE0UqR5g@vX;0Oa1^rY_rIeJ(&@K-h!|KQ)J3+sNl_CvMK>I#p{ zj1PVw{2*Uhj3^B=O)wp4bC9?v>}j^Bno|BLa`g_VWun`BX}l_Yjdg!N%OGo32DHiS z@xetZ-uFGbI*g>k_2kcv!*346AA0}m=2*D>7j%^o-s3k$*k_)g)c(1qv9oA7$2XcS zTph&AvrlGZiFG3}1;#@Mz@9HPdHB;dSv<5}duVC8Yl>}9MM_P8XsFIIKvpCaam7{O zLQ5BoQo^fBN1Ir@D;p3#`9X-2u;jAm9;rgT#RnEMPgMs$Cj=UQHQUD5nOmhKb9(ia z&nly$`?3z}BWU0DT4~vSBC{o`aN8$<0JXf31=*d-IG1iC^dt<;+VlN-?QwoozXW;8 z$ZKIuCY-?AT_xnV0%qxNz~KbE!``1ro(k2*d*!~dQER>wZ|)+uV|w6(L9(C>mLjKb zW!|8L^yob5k=;o^5G_?sZbP4^A1Wy&K`G&37EDfmY7&dvJ*65Lj!nJm=~T!fq?ftv zx@BtVp>N6O{`~u+oEe|L09;`X zq$bJ2Ea5>;0x~OF1^jRAZh{%rLxnMo zH|8no3dPHc*jv!Ggp?%YWX^>_7WIl$wMlC1bT*;D8m1;V8@E9ZU|W4+&L(Zk@SUzq z7-`=_TT>Mo6_02%GAzd(wZ)M=@PY0AxSX6PU(->HBte97*r_+MZuG=kWW<%o61*47Yg-7AG=Qyb%^w6$Jji&*^+YC`Y3;p{ zonWl2$m#@U;Ad-+5iNb*5JVEaRi8JGr5CZ1Ds?Jd4>8@yI6ADEMmrNnBu%|tGutua z-3wAh_?N%_pw;-HO=lL~kpbU39myp|-mWxuBFvzzIjJ_X`+4HgOQY>Mx(AB7HXw2& z)!L(wU{*HTnz|#381y3!S(?}@I+NEw-4Ky+3dr(fRw3T2iT7Y|iC9cyZg!(;RIeKC z7dui>S0aQ`;Kd#O@LjjLrL8eWYHlf1pOzq|L}rA6Nd{pO6z-5+gP-eZY+0x+Yw2x` zG+y;_0sSaMs->$Z4U3Ah73Ou^etTKfIK6h8g_mgNSmO{Fe!ipXrC2H&(kSd)B>jd> zAIJZ-rO=o5R3+0d_UDyDLr4*P5+hEAD{>>dB@(ww);2LVS9=INR`V0{BfOE&2boeX z(Ipmm1V(baAQ+r*PG7fr#G7ecG&QwzCFCOKOC3DFK5?*zP3)5Jpe>6t7OoSKOx^Nm z?SxtQbRYz0^Kd&Vuc^;fCw-AgjrPX$SkZxFHDED!(9tKCL}lDe%nVRdTysBek9g$MZq>9mo zain4V*fOm~?X}=nf&@ z*_CtCcTPBOrq zCEkktT`W9c`ecfYp0hDgJj!B!P5hD>PTAL3`yLq7Ik>@0zX!sK4XkVKfwUu7T&wC` zws7a!c@KE2tm#GITfNtky$B>Nr+lE<*NO-9WmxSn6qNnuWuYWzXaIT zyF@L^BdtzWQrZN znJ`4lpeh-1;BSilbx{8@K}yJ=u#6bfA&9D!^Qbdzf~j&Wjgy;J`Q6}0qHBa#)*j%w zh4+nI?ua^z6?JKHA+yNs*6t5s1ODR)nR)4(+utTy`L|&aBTGGhi(I)9A$6`SSDr7I zH#3tT<4w!V4Ek7K-?8}RJNx`fT;!>6e43D!+gXBDLOpL<0Wj43-e=eAb}WI-CDXM-@qg>t-z%r8>gS%I4L@eNr=r7rFba?5Mm8Dq`f`g{r`0 zOSW~MlplBi=Dv2cck0KbeX*Te-4iaxBDRU8!fg8&RfziQzyPZUtdP%m+fY%clRI~+ zs6wO87Rk0-nwB1Zz@w5k&FIy#3^gypzTY9&XnkfoP4-Mozc$)Fn+hj@zY3Q0bEzh*| z`tESP9#)A^j&oNvRHpHcu`^UhS5-VlmOfrHGc$5f5M|o?n<@Nl&@(V8zl%_&VQ45M z-2)QinYHCo@ar4e)*tah@kO+WiK4|jnQAFtWM*7oGzu@>1L)f>9{43bkG^P*oZ?AQ zn7po5{wC-OxFQc8#-wlzX8)sc4+^SoSdyDFLX*cnwHDl&wk=%?UVqvLW?+d0z4jxe z<=kI9Ss0Z4ILM7P?zLzey8Yv>y#e9kl*&Vq!kmlZ@3A`8uID6X>0n1vVw!5b z7owEl4q%kI zvf6vV_X_oIRtdOp-YnXC!mIF<2;PrF2>TBLy;yBfNN}o2w7#3alPnw8njR$!<%e+p z^~d^)hT%Wjjod)W)zu@^}-;optW}`NbR!`ArX|ZnH zi(O)OEnn$~>>Q7tz+&W2A7M}U`ILXTB>%>`QW+LQ`3q5s6%Y8o@e3HUo$PlQLfzos z`$fMyfZ~`&0^@I;hke^+O23i@|MZ-Gsrb_^j?nQ+N_;o7LtcFPTK|x;3*v8QsZKlt z7v<|gC#3HjubcFd90Zwz!(sUg=V29{O4%EZyZWI2FC48u%H{g6+HL&NrSluaf$+N& z#Ln+({*+ijKMnws2eUfG zS%Ga1{6b}OKXA_kjm3IzBu0JQhzW-iBclvH-FtR^&DU1n^fiao8SNgZY-h51Q{JI1 zL1qOGCN#5ac+a9ptC43V5Kf|(SywBU82+%fx;oZ%FRF zv!0V$x*NgYR)v=ju2J*vVrKr#JHm&MiWgYAy1cYwqjr0LGMnS8y?)|)RcHGxsvTHUglqo#m--#Ih8}7mU9Z%~&LpPcD&4VAz%bG@V zhnr4M0oh||W2??F_``uTfZB&u>}^K4u|HGAerK0{SMe9yhCi?^wnf2MF|LFS_O5O{ z02W1Lg9bq7?_b8OwK>Cfv5Isck!&{!gMW4+Or0cKhT|rM-8ilP5^f6P;)W5y=Fe({>G<8?nBvIw zHV78r%Iy8*^!-xt7n}1lzD18P7K>6z5}dJ|?^F$X@!kC)lHU%bQ~D(=TE5{>+VZ%G za;sOUgrByWkyXn1SSGeT%(1gmb@P3JxD=XMyhTtykv5U|qS-{qArpF7dtDYW?Z4@p z%j_%_6KkEI<6|&sbc-L9VKs@<%Ntcx5uraU?m%c!8~>0l&h7$q8Ydg#R38$mWo0TOHZ zNIK?x;Ck-UTkrvN*XEHW#rA}P{>;(s(?BzE%6OuETWPFRLva{ z-c)eDV*YM*l68nuA`|S8%FqUfJk z_p!ONPmG3vhSYC$6*H-EpB-+y-u#^q?}W!$w*6~k+5_en3C($H-QC#1MjO;o46zPa zJhR&ZP9(XR)wP}taDGw*^ytOG!ZIuM~uxxI?`U+Ai74#WEDGk-})tO263>1{1qe>sJy5Hjy$F)I`U%unH6?UH^!k+8kzoTrRM+Onm z4*M*1rMKSQiZ1#vyplBCqM#n-`Lbi`fPQqu+US>4j@^PD5;F+nIDP1{%r}GOec*r1 zH~mjt^Z(+M{uf}~pU)Ef?#RIJD*o=s!QZ-s?UP=j6vbLXLFi!Fal&9_$Syll`Ty*ZRt@>DBgBmoO%#}7M`Hj%0*Vp_lAE@@5`xV#6})*T_=iypZttWOD!^s7>jBGIO2=!Oruo_2c%&!;3f2WK>hh;CqSYFkihCYjeKo51}O z+EXfWy1a($g?U+X^P;KQIqBHL?{<+pr%DdY&uZPWD~<@?SB>WA%(Ucgx?-;H)1qz2 zmq&jm5XhL7qj>9wY-c~Q&9RLG#UEt@mzFtj(Dh+(8rMR}v3A6W?tY0V9emK_rcNpG z=`)3J=P=Yq0F#b5D3hTc$m?Sv%>A z>$74!C6pOe`*N*|{O<_6NhZoW+J-O|P$KTTjkmsz_DdgJgrybyUd^=YQ&&?hxUua$ zJg=^o3%q13`x>gwy`kt7DRGCBj;&8{*MI0e9d5s_|MJU2-`O{~7bJ=owJJru_W<4v ztphdvt8}xSE0SHttif8x&P$DzJ9AQ)QxaL48*Y6^+!zEP&T7SQ`>ZbB1a8Zgg-K6x z@A>r2Z6!W(Qw!w)0!joQx*urccUc#Cag^o1&jjuAk%CWE5i&O8jK^hLc3QUP6ACP) zJm}1vPgIzZm3qWgI3H|UGqlw&oDoExV|Q-o^|YeNfV9kP@O=xzN%h@k>yNHmZ{-{~9o*FkIGHZm zr+;wenYNrfyDCey^|iQOOhw2%6OPIxe{hDPNFN@HB>Q)YE6VeB(7Z9OsFASB1^2XNcOx3UU0Fl)fY}KI zlF3#I;);gI{n79Bl1+W!I{MHaZfB9qfK`%d*rxKk{o{8P|LKPEx0}n~erDod4BejB zaUV7bR%P7$9<1_2{D$epPe`osc(eZPJZ@7=^6t+1)=1sD&YMHnsWEpf^4bn$WXmM! z8#NL7I0N70mtR2lcA9E{QhP-@%z4giW8iQi;w*xbE4VfViHq_KL0~1{>vR?ViTdPs zkNo#GMq^%@E!5An$(iUc3BRyOzl!qNh(^)gqltWKB$Fj6GyJTGu~|s~5JcrMu6q6T z^WoFYJE~5LZiQyZ*)5qmU*(_2r&@elkX5?4GWS6FO3K(3sk} zWDGI61T#854OQqmgc9B&F-N$OHNUU4PJa{?!7zN6s?2>S{Cc>csdMe(9x!-6s;o=b zdr{nSw!W8o@HPCKBiwQi8el8><7OCUIjqi8GYEbP_b;O*wcYA~5` zT|2|<2>1jX3qzXi|K9$e#^wm(|If&-4_@ob79nQ>o-Uh?by{cXL zkjx*iKP(q~tGnVDk4Sa?#;NEl;}MxY#g}IV=HF^<*+!MWS znUzT^Z3IUKR*Ct2wXsS>9iB{-OPxh4GXNvX*LSt;I!Ov$+pIQwpEGcwx{~&2PO%J_ z`A7IX4t)BE?{t$(`Bds4sULAn&7xz%jr)*HJ8M2gSry+FE~jH+U3{8ByOt#8>P&2I zq1{)T&43$@j&zR=JEyGt)_MSQHuMt(LzxL*9I3|_yIYejKhAk7vqohr-1>x0H}YDx z(*v4kHeU7mUjzmIaH#jE|6YP0wyd5+E~J}R>YXo%-4|hO#%AKR2v^myK`B`vgj+PX zBSIW{d~A~1HtNuXW@; zkt~(uP>V(Tr{I)E%M2jTH@W4dyJB*Ag zs*0@=HLp!P=m!w`Hbe!Y)fLj)FpLn76oU@&?g1-JHW1uL|K>FxOTt2agC~Y|VS4R& zB^E`(xu0QXl*of2v}(?dNs=Mnu6dck(ACgnk1Beqt81H+4^AiEzwF{gr+5r8*ik%JJPmTuHbLN-PBZz`l7w!@H0(b)w3*o7172e)+jsO6w7e6%INF8RBbT-~ zuoz#DPEs_WO)*oy%8}~q)vhy1aJgw7zRvq^UW*?Bc-U$D$AZXK#!kq)az`L?V51gx zR?QDa;;+K9F*6opE&{}fsG`M7SVdb1y6ffDjlF522eFvZo@{F3XsZaDm|dav)hS+x zHA=BlODbY*vsatwBQ>UFICZ}21becq8L}&9b^cWQSnanUxy^LPRQ%_W?4>3ZI5*BF6n9FC~M0va2CGtucAPf zZFhm~1Z2Me#QzX5{?pq17q^XndXF?Wwr=a;AdVbvIHwFO#ZW73^@0&D`a6E393DrG zeL-s$Bj63>4@(00I!a$^##!hq>g5k#HfXw1y!Ci!C%!AGmF1r|TK_a1-73~X;B6ZI zgt^$^;?%&xqh4ZBrtqT^ggj#VAmY7ZcI_dsbChnrG>yD$&7o0 zy{VxHBi0ILbG#@pEGj*3rqIfL;gG7nL6_*TzI<$ zQ~R0PGCy9$g(pEI@i3~-QU4=O>a}!-m$zm7o*jvGLFqGY6bi?iI^3$KW0i*0=-EYQ zohp18D=RB3x>1en7#RnbHb7(;*4at0DaJn?@bg_@efoj(t2Dei&n|Pz-LAew%-qVT%L_GP zA>Uq8(qIw8W|an=5Ej!z+nKAI$nDVfz&3gvw>2=+4Re9v@T0!B9>p|dQ4%V?t&&zx zd6P11dT&$t2bXavOO-#BN;;mW(It$FdkeFSf89+=T~*&)cu8p6Z^eq%9=#3CnPBDl z5Q*yrwxkij5jkG+_wq@JwnAos zP{UUpZiy;%@8#VhfujLMkq@_|nvhASv9DD>5Z?@v(iTEioQ+Dt1s^sCo@Ah%Xt2Al zB%?7WL#)Z4$uU1}ZyLX`e;iDVIghEy-?0>1Bv$sbE5B^Qzg+*ZF#3N$MW2`-8U*i? zHEf>sw5@*hkAh`8+ff$x2)G}UxxRaAShP(DP&NzOD!=GX^3=KqxHxwkd9p0~BCK<@ zr5|}udu15hT?ySmB38UPmC^rCan~Kx)VB45K>}Agh=8F=@3j1rH5Wa?;=$Q(tE%Nih!Wfi-7R*?wxn%>V4ll-^};k`|Hg7X4XDu z*4lgRy??75{;y#Z4|057y|{ChRr^vewT9jXS+a^UQ|04@nIV}BeZ~k`0;1PLO$Mi* zhXS;eZKhZ|z~^S2omKh3f*B>KV_EStLOflWT3SPLoezvJim?4el2eiO+^o)hUF?U+E@0v+p$SW*OqzYDx`3Oap4G2Zvw@OLr8+D|6~s~=8h!STrOUCC zH~AnaG&D^rAIS4&{3&BmI>M_Kuf=?Mt?aT)?WM9y_PP07XzlY&=mI`oZAV81QVjS^9Pu`m zP3obo-?>A;%!R?G`7v`^I7MLo5Wr77;PQnW_2(O`C?Njb>3oaPk_WyM`sdpGfuZun zR}VtQJMJAl@pX6O_c|_3w zt*-0ON&)5mBg*_<{r}V^xy03jP#Eqdw`q5ST$zyt&&p#5hSrBmfgvDJsbJ->5sQkB zp7OR~tv27}wb~Q1jC(=qV2n_Z*fy_msNH~gjSw?xV64ijGc~N zKLkXfG}Qraso!G1d@tkfehqZ-d+-MaGW>N)-whD*ZP)|%q7e`=p&I5mK@@Q@a z29C#=E=q@_W=W(z$T!c7{FG}RkwEX;aw;0Z)^Y~R|G%FsQV4ANREKo{)Yg!-)ASP0 zs-KOxiOY#=s`U4T$9^${`JRS7I1MGnWuj_RW(mB_1XOIyi&~*zfDNtQ+$js&S6oZG zG{O506g2lGc2WDgOQ}1e4YY1fMgs%1^Zd7%A@*Al3juDr-tVv<7{-WOP^JH~)j3t3 za%qq?R4@Yb;vB9}F(Owu0M1VHDE!KgpI&-Q80Z{?O!c`UnJv2^oto3;F)ja9kK#mI zOJdmWKlGy;_>v+%(%{LnYA>%U9U85*Vb+17@zJMb*fqUBJR-*}N=JufWrwAG*fn&8 z<=Hq2)YUWdFs7@j;Qy=2enYdF)*&F_55JxWim5;rU_L|Nj@kwov4P_jsw0bOZ2^Ba z5suy*;=qdz8wqeB^|(}rHgadt24o7{<36tp$CUuA^bgMOCZLlt7^fM4@p*k^*IIkV zF4gL_4q1mpZ=;2X-W>ctn7QM#38SQ@kjg{AaYdDtSX29RnJ3^eiR{s{nqalB0e(js zS=Szn6^)8fIr?)(RVuAZ8J2o(#4MX@3`(`s#=ws;GoR&~cDsDi{ZB7&RaPilS=UJ5 zC&UZpincjeCNe%%F#QEH;)c0SU_4isD0VDEYQY90x!3c%gr?r zk&iOp32$ESntu+-81WG$E-4jyCr^3a-cs}ldP?2JFvh^E4n7TgsC4&V|GW>$x@AUb zRcwSfgeXpUFb!-aS<8nqe+5w`{DfC5KBAuUg_23tq&-BrX=5Yl>mb75-1%M*SqG7aeqCBKS3hC&e!XKER99%TcaASAB6u|zmi_^7Zs+;5ODI5HS&UV$esXb{tl%n)R zzKD>{=9Buoy=>Rpe_hbd+}SpS=Ajp!`tXYrCvM355Du(##N0a=Gzs~h;uVb=)}VaF+?>*OTOHFGxEep9KY&4y@^0!A7tNH3XLr#EUOhV==+CBbtaf2RZs{*CGo5In$`aV%4g2lj!KQg^_i% z!i&(tqn5cMSXO+VZU}GDbTp%t$=j$ar-WPb|Nix7^?$z+r!|yU56ir-sEgF5FR1Lt zU0}N^?H%%Y6h5i#n7xk3GbDPs!W;t^$f}?zb~|AAjQY^Q$-4w&ugVH1PNel5=XG-R zgF^t+Pc&^@YpX}@(KU>Q6SpbG<6*1gS3lWT*PKbnIVG=fmEiDMoYF^zdbkh}^WbV# zLsiJ&o$5wovq$&jB$9<(x&_MMfVTA-x0Fpcy_MdWEL9JdqHN=jjy3Y+^@@CkBjs8q zct9+J^Eo@9*W-snbq@Zy+I_UU5x1d|LmX^D8A zRk%sLIWL1ejkQ_3NzlOK0Ja+8eh}9T3q%JCvgh}h;*h|zmOSK_aEvQNyUqeTzi146 z^6xv&5So^qmxUqj3DRXRcsJ%uRP5;BuiRIvd=ayAyjm8-$3&}`LMsdG*ua*0R6bTa zqP^W*xE3$SfS0NI=bVwuxQmx|M}td(E9#ME1=f4+u~Rt-r>5>+;BqbXu~JW5`LN^Q+lfcEQR@Mky1OIfyxb3?yG8L7 z3%;PMs>-^{$DG|-MVM8T0PC-XT3s>3Y1%39<2!rR*{sS|=&XcEl9UQWSX6jETV2gX zv3lkGn^5)O=qDm$P#LFaLN>2;xNTSjvuiLu;7r%cB9AR9eZR~``%uwLWB7=`O_+tw z{MlH>lP6!1Be`A(8;$$OWi)5>WlE&-mDY&oM0Hs~z&43tGG_kkTqz*LH3rGhk3|*w zf{PRlJZE!uvNC01QnSrPw@(^J%=m0*py^v%OfGki0_pYRY=)+-(c`c(ZAiCR5Hm~f zZ0U<>>s3?3-`o#Ma#d8PbhrUntnGC#UZFVE4F)e?Qw$@GZyp~twyjM=ADk*ZZ<^$$ z?ci2>sum|eE1h$UtoNs1!O(h{vYXifo5S*#D0(UBg2^^>9o1)Kf;frq%>x74q_*Fd zp4#ycYhv4A;gv`RZ;c_b9>;sC6M>!N_x&A=oh%Z7&KY%&fgLlg7_AuO8GYzuH?bw~ zwn6tgk-@zDtk{vXPm*Bw0@LCso5FUmXMwV)<3hiA7Z4cfyD*_o3$s%}PsK^Yu=XJb z8Sx=GQ;NyP#ssix#CZ{Jt=M@{IIzd#I1)MN05g*-LyvaT!y{su#Gh>1TH)Guae5ZN zFaKI0{t1j15gCQlj=4S}sb6kl=4u`HMOac%*BZBP>8gLEZn5=WY!KH#XWZ}9jru@l znmDx-7%d0pp4vPnM9;lvK1~2ETf*Z~%{L1RxcU7D5O7SNa=1dzm2pi`0+%LZY&)0G z+iR-Rn`=?Y3=S^xWg^v{h+i-Wi+u;;40+Qzy+5^`mC&*)Ly_T!n`Den zk}w7rtnHpyhoB62FG$XT5Gh_Ty7=iQqrK~GPmQXu<%J${vLkMBoE2x`8qfXYt3VEK z>nK@SAI>X5W6Ou8*<1LMqS_OmBoSu{9s)G* z#gb)1^9@1W_R_7gZDQgGbywXw_txF}?6c3_&wfrl&O9yyFy*9Vr2s%65MT;>10Lr9 z5`bq22#5$zpCKY5A|X9PM#V)#ML|I&!p6bCB_pOFCnF{$rKDk@qoiV@CMBikeaXbe z&c($=LB}u5$05YR$;I*OCO{-4BvcesLNqi&j_0J$IsWO_V=Dmb8E_fS7am9jfWrd9 zV*wvK02BZK@G0zRe?9QOy?}7=Po5$mK0`uAfpw_H1i%5|;o+XZKYjY-39Ppl?0diy ztf$z|*~JlXRE!X*>~J~!VltmmzbbFTdpmkW!)a{qkA#d*KuARVf|ic{B?A{X4=*3T zfW&J_DQOv5In{S+>Kd9_+9sxE@69bhmJW_i&MvNQ?g1YIgMvSWgvQ2wj!#JZl9Zg4 zos*lFUr<<7QCU@8Q(ITx@T0w>v#YzOw{L8GVsdJFW)`}#y0*TtxwXBsdwg~N?!{GQ>y zipeZ*L!#zQ4K>`#LISuQ956&?s1Ja{aCDB$vj zKFb&RUule~liYiyj9uC+sZ6$yfcKUmO!6$~XLZU%;{07x$}@yR()yF=<*Rw)O?sf* zm$mYICGRJ=`8VDwu$A?FwRKMzR(4etBa_4n&US5iHT{#Oa=c7_K`PBCNYga1&z!{9 z=}Gt{5#I1l+VkdyQSVW&()&p9Rp|;@kv^JJ6M`4`sKAp$0Lw!*!~(u4>zflR(Uphv zq*RhuA&rKIS`v@ov$kzLUk2Zsa=fUObfR{G_GT2VW3p3{uQNe34D<(_08jG2=t)9lI1TAzl zI0g*Ae{xMJ;1BQPY0Ez4D*$Con{0AgDaBFoK$CbU12I7!D~OnXO_CWCFQf1XsFGAF zcCCFivR)mfGJ%cX5Rp?0B)Xf_tDmbPO;55FNI6ahK(i{V^js8<(L$QW4g)(-o`*zs zNlBBb#rrb2kzxu=k__5!E_G|G2gUpVL-r3@XzbCCMk@foA@)W>2?uCUqjy_5(v9ii zEkl2dHJN1WN!gjXclU%1MGxB0P)gtL_Y}t*zHAVU;6igh(gxL*ag_v)J5?1xvJ2JC zklVKv0XCz$TX%G6gb!21<`*SC7Mxz`B0AUlHi~HjhE~Fg4wOr&rdI<5J;+@|3 zpaS1yCr*>cEsV0SS}Ssufp)`yUms#UxVUPBb^SG79=mQxj5JLU`P2vH>Vffc%4tDE zuRiQeRjeUXOC_0cC|)!be_a3kkZWH+*IGQ`GW>W!2{UV#KE;$1&$u-=P4f0jC&kDt zNyI1FY{2e@I8AT*={wbxFu{ZO&LcEq2by)&B9=-t9^U}=&OfPZ=FLL^hZZHw>Pk4n zTof}BazP7O7&OWLJ;Qa6fM>6%Nu6jecNC8zTkKZRI5vKQvunRsG$r(HP}z?rQS-e( z+5+_mod`I?#Q0Zc(0}0?euGc$)|ZQzb#)cGyTm4W*`P;&-gCk7s{*q}fS<;T+&$vJ z?vHnB$0%>A%^N33nLpY~>E>3^NRe2qpP1SUP8W}gk&EW;f(q`Gj;f+v)AkIM7!LHE zt0M`sY1y1#zSd0xQpSa2XyZIch8f(FcRwT&Um(-N50rAJ^0o_HhI1Xb^YSLguhdmX z^_ugJRLK+*8+vwN_D9Z^cv0&&jXsdN;^R&{0?NbD z%?Y7^V-j(Z{FRnatU299ODWB4i@Y=-Y1|_q`+hmJh#x*2T?)f}$0@=?e)`M@Nilb^ z(}`63{Ifmb@H{x@KqNRVeuqmLF}@A5qmP%^m8`4@jDOEV#9J3C5o??lc zXvF)gO5JfG(PP2(>4xfk^@OF_Y0##eEiZl;{2BN5J7vVAy)?Q}ZVXIUla7npLXXUy z%K0`oJLE(ikZH*453T5kDrs|pYC~bGg2Tjn1Q2~q+vB&E%nHh6c!37^mdD_q-vb%X8^xWHwo-uJWz^nf4GM?iSBp1>+M zrmn>!z&Omk&id}lr%Je1Hp4d#DbtFp&FbCli_fQ6&wOyck~ytKVRz5JUAkxB z-k*$E9fW_tjzTfmZUNo$Lf4SC+&nJq(Vf@L^GgACwE{z5SAB3@Uk!^(-@Zw*9bYJm z?Xfn(tRP&X1k=RpeV%kBTld!U8!kjH`25-~`Ip9q zn?P%MNB5AVtX>b05yX-$;KgKH2ai>_Q2GJsI>V)~E9p;H$}2BVUb2>8*6i<7Kd<66wpobS36V_6T?d33&wgV)-0%`TT%i{&|DvsB{Midjy0)LPG&tJ}bZ@ zmwV&}x(9&dBcNjxwqjufv}~)5iXOqSKF~HW(8H%b0zM+PEuArJr(XdH9sxMgzyIm& zf(&$*2TTe?pONA$pXDb8cZw88zE^?|fqFQP007k^VAU!s>;FB`|0Oz#S~ftCYE1ue zLmZ`Q0$X)vt+I|$VGeL8|mTjF8pCja7Cg}C68&6rrT><R z>gDJljY1KRANj&8|L}3RA^90F1~$|FQp~8|Jref_wC4k}ak9J?bLs?p5V&#VhAb!i z9;vYApV357IB{zAP`OLq>&P*vG|9jf3GR2>`XLuXbHZ)uQh<7$Yc%PBl zj7PvSE^HP%^1dZ;MU8>>+O`&)tj^V~`=x*J>#}#-M(1$-=N_b3xZ29}7*+daq*7JC z#iiw()L`iDN1T)`7D4w@F(QhiAXk!yz-C04yU2pM3s^Mp*8{C5GI>T3uQFNf#Vb)o zpdCNpe9ZZ2nlj;!xk?AN@zbyi6OR* zNr5q+@KbStx4lvy5B(l+Qo!u1#e$0pq=P@g1rA@qEs_|Z_S@KC zOQ33}WyEw72fkBrp$7m*0PK2_4r$LG0nuo~_km6n_e{wR0s#7~(7$S1Lfnsl=RXO2 zX4FK_BB2cee{%xBm@1!yv-*0fUPL$fwY^D=x;e$Iu%lWqcSBhe6GVD3PN6RR9$!=D z5ik-#3sTZ4W4)r?z*SKg`2BDFqn{aIHXF6wGHtfH`+_UE)b2ExbT5_%n?=jM?jln`ki-mK?Z7i?N)GJ$vn#`C zd{fR{0f!Dkn|I%wj0Jrax)GNUsRnE=n%&O8Aeq=A_B%R*ek6V$ zl_~N?fz5JiOwo$;_DC9o5Yg~$U?Yw%x|Nkg@czxTzcF#LrAlc{?#6r;M=|a8MFFbp zk8paLlt9wtvm$!ct= zZR;+uM)$PNC{Dv^5dX)8K}d3Zhv4D(|~21Ps*ek@ftmBI4A4M1U@KLc&<>+5W7Fr83`Z-lkJUIWMD86cP5lB zr=z*rkLlg?Ddy8$=qP=)S??v0Gu1D`rM7n!ZkSfHKyAV0mW4SF4g5GT?z$rf(F-#Q z<(%|vg#T!?`m=429$@z8PNtC8!R}48+c5vv-rp6!@$VkF;zf<-y_-2A@ESctOjhg-5^> z?c~qu26xoj1ouF$rMss}NBNx@79$F=?Kzs#RG57`*1m;g&nWHQOx(6KK0LFunCA3Y zdv><%DK*U#Xg2EK&{1JJU!1`n)BQBv-Xu+<*>Q5fDh*xHsss7~VxvuXP!m>cM84uk zZAT-Pa$J3~fNn317dXAa}Y+)g@C4swkWnj~d&h{knN_ zYzBc$g|KkZ>p|tL=wcBeP9|6$jH1Sq=*&}UF*dIg66-o!o7zKUD`D!NV?txkpz{-H zmq4KNdntY;u1i*r*1XIGBkcyNj;CSm;T)6uAsGicJzVs$kb7kFisIYFO$*2VL57P% znb}W-29*B0d@t-5>w3Gjqa?3V&5C!Eoyz5(hP^dTSd_{oJ?$}U-5-3l6a0;%f8}+P1iUbN~dQyepSkhaUQ=dXQT+P**C(hEr^8=L2aqYpmDR*K*B+R>3W(g>B zSY_PvC^S++wsI)qS7DNxGx2qg7>5ry1YYD!RL^}LFQuxkh($9|rWLf(Db3BI{wYh_ zMQs`NW%)E(;9G9}k#wL9OKyGr(P1LjNxd6Q>?BcYNH0xHZy}2(#Y$tpypDV-%9~{9 z2wQE)!9e>55v5}W^25^!)@o72*#%2p+Rgo@UYbrkl!N#QQhI=Gc1Cu$K;U80 z3gg#Ok&RNw$C~4uzR*tiflo51&$hb@xHVoM4J@LK?4J!HERIRV-g?b7N%reX4B`?2$nyHJabn_5X1`Q9mLrq^9(q;2p3s)Zr zJ%la~JGP|U#=3gejRVo<_?)?kvxV@>ZcD%*$F9I;eaRejkAT=I-Y@8BER;_xd`=!a+tl{z*3Y^=$gVJ} z!=hKTTrxRfk(-;1LnlNmqL2R4=B{<7d~NS~M4jxGxV=TucN4X(Z=}i#b`psdrb;cR zF5InjI12T$t%y;ax6p)Y*34I_X z_j3D*c-5}i(tK290EH&xi{t_mmHqb=HH`zGxUrJ@(>y_h3Vaj%9zPG#CNDSLPHdQL zUZ}GBGQpBY6<^oNa@#|-mTqtv@leX#yox2EU)JEiqQgW_OS}9`&;kkA-n|<1=Tri zwc{{~1u^qz{n(im%F9_~BIbih!a~ZJXaQSMhQ4tjUdXI#(WzG5v~_;uu}P$dM8S5#^`p==yKC(=m5&25Elu9he5L8xK$=J@K=5yVk5DZ$+Q1MZDRI zx_ks=lr*ErrHx%`PT0VOrAqyDRpUp(rjKMlhc-qv#%M?TKww2MpBrRA`;qjW?+2fK z#)3z{j_Ks+H2?1-vcG~re?g+?4($dfvces!rM(+>b zQqz0{{H!xmx=*g%OFwohPaBv2?9>0f6hya0uXLaGxrQtEICn{@MIs|B^kq`uzlz+8 zJpzvJ32r0qI+c_~?|dX5)_mB-a7N#4HKh@S(tdRkh`I)myaql3zNb)(Wj~BIz?8&& zd4CCua!7LrLN=uW@zU?JogM*VRzGhKF&_ApB5zjK)xZoVY;RC1B?2$`qSJfcZi)|t zRjKGZtqPl{jGpURy$U+CskV|w0ih1bhOAQ&y^0&5eNsn?)>V)_Y8jMWx+5T(c%NUS zrRA3HaKMX1{xfftQ{SWbd%f3csexQ|yv zPEC3t^O^QMF(iPels3@@@4k6v2@6V8+qEk$X*@V7LkS;A58oJSarj&|GlnK)f?RS;ENzj6Ad7* zP7;j}VzjbjZtf0ov@j3plO;_&W~7Gqf5sdwzJ{mNIV+GKhy}~eXRG}-9|292pS#nq z2o1XNC_<=}t`zP_PFQ&p`Pa&Bn5uyJbYV^0NWSu`%r8*4X+CoAn+&{_L2cosrVNjg z(2mv9Y2R?G>_~I?80O&D{E6o2Ip33IcMbR=iNbl$S5pUEgS>AVakgv9doOmNT;W74 z0KwPbCa0dYe0|s0X(`PQRW@kzi=1=Y#Qd?VTh8;cRD>!O-iMfPyy`O0YvwwLnKy1V zwfBfUFGFRlTp=nsgnRLbLaJa-`HUq4?ufkFym|eX4B(Xp>1eD8Frtvg^PEAu6{{9> zi&I#EZ`YYqbe0KXy>@a)-Ju}yxgUY>Ez#$w=JDi#qs5S7s~g%Fidn~7FOhDYv?G1D z{Y#9E1-ZFb)(@XIEjD?FURH0`)M~~&MU7laboBHrtrgV2>2+d2H)2Ar#iG`q z_;N@CP}72Z1vFWnH~7)jCi7=Te~_g%>}#LEYRl7K(gxMU>MiMsTY<@W2Zj8w>}@%^ z{d!>mPSi7tlkEnAI9nr}VLRrD^0`MqQDBb!Cf6gt%puBKiKxg2>llWTDGs@N<&8fA zvQkd5+KnUyH?nt~!NN+o3sCh0}%TV+I789=Np;#(`$ zH^;PrC2tJJxl08ulz9mcAhA;L)5wZ;-b9MF23mSww<(nBtD}b~bUmi(Uf4rEjTXGZ zP|iWc>4qy8;xsGhHagl;nPk~dlndq3B!*7p6f^!%=Zfi1DK-uF<(QM=?*W_(9(YKP zZAGF(W@mT%rtZ74Qihy7*6B#E*}FiaQw82fj~)1?|wd`^p0khYN3^ zl*30rg17#>^WM*QO+3&acl=`_txm0Jj*kHITR0uA^!KMEy$cs$9g^6hSSuufjUjiSrVlV1kHW4dN1D6&O@EqUBIhobN2=@ywNLU{`-aiYbne`rA4QacSo#e|hXK{Rklc@qo=fA_rbwin?=5x3|g)xOeKl zg`tAV2o(OJ>+lYb03$tEKY3b6Wc~0HpB{v*haU}54+)BTWPb&9{DM6k(uiT8>19~U zK12GAG+2-9Z%zR40@qLiE02yFjN*e1$0W`Ou+n#Qs)w#l79gG&#Z*o*o=dmO!LT|G zNd7%Y{wn5NT2c~MmGfUBbp4{E{22vJ%ka5+2)plkF!rVU4J9Z#)zq3_m|IW+J*nmC z495{NN@H$EtBQ&5SzasEEZLl{>I_g#I`r?(O(4jP503wcZvSi@Y-rDcVF_@~%-j)Z zl!7KCnz@hq6!Y@zJXgwFwo_TWa*>$Or~5uywU-z2s>{4unjTS%gAY&rwu_iNd9-wL zu9%CF9|g>KZI>27J1L^Md@{4`#n8MIQbsdPAv7Pk2y4VS@kjMjFbj`GxfC*+}vIF`CW*Uhb9%1Ns?lo}|**M?2rT`L#c-U%ho z_fAkx%-xK0`K-Xfpji9U?L!~Fg44C^sxKEeGtZn_)h%}qIO}!1)*B)+LXL03pD2lw z<#pWhf%Kg7(ED_*>zZ8R)rS@YK#>!n44R zPvuQ*2p4KJfpmp{^43d{YJD`8sl|(jgLhT619`?FLNu91RmXT9bwsE}sDJUN$%_&7Jkj52hJFpUXO%3%bzr4SQEhcifC$d9+V!=Q#h^ zSfR(qkLyR(nnch+r0@_%189yr3e!oUb2YQ{b%-wF*X-c`gp6>immYu*s4WgMLUGMO z$X$No?`bLxU1(yJ?5)z2a4hi!N+Tj$vb#-dPYA(>;13U@ScxFRa5?Iil%%%eo$u4h z4PI3)>2BMxM4=blSj~5PHKbfED{h6b<+8Xt-qtuAo&?n^9T9v$(%AMmQLu?(01JB z+$kwJ0_TJ>!a|;3wx*t}lX;G&j;?$>~xw2H92J`*=>And@!QCBAi~{=A~yf zvbZh_8=DyQI9%mf3%*4|=S+2ZlX-k}FQF?)n<*y)L2zv7wPn4;>wr{TR@~x$YGaQF=Br75l|evx*O`>B*1c<9L51%Y%CEC^yK!cP! z_}Jv@%yUu_jFm7|ynG&g$T$Ct(htS~9SWo35D1bYG&72p&MXO|K)SL~Xo@^(Z@*E? zop77w!dL0qixzJO#zf29St^QlG@zH-gR2&76q{3C+Au(OXsprmk!8E7DvgibbzR4p z5m0Dj7uM461qXgym*Xym)@_j+`_JyhOb`R+#;fMQEJVgN>i7G@LgC#^hx?kbKBOK1 zM9Wx=^J$8oj{tEkC~X<{{Zz4;b4>1Hy0eiD6J{w1z3Yr=T0twhd$Hq=HrxMnmRTXB z_6VTzu6+dP<2(YkYf9ZSOXsQ={R-yj`xUL24BcaM-dgTbi~*uO(> zdBwp1OV-;7juOJVi#1!7?r*gx^Qv*(tg;85gK9`deRxm-*U{UYV zyLt^1N$o84w?1wfx5U#A?n^$)usG^YPehnu@g=@w>`^>Fu#DE|%6~>QLg}K>=lUnb z{kim??!=;bK~YoiFn-%x9B?ch{&B*0h$$C5fL>DiT=~wkME1!txuuu)c!zdmF^I2xBDrnf_m(-KfLVb+ zFoae^{cV4(#kIT%VnFi#!?evKKzI$7+85mFwcK4sJ=|nq2#*TAD3Q#+dNXy11FRe^ zvp&qhW+c>CqNinZyDc3T=FdIqh7}U?M@y@janv|=8tt5K#^o(b3Y3NHh>+H?;XV_h zO1fkGNvD4?IzYeGZ5k~{W;Lu-#|7CJN$*TQEs*0|`&4csh_p^X7r(l=tlFS1isN4~ zLCj-UE`jCRi?3n&z8gR!+S3vJ^Duk;rWj>JI-l^TgYI(AZAR#m>5cEOKAl zCJUszj)|@RvTl8?S;@>i3SVD1H#rv)k)Ip!nQta=37Zd*5Z+@%PxT*8G?z6$0!l|O2{vNK7|X^HuebpoNLZ#N9(c^+b%@4ygK194Sq zow`WixfJMZ?l)-juyY{gSQLuxP#EJ7=x*vlpF*TG#o0^V_x?MYMineW+a(}yP^1Xu zeugd0RQ{qg^WBrPr2>81y5zgJi+0iE+Tbtwx>xcPsG-2jTN$N(-;wEc*o~SQr_Hc2bU%LgoHVc<; z=M+We=k+*W_hL5`f-AByQ7{GICqn0bGl%Rysm7OIupeyIDNqx2kxH&er!-!Mej+klD4nwEZ+^eHvGp+1&|DiN-|Y$j&1Ofxrg(2FzHUn4 z{&%&3PW6)e>?ke2csto)-twmNw4bLb6vvqJ!wivL6a@tC`Mt)wq0aZ!DZN`D?kkv}UQKKN)jcq>3{Jg;Ts+~Nb1 zif?`WRPyFi<3Oc^ASoJ5yFLiwSb>EIybXq{PrE>BTA&rhrxJXwW=X8zBtr0Ko1ukP zvbM@Q8go>K?SDu0&yB#4tGAb${^PH`w1;jkboShg>LWQxWcz9SW(*N}Q5)-o`3kEH zkSy`N3Hk49Y$Zhr(V>H zve+YSHU6E!WfS6l61(@EW^$H@sfXCF?zg3|BJW;+{x+LXL zIHfaqNdYoYmHs-X`ww_8CHLxUYPB%3N=4P4N4K#{$YD5gz%EwrT@eMkRoE&T-9gqu z>B9R2Boh`=#E3qgDN>G&HyG}KsjRwwC9+`$;m-#-3?P^31EVLa`&^5L?{%wmT!|#^ zUiX0(~%j zDJF=a5UZ6NA${cXz|#iWWX|*pGrN1n++4DlT!U5lNf^NTuaj#2K(zbC`75-@c=C%! zl>Lh^2;=+zZ~Q|07nJ>bwt?=}?^FFh?Vw`2Q^;HGWm-R)v#Nv z|F&zV@=<(Rd6uU3&Ly#}j+S{{jM@hYG!fEG`udHVm~NEu47^F$t_;0=S0(jemGK-f zlF^jWgwX9UCd#Lt`d~K}HuTeG5xu}y5Z7(S(P}b|c8w~YhWPmAjr!-|8dyx9WVUUk z7#3jMnk$YU!wM5kYUV~(PfJVOi}r(_{h-<$n@<6@uE$;dJqiH~Q72p^G6Qq0yo-uc zZhDJoauI@YLRDjxbY$TqLK*Q02+*qQ9<8Z;7}VpY-!1`ZXw|?VjDMX(@Xtg}|4v`{ zl>eE!1gw04>V?a&n?|YD`Up^ZD|T4}Ls-gm@CLT6*Sa1UKHbklWm>{@d@kg63aRDN z33|O{V2rc(MzF{b`2SNw{F{=@e?XG?ZRZ2&@0cWi%Pje27E>`ruKj9egsdhOgNH6d zqS1Jjx?_Di$2>AE{(3|@$68+>0b!7YSgyiirCx7Xc>G_0pZ*IwqJD$C_DA&*nbU&By{NZm>!lwzxr14ALv`ySvIxHyocO_7)g5 zhwMcoSK8>=y(aAL13!MZe*o_P;OBqXvk9H*j#CqF=Gzro63bmVF3v6PQn|^-Yf-7m zXtCE?adh3Kq{IXy?+D>pC#2z2MZPTE@JrJ&vNB)5jM#9IRNXFW^FiL^o*L)+Ng}wg zht1un+gYf)a{+a%#bzqRj37!A`luAI^VA`}hYwQrgXc)+%5rJHHb!59D^8c@D4g1@ zqUj6#H8?n3LJImM+7KZ$=PA|C)oPo9=&T)->U% zCUl`XD|1t{xo96aj-Rj=8{R*xkF=*}_6O|OXit3?wka#t!$5KM z3h^Fw@db>e0E42DXGhMgyx?V{-c-kesH9^qhi(^=7E(imd*5>bk|BXcrG$#B^B;np zyZm=G3C)R!B0k5oF5jp!-En8naEDY?Xeve}ty?>jD5QI)c^i8q^p2E? z8U5el4S!o{ejWLT|9b0I{qH{iAuxQodr|6fr1y8y4(>@;8E5Q&(YKU_Ic1hV9XX^JMEUz7b`Y&rmlJ2$aU zlo`-O^IV7Ww$^iyi8-LVL8j%h(~*1F^;>C*_WN%NviIp8B(gMNDUx)D>EtPVHdknL z^J8gci7jU>IGO(?YRJDzDfx@*NdG1+_;)LS=DW4Tr47>e@%cK7f;de|j)w0jlp~ck zl&>6g4KP8XVEls*dtBe-iENjJKB&}uTS;WF^_UQ$Rkbrsuy+Ml&&gJ9y&}iN^+(J*NZ0b6zqWTl<*JmanS;nr*%z;MLTJfM{`tC6rmMP6L%SVWRyiDJ`Q~5U zx`i=2;wE&yla-x;xkA(50^a}A$A7{iek~{eU`_Z>U9~5|vaTYB^^5M1rTcI z#&}$!Z<@ULzWbMZ??cW|9pYle7t^N#p$yf5_j!dTcn9G-!cE2XG!ktvJbEUehx>g} zocd_sFUyC;Vq{XAi-KN)ihL74XS(6>P@k*x_ld6|3%b0JLbo(P+vT)$N2{8a0I$NM zjis0>e9S&u_H2#Wmt@>CzNTR1FNWm{l>gCk_22#7|J?HT+qwyl{qx5Zr(W}U+;E-& z|NJGsV&woo-PRkmYoWs@x061#uEA0OOCC@WloesO&3OO;kr|0jXzLK?;lR=HZLNu< z$^o<76mei)p&8jm#1ADK%Dx8DsQmJC#F)3b^X=wV+>;&td`1#9k?stBu~sk+LHkgm zy`?+po>3v+fG@I@az0qHBX0+ z*;Us=KMDApzVnPv4f%yGP4+~rP|clo1H)F6b_CcAT-x!mKWB)A#1g>-iNs7LEuEVE zrXaKUd6({q)%WyuN3OSw2+DopZ#ka*lz(UZn<@Xd9@V|8OEqDP(ZtJ>{SZEY{m??Z zw7t(T9&LQ#oRxrDO}52cMxEGRTJ{8mtvyCvu9F8wJoFyrvx+49?;XVd>;(QRse6qh zUG<4n^k3o@XQKuo`ajDb6(c0aij}Df9@}CvAIr0%eXraYV{2j!zsM21m%IdVP_CIc z8^-)znhVQvQ}hfR$_WI54;1%dWd~aXyVS8qJ0Ews#cy z#=WI#O*mB9!E+QAm5q;VTR_}|Y%-Ky28|u+%)UCDH3u_wQye{|it&^!m~S;LKY|pk z^RU|+*A9Oty7c~h9V61`rPl0tHxwGWG1RykwQ^+gAn=(gC-Sy)y8q*7)jeAIaVv&u&$23Z#0?eT^p~bYYe5uL)k!8#%hH4{_?zBK7GJ=z5b-e&I}UI&yxo z)&k@mB_J$zkT+@&3Gpb|wYv)Ih&f2&|27`QXwjV1Y%1g>lJV3V22A1p3_IisT5VTb za5LMEjXE!nDnl0l#xjFsHBF z`y^FsF49g71?LbMMGbvV?ZCJ78J7JTNbzPXUoQRC6xGz3yq6?F|0%&`Of3R0DM!i; zrlgUD*j0K5hiiheVj$!e%&a7JK(iD=BK-!~^HaiQMkV5_{m#8+@yazf1=8-(cJkz& zURyyF5A=9gK42gQXIu7#yVId8WV7NTR(kv{I! z*2DpA#l$=5642rECZ(SgD|Qe4=pHTyq4>$@`Zrs9(|0h)a$?&eD>R!jt6mgu*Tsdm zW|44oW41sa?R0?iA>sqZz`P?}#)DzBRY^|L8y?y3nZJzYe|~-ZZJqqP$D)RyRP&-) zS#z%sadSOvJ!|KP1tM!P>IE5~s>4WqvR6NUoW4k?6r2Zzw=tjdV+76=rewZU;e^8B{6yFCm|F0g=hH-WW)va_#rHM`vD2-2*Ib`|Zl0ivz zb&c#6CGus2uJ(gOxl<*ws{SzwG@dtL-MDKtVd&+jwr8{>Rb6Wyz>8ifLG;%WalQR^ z&`7h9T*YdaECHH&fB6vbm@Cx6mqQAILMhB45Wc*c3lX`MD?f@~OfAtAi`K~rg5LK& zP2<-zmUiZUVLIc`+a_H6EmLHLFZX>7a*%wZ>j$HnSl3v$OE;BD-!U+eptL7?U14!h z)3=6WyvEy@o_t{D~2*m zb;_6RCi3`0+uumWseU((M!YTegs_a=b>+DGP24Se% z9PPgpc4T=e=^(!@C^N=frGsTVI|ti71!A?j+9-Qex3QJ^pz~+c6i3v0Xef2-^Ka#+VcZr! zlLY+lR7a=JxH|OAm7M2rH8jX|r|Qq`(a%U>|M$j6x|)6utMVzEYfO9hVv){)7&k80 zIBGPHr-4>^S&{EpEiNls&QJJ}sx*#1=D!;;O19!0$}}_u5sSl!3MNdlr-*hbY-|1q z?qQ5`V5vTme*Oe2x6d2`XH@wRxkG+WG9Q*)>+l=2_Mgpu2#VQhm6V6USH3$Fb}7uk zKA`NS(}$L2GBEDUx~z#7wUeVQk=7e=#7u+~8F)Isccu;(NzKPY*|wf5}fX z(bo{^@1tne23qX}GrC&3FD*V$Tcp4TRE3kOiv}q7X4YgKH`!adU@yiy&%bK)Out5b z0ZX#F;>q&O#?KxB{zg}CFILy2Z(r=;w~Vh_2_{h9?WaYWSS$c_4kYJ~FKt4~YYJTL8UQm7Vf*O~0(g4g4Rb)Pb!i#B9K z7?s42&a_3pI&clT73rQjcu-!a+yL(w1nJLk<1ZNGc(=UVeRNp&26|Q36p6g2?WY>% zl}?AP;N0Gdd3*}FUG`W`!O)%=6H?0AW9yr`q~bSvGHkF)L zO}ZGwpd(rGOgp~zQxzXs`04fLaTtm949fbn=542ivuAp!Zn?MMJ=R`Z#T|)dS9WlXwkC zVs1$~0B;@9`jhk#LGb#z{amB-!&q9q zoUe+X&ynW%C91X3wdkhhYp9LoO@z~u87Pk@3(wDfd9by+49JEbkL2L6tVvo4)1!fY z)oEO&iUpaW@vJz}q`PgbZSt?10F)6>w0rZfhULn~oWSab+k)~%F9zonv2{LMb*{oRI_PEst%h8e~v1GfVb%N*g+ro=?9w5m(!}JH27D_ zt}$F2AAS_VwiGx|j*k4ZHO8kN{V2jxSZD`11nlt9aZ;q_(+K5;M|6P)c^r9R)&m}L zsG%aE9e!8H+XXpeW<_z4A@L<=dt}1MT3T~+&v(Cd6kx0~4C1Q3s9BhOdXORz7}y#A zBOz#&CsYg_Nu7{8X~n8PEP{H!wnY2phg4O)F~(4DxTsKSpIYU^Bv^yjC4!`jXS#(! zdw2K$E9SbRno75QFi1e@AVLg92}MG2Bq&1=L`Pac5P?vYk)T96p#%g&5s(%j7!ag~ z-lQWcY@(+lH%;8(tkP9wZ>5iQb*)aB5{Z)907%WHsMyADcK-vSo$rjt8k%x7?lS zLGg{xhk)44S!_HOZ;4A=)^KYnOs(GP5F^iR_l70wK>*6(buiD__otA?_V!SYrel0i zEN#O@e>S`woG&fC(ek30Z@?CP(vS?bx9ZZJB=8|+_K%f-58kzIyT!N#J&fX1bN5oG zM?m0?0{BU-bDne)q>rgywOg_8zIYr+)=za97tSU(qk6fkY0>mAe?`;F7cZ&v=g|_D z7$b__tUE1{94BI~*?LI_SyS|cM3_Ap{M41smrr1S$UZp#m3;pO_wcfpcph(%QC;9V-H0I6Zx|??sRgw>h#zPlyLtb{+;y4Edg=P% zsJw#B()UonvKfX%Pd{F|}(^eEe61 zna>xs3HS8%UWQZ^rfz4U9oYna4c-;! z1+JZt;xw7lUzYcDgK7&TK$Yxtx(fdPkuL`5FO#mC!5D?G% zPhrp5FnSoUW|QNB$aG=H55|K7I$4wC=1 zAZ=l~=206*>xaUtitZJj&v#W*qJLSzxO0)SMlH+&!x{G|P0#Q5(_oQznuma7b1aQ9 z!aaHaezk%)FyU{~GzY7F& zKToRw`c|a;-H1X0jg9Yn;Xj`y*bARp-Z7SM(w@j+{vIqfJ7WkvD@qrLI`^X9PfikinEpf!*D6`T52 zM(f`FA#MJ^RJkIn2Z6&f_Qf`$F4x>2TUS0K8AYZ?z=Svp{NI84szXv)pF}Qgdlp57 z#a`$c<6UihOjQM~uIo&T17-d#%KTpa|L~n;lUEOd;KbM5CLOghB?jg^D^F}0njS9s z2ZBJw0_6h+EXtbNO56HX8hm5ARVSqx_a0q?-~=C89Kh*fX~@Y-(jh>pdhUR6S$Str zaOK&;?e(QX{3K!h5D<=1qrt|>-^y-$KgPvH7etFn{?8NX-`F_>T#eVT43V;$I5{d& z7>d)p0Q|MywV?^vPHkUw7q!2;l)NKcOY7)0pwYNO=DwBbr2Dgi=E354QaohRnYhh_`;A=qga1+4TW{?e0}-nKs2c(*f>*`fp1 zrZ#C7(fA*=C=Rr}1jTp!r5`Q-*Hn?gT6dmR8(9^JAcFdaX)B7BYd<5&uI9P$m=d)p zK?u&s3{F|t)pv&HSlRJc*D&)irm85D{-er%M|)?q%GlVy?0Ot1q70db`y};#=&qju z8zgGJGOUnB-1v);5Mi^QgCsm)z|V!$=2Ge3$ezaNQK$%4H(p7OD}H#1AN;=SfeuO# zqM9Gh>+O|Yw8oTmvgKP%ie?eLmAWEfv+p~9_NQMvF->c_st+tb1e}mpUWqiZIiG$K z5fR5um{x;a{o?0yq>gp%!EnJ4bb_NNYe=~qUCgl5X&kX^rrIaoSQUXd&dmH1-z4_( zDVP7TfU}ZZ$;!HFEI1}wAS?HJ^4+3}>2~3@%hw&Bwu|h%noc@;Z}HaA0XcJ66UPJ% zLo&x07Vy8`N-Q-~Qw)2Y{+rO|&GxyMu(UxhVe*nfzGvdN`<*R$uSf9{yEw)OM8(0! z0axXYo|~WcVHvkgsZH{AFxx=+QCFtktpv1e5c3z%gs6|uiuniBOTHjVQDxDdVyQ`8 zUD205MZNPExF{>3-E8@FUL$>HTVE;%Gav5-79o!s%X-BeSZYFD zS{c-0a(OKQ%$9QFixa$k03Dmpb7>kG-h3sBSXSD8za|4;Cj?Q`YVJ2om%D6=^eZLqQ1J$}*N*+$*F>CiCOtqeI!OJ{rE0oop7T>m!>z9I4f?$D9^yr|>k^)UpN0@)8g`lMiaGjZ4`;ZY|2#znG|p}f?3lXNe{kx4 zjG;$)nFA*hJ8b2?Ef({sNn{iPc`|ipWrLWhuxs z{9yM&ma<-!%W$Mr-G~H;WNP66AVEqnK>Ca@}2-)=DcTJ)@a#|Ch3ZsLiwwRofDPPAU>u9 zXBb9_f5!^G)T#WT(hlQ^b>dpQC<$4f&^hmjWG4Q6X?G~#c|chW(ligf=Mp)d6?3<#NXf}w7}_m}rke9Tx~`(6wS3$W+jN>)Spl&AR%FMf>J24~>y?nfh?h)SLt1ybXfj{%3z4kwc1swnn7KnSElCa%;k6Akq80I}c zsBHveP$R|$egaUNu3I8r0x=Af8PsgbEOI_k5Kn7kI6xNqUCbEJB) z>7O_VUC=p~do-FW*i=2Ix*VfL?`*u)lNiG64D;C-(6R+0D%or^KaU&i+pCDcbt{F) zbzT`!6OQ3hV~lL!5`3?#GP$|-I>YnmTp&F0N^4)=B40Z4>Qj-6dcl#ufj9%-wN37i zZ)e0d?n+W6!3d+Yk+CA2-bJ+aQ*msBn8SXpxj8kzsF`n|A5;J%zpTH(%ae{xS)>OJ2tCR%KjMdybT6^;i zk?;1Kbtbzh2x_R7r(mCH-@?k2-85nxvbd?9Jk?T}lhfL(k`vV14E7oK)(P>s+Lm z_#GG4+9Pd7fF&Y>epK6RabfJO!CIC)Yb?$UiOB2|tq2czBgI|vlp>LNN(i!P@glE) zJi*3=yY*4(RYh@8mEV-V_{f{n)E8YkQ${@TUJbJL85BnAYj^HTCkZIu1ww3yLiwdH?_b diff --git a/examples/nanogpt_4D_finetune/figures/nanoGPT_train_losses.jpg b/examples/nanogpt_4D_finetune/figures/nanoGPT_train_losses.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f4ab9f0b0cedcb4beea812fac828ff5591e87f4f GIT binary patch literal 37171 zcmeFZ1yEeuwl3TR_dswbKp<#>yM*9DgF69&G>yA85IlGSBzSOZ92$3bcXxLS`giuR zefHV!y>st>>%FR5x2Z{W_gb~q9Al0(NO`P zCnY_f5HGtRD+ee0uZzH;qM~A;V-R6t60uX0QM3PtUyscI9F!-TPc-4-o&%oXz`^6d zJ+=cV0RT8e*ld4I_^&UxC-4Y}NXRItXy~vLDxU(Lz`?^mL4Zd@L_mO@?FBmyK)^x7 zrG71j^i0_Z`MDh)yKnRl6q+}sE%+)U$Fv+D?EO&D2ndOYN$6hCGrVNvMEB zD}#Pp+?xrL>Jqm#3XtDE~L|A4@t;E>RmFJEKh;=g@QNYBX3 z%FfBn%P%XhsH}ok*VMMQwRd!Ob@%j+j*U-DPEF6uuB@)DZ)|RD@9dtOo}FJ@US0pZ z`6U+|0RAsx{avy@$b|!w>j?q^JOc7Bx!|6-zy>@H0wVQmBwR6NWFxz0&)I!Z@ZLoK zC~ZNd;ZQlo|6o6YMnKECLU-~@wBIEA_XP9%Pf7N7!TurFEC2%@4z_skH~zKc_?<0Vvp?nzq$OM2?@ZJ-mp#7WE{51iUEPKqfX|PBkc@|?l*r)Wf^_gQg2A0U!11@Mj{J55heyD% z&$YlqfZnr50O0u}VAcE);E&w0bk4Y)dJRYT2zVw+j`Zugj=$YP_YrVMC5-h5_@VLk z5zvg%be$lw0(b0kk5)_n0FZbDv=1YT{r8xU*#GFw<0eu-QZV<{6(v2 zSV8V!H!wd>&Q<(bk;lVR5u``JFyzN0;5RK(9{XPNJOqHS|E5;>AB(Ar27}={_ZxS6 zDKI5tw=|tInS=kLWT`tUZ%mk0b;SRwRlv5(uM21Vc8{&575HPDd$d`~zvy@bW)PI0 zz}xD>u;u$@5kCCnzplIe+bwF2u0`)iVfIlK^jG`%pQ-fW5E)~e$^huog66a0G_w!ZUEr?RapXn_rDqV zgIeGo>B**jeB(zzhIK(kk$vBTtYqaQz_5Vwp43T{oE%Rh(Nq6~ckFcB}gUC$%nV)%!OeK^1J=GO!n_c~0S0|A;*^J4OAj@1D<{RAC>o%Ljhp+HvL zhyI^NP^w!Kv$@5F(4tC`3MPdBH4CxIy_p-Xd~B!8)hk$nk|rjBrv&%=A*}$%SrLAv zIO@YVH`hj2D?@8}xMuf}m^LAPKkgB5)+%zlSu^0#=??aXIG%VR_BUBb&a5*gf##Cd zBahJ99svP;Y2cxVtVe*S)FYr{lD$}gEl%c*qUK^1Z??(uLdYhppSqQoJYVCSNIwbqk#1nFeJpvJ))=a3@(n~?Lt9vn>qdKZpq)gYrt=6e(+j~Yt@gn6!MiPGH?F$$j~i`$xo*6bJ>G(xu5JD*_;chH9BA39WqVDJ)k zF5BBmn9uK)nCpw2d@OJzYBBpU_in{K3c}ZXCBy(<{jZv@pHjR|w09z>IjD^cWQAM^ z1Ro5fBs%MFPGqj2m5sFhRciKF1WCZ6qYf%ir%zgB`p zwm(~gt=eM$JDip97fj*q=rf0*Us!Vs?6_KN*++?KdKC3r%kMrskk_9?2E3SL$sDi9 zgg9ZSX@l9^B;L)%$mIeXlI~^k0qmzs5l-rrf-D;OFv(du-*F=H=TnZ z07u9Fjo>a-fp^$3yU9YL|Q?F#AVVz={haY0(pHdS8_@Q)jIqoVX z)$H_S*FuO^M=dHHy1bFLXFYK&3DeoGze@!tC>?wvwMUv?%fBf8P)vF)RW5fD8O&*- zE}sQ86D$*6x=(0`N!Gpq4va%ytluYr2;{Vp)3eoH4CNhU=iL@z40^y3!2%DX%fm7* z)kgr@XvHHSk{3LW&n4%t448rnV{BE z8zF!cBlF^mg>syeUO^8p@qqKW8ZjB(BFcE?$l(FPWs1onElF9-fiHGa8Aya}sd;Bi zXwOFF_`EhK%*sj9Jp3!a*Ffe-oWOMVNVIR5F`C0rmZPc?a`4{p?xa3?`Efolf$kt? zQAbR`+mcZ4aiEz9+tC5!5`~mME37>&ft3nJwl{Zk4PsQP9$-`~8mI8>Z>xGwPYn+)d=-M=8s7%xH>~C!t&yQlVm)GgYq&jq$ivWmK42squnO@|8f{v6 z9xoNxrBEw3Q0o?*vZ^}Re*}2$=Mb5gwpnM%07wu0x5pD?m}PA@&D(4;VO_!U2v|{O zSa5Vt3e<^w1W=BEtC#xK$$$U!`tbArFKg1wk7!xnd;mT4l>ip})0okN^O12b!+3O4WGc;($F0`{yx{5xA1EV z5kg0K(bpPryjr!(0a1sS+lU@?Z;kt2K!AId^%Y0?O9eGc;SNbgD)yTo9gj+&`K=8U zh=`}GZVUS-`cH$#YwSEL)27z@D;$2#om<<(a$bJ=oPUVh_Vy1KKJzZ=0G!K@>lrX7du^n%Y!c9&tW?1++b&rO7f=uQN!RR3JJ zaa3^1I`Cstk&df_jF-YCU&1VpM07gfD@Kb8S@da`pI3PY;p|3bU}d%i>lYcP-0y-v zh!CE4)_qT;HMB=2UJp&w7^9kW1TL_ze?+gQAiTst8JbF1Zls(6x~F6~khu(a(91r6>W&ees_Ip0oT z>gYAqpDb(zO|?h4x-;pxntXD6EeH*}EmQ4@u5|YG;fQf7l9D?ffCT!akWW<@5bMsj zV&F=d+Sy*Zpr;T^+mYlE@$z;FYB;QZPA1Vxl%Kq#%-t<{`>7{Ii043h2A*2WUiET| zQX4Q{JNe5=p5yLK4dv=_HogPdfU4y3h$Gsi$5RuD$n07K-$Ml7Sv$_v^*FLvPC$G_ z%o4_lO^4)twLW=!x^%cbd`|<;Y6MnbO7FzmcV+ILEr@4v1XYLjif)IpyE$T@=-xW5 z)`q+2O*j=JTe5GHI9HlEk-io20Im<7#8a4kE$%T)ER7xyGZ{vz3SXTA>8O2?i&q$T zoVy4b(qvmfF^S}5&eHM)L}Bw0lc^I1)T<=ew9N_JYkSo7eLpIKtVm#To;A&yqJ?79 zqcv8;dBABtP?UUcq}_43OaO;yDC(epTinxfOZtu)wuXzVWOq=#MLeh4v_d?d+Tu-G z8`GvrFE_zvT2oQ!05!`$X>SYU#r-b`cvWlxkM{D*7$lx|&38MjNV*L_(*kA>G$ z8E$Up>K5v+QIcYU5tgZ8M@7h&+~H<|C5ixOq?^sJhE1$9tw* z>n-K-v3fcD=BlTp$rjl^qpA)<=|0Tkq4(TV5cQJsu2lpGC|nDq%J{%knTRXAVHxk6|a}8*&HC zHk?ETI?nYlAJvOz&FQw=9uGJN@7kq^@sW{~XSsfooWjVEnJ6so=RHF#+R_y)AepQ!&f;pQSMVeW`I>wvzK|95_qd1 z&bJ?9zq*-pXZgHOiUIT>U}a4Dyt)H=6iDT}+Q^Sy$Z#GVkah)HWCZe&TDBj)BA4Dx zFObW<{?w}d5N%Un8B(yL4H=qlAF_@bu;EjAe`utoV?hpw_Xr@bxU8KD;KD0utn(Xu zkfA4V9uj!|O;3!`C&?tVz>Dq?5W#hi#>?sTrb}jwO+zVVOUb`^l%EM@1gFus&m;!0 z>cb$D=Nh_w-K3;Q2sRVX)bGO}IMo&%K9G)X{-6P-3%;EwdSDo@#71>h@vvp#5^@F_ zwkYw$eLMR|#qByB9(F#@m=jb_m}RiH$j~$;C}bmeP$-}y?`BClZ&s42AkLnPXKV8EDKdg}0t|&9Ie8|^7`pzq zsl|d%fGSg3jmOY{^V~Z{JK=(2-3`G40Vj5OdS*3z+iJDt(&jaKi2da3-qQ8zmhMyq zHV#U~CD7~n1ZT*)s(vINwXI$)wwQ*X9*{oRT2NqB+VS44&}w#-DG_MI0$FAD6?J-7 z5QT=}+k;F+0+<}XGiR+QUIjw<1sVj@M+?`Z500qS;*d`JXz+i=`gV%-AS0C#YD}Gz zE4F5@wWrSE>~YXO(~<|F$qKPAzI3=2?|f!Kq8Ng>F&HdSEWAm2akw-CQ(B3f>XGDQD_6E5}mOl8~&j4bYuzmQnhV&{N9 zv{7i<`feKw55TKDn!NmZF9c?YFHYU)t4G<^E6z+m4AJTezH2{K=u%R&&zux#9eATQ zLMl@;%H_cnyjfGt;$}?yI@WFGE0Q6c6)*42je3x8)Pml1qBiBO6Q)`_a(NeSA$3rx ziOr}FfBb)4p1M_UWL&VTZNlwtLNL}w+jdYgU_b0u;$!dI%FROt-^CQ&>7a*+2e;J- zKVec5q#>`rIN4z~l?{bH9UMLU{%upFx*FcRl$%c-76doY}0S-BJ7tedf z79_Ygcz0zDZ#^%{k8k!DuL?zaKA24^y*Pex=B1C=ijgLCPycd8Ju!SCYsfTEZV1cttL1 zCjmvt1;X*5982x2bLGNv@tju59`(V|{r-H^mgYg%9Z_<8d_}o&q=Fm?*D9!bx_nzt zU2xa)r|smu?v8hon7V0CLYXLAZ;tM06A-j>W*Ub!Nv*2hLJiX z9Dp_+R+m#s5`wrGf+}gIM%!Y`N4^TGUimM@-Hxc@J?*S;5sI12AH2 zZSjVYD)GDb8ign*+~k)6owlGdaX1cmw%^u28Wup=Klw?M`O3MhRhEop)YTs^)>)4( zbeUn+bzy&Jj=qh6PsHc} z_?&QK%9`913_5)tJJy^a1I1f1Ni&^3C%rvCZ%GW*6n?YOGAm))oh!NQ0f&q|ey5b| zk*utZdDQhf-d@b^1E`CYuPAvIo~nihs)Yqtzi;oGa2f=aeof*XNy7s=Jw0^YGi{(f&t1$kumGisL;IKG3#77#_ z&rT19S*ufq2r426cKF1!-rN9p9)=LOEyFgd;vQZohrK*fJ6|HvT(Ugb;c?cEh|PUh zcr2DH`fL!r&&rg!P|6m|96mFfSvjLF)3TTvaTA(!F#>dKQD@by#N8G;_Z~Ux=-nZWY$*HnXAu>rqySwtBJzVB$W)>g;B}^ zX47kQmPiyoTv3?$d%hWN)^Xg8TV*ep@a4dchZkH69Y3B5raF=I1s2ZEF6ijY`;BCc zBHE`#VZqTL6O)}Xz-x!*6j&;4&vISn*_Lq-p%Kd0gKxhSrsh!=jpv?<3nSPZJY-`vmm$_I7Q`9$6T3*m%x5S(Fw>AH z8L%c9ygW$}BM~=6jgzE(HkL+4A@jtCQ@XolM_ETl%nU?cS%}kuAWOiTE}HM-2*3gq zdSQ2{Cv0p?M9$b)aiN=P#0}n1#UNN6*!ijbmePl&w$)BgRZJzP-E#Fm>j62~G=XCo zU#svKnmAG|o8X8NF8w)VXDxCEmUvhLzt(hVI#uB`xJ5JoZ_jD}r9A(=llbF%rlhh* zKuk0QJk1CLrX&|Mq;xmpVHhFsp+#NfhI{ufC0dHg?<*lxPEE-yk^2m z-H!dMl*T-kb14EW(zT|i%XF&mSvP2S)wol$Ptlp0hPe7SBs4Z?uymk z5?KRL=sC2K1SD5Utfr+*1gRykQ?QFH%ObDoc77?$iSX3P+;#4yn35R{Sfdg|uk)>- zzEahI#g~GSJtpzv$;g$8I=$yTAc^_SRZkA?Eg{8z1>}-u@&NWo@)1Y2r`PeVRW}DI z(S}ZQer9_1^7>-@(q!v+r^k4P#qmjo9~wwikE;r24?62s2HZI3Zog$joGt?G1fSs{ zTf1RoB6W)GL@+K4wH=}B8OXV@eiiIV?T^K^z~!T?(wvYhFg0ZgZiifNPzkxQb^oGf4dLOWt zGEE(3Mp*0R!yA5=~4Qf9S*!6AuZ;Mtm`lS++slQ9sU{tc@u5B5!Bc zN0&mu#yBGQ>XCiBJw{N^=b~e2dv3)`>@(8l_Z5Wl9-a?(*5!1d4P62gJ%SG*Mcg*0S(4oX_UtI#G&I!V0HZs_ zM?mx&ky}(P@XlK&M4>>iM1m_Y& zyXRlsuON7!?bVBANUf6Nda+uhFqkpWF|cdFBbna3LR2i}1nN_a){7%3TgLf01wh-y z%Q!({rdrZ8k>a9cj@pY2Ho3qQyc2{;jq(@abmxRs(chpiWbva>skq&UcsrDjn*u6} zaa=iCE%NsWK1o*jjUEcf>2Pb&PUm+Q)9?|InW1q+Ow3t$V-Gu;+t8Gjpk8*&*wOT! z$}p}Gv9?#eQgk01+h?h0;Ga`0kL<0i3NZ>4>ZnQ&gZmQWO1=#raDoN7^m5;2yOzC^ zF{=!^5y}qPYEQzheGs?4xb3X0Ck9`3HVoDkT6#RO%iU=*#koJKyUkokxwrEj`|b2N5rh$pl=W%3av7cvkmZu~FMxBRUiC%f}_#c?e|+>&BP6U12svT+I>+ zrtm7gdc~Uys?a>RlK>Z14A5TLb^z4L-bNqtQmQ-CAE#ky7{n=(UI*OKf86KvU@wNu zQ+IuHhvcWepEwz|Tf#S3rK~Us0k8$7S>6!(d)mGzsTku7Ik!kaM+Z(;Yi8j{6UW;Z4Y%3ASLd2SC)aGA%)7>2 z4(dG-`0V_7Ux`y5zbif%BkOxv4K42L9x@HE*>b~^F2!-?v{D;-W&akJOcsomf(oK67gg3q7LojqmIuRvn3L-YTE?wibme$Ez zD^Hkn(eRN}J-EcK6W&Z9|9Cy&LRUauwMdDxSZ^#~VgZ#zdAJ;I_aNEfN<}xXy@v89 zM z?n6iL0|hdK-k$>&_-(@PimxvoMn7BgyQP{%_TM>#PLABY%!HYyLr&Ig@6N^-uJf}n z6N(rCfGgP=7h8~{SlUi8yy9iz_pO-C&9#q1+!`O+4|%9h_=_7!SM*9&^q89lOURxCLXVoSC z8U-f3>~&Y(3|DT2R#tT=Z{RDXhZE^8(KmjG9zw32^$L+EwX@3$g`2CIw5DsX3OxPo z_Bn=(wI(0PZ#i$Spz{0|KP??syiz&T*fxNkiVGS|iXJ+a9zuA^+GH)ir;Pl1|EnYV z(&ASs36QXO1W?#M0;(A;dmG0e0j2eVHV++c9vjcW88Z@_vOlRj3!;5 z5bILOV$k!hg1l| zO@WuyZ$i+PE4r~~_ds5-5^gH3W@6tNPhEB|F``JM?1K6}fuo=D>*o%8x+rQINI zP{+N~y==9zrZ8vi1hXI=7@W~lbvl69t@oq8{SEOL0%iS=AhNR#GnrKxCrX0tssz13 z&b1ujJogs7qYxiW4O%jjIO|v5vOfGb8M%Af>JhoKDLG~)v@4kXdlX06s~q|=l59H%!6ZN`<`T0Fy2Q!lU)FiTeLt%RW z(cYdYxy1X?3pp0c7Pg9?^?PP65?)#M%Dt~z0LQ&jE4uUnf?6&-VQV8MldP2UcW?N-4^7LgC>UNVQ2^KIz0*tdJg$9QnX zPHU7HT6!`dmsO$Xi>oi@XjPWpMg$G}%^<~k8mz>Zdvu%-L`{}fhi?&LVa8pfSc3tgXkPT+tAsKzxv+1rXkl$@-!?mL9Pw!*-y|Ld4R>*^59ASQ+$w z?}sg%UY(4rP$w+daPqE}c$l+eOt}B#-AD*4neZdvbsW1WT3Xnb6fdUdBG|hLP5zD` z_c_H+O4W()zvxttv6|pOT6>8k*~Q?ZF#}* zvQeH8e2?^lQ=Nn&-p(@9&3+)CC^=A|_^aKZM*=iMZ;z*Sg$u&kIRe}1P`OjRl(fYs z(Pg8%>BBM&dsuP=hczvoUF=&Txq0ceNCeO4;nsfaX%u`q-b^;m@>DrNEJ`-Bv9PW+ z{f?v4PNU=^H;tj;CW=?dV*GC8hfJls)mq3pmEpp;lpIR{Acv6xgQ+mbsz zx5*6xTa2677d2qAF0$E2fZtjy{z#=mtWgx$j?FS_O`d;9eCmz~vg16WZO~UJZp0z# zD|k(ZL9C125Qap&6AIYfIxz?W*(7wX8H}#fOU_Q7>PwcvZ>uJ4R#;mf6dnw8OtdSB zR`-`xs}WOSDU!t?ujwm<)7}(EoL_s=a6{=j)4I#N*WYHg5TW(&qg*UgRF>YyI$301 zKWl-OqvmMHeQw3tcs8-U};`nL+L$Tm#|H_&)tG7Kjw5WZB$fF=iL~+J8Y5!2f7yb z8w4u%$`bj=Qvlsi3-Z@Xn(7^2lyr@5^7th@letOcG z>|*Bv1|o%XR?P+sCS6+HrblsFWvSw3!bkC89;nhJ-fE~FV&E(dCyosYEB7%|9~iC;Q`Fmp&N+p=XoWv z5JwDO;Dxe0g`G7s2VBAzR20s`aG{r$(*Y7?EA^rEP=4EoTVqz>n1aW%xI8fzAuQ%R zV5JQlQW|+!;41Sh4H;GPZ6;&g;N>M-{TKE>@HDKvV7e>{XSq|Ip`lS?eAre>A641_fxEAV$j=VmE9-xq z`f@a;pORo8cl2Y-KUq^H}$ z)frnA_aYL<(i%(zl49-@=Z^#vd5>BX&l4WPn_0+2fl4QohU)`5Q>QQluj%^tuQrS;&e z31>m?7plUATZhhfGV2;mDwd3tm2)r}qAQGsm@hnK^9T@H`vu5^rK_tsGV~H_8v5}* zuVF*DCJ{kqx#s`0#MSrw?GVM%iv>8Gj`1z)&8@JlEGJ6fFP`cNtTCYK^^d93nJzB? zC4+87Bu?JXn(wZ+SzW$lj(W+*5cBLxQi10sm3mxWrfY=P8?E!=XOOSB0O29-lrQ;D zU&$u-O1hs#>wy3CvQ$l2kk#XJ0wv3M)K@JbXBV_Kj>KufQP5gydW{ytIVziCO7tRcvsi}OS;HQ34E zY2s}n&<)fsXHFdnjHSxOu=Rc)re4TH~XnD=SN zqYc!RR`qkcEP@_5!)VEiBDcGd|4=WWf5h>?deFjT;$mamg+Vofi4OYqLc}dZBtj@3 z`ne*p?wWJ%o}Z2T5g^wXamWBgWG{aC+$0J~IgxmcN`33^i}*Wr!+l@%LVm!NC$|zF zyY$p;fdynX@J$W=NmxXkUCK;GJ3sQ}@o3kqXJm=xt0i#3LQ^08#_sE+i>)2)KUFW1 z#^*w|N*lDLomxA!>UuG2y+l%+ZVjs&e`cx1mp9p8z%0py{4WX}v#VMo&z-)L1De&} zMbXfIW<2SNkwL?j&!Z%^#TO@?q9pwyKqA**`{F6=D5wJ!YX!{n9rJ zY;^4+Kk9iTvaen`_p{FVX!?gkSdtkzyXDfK3jGvcUO@QH)$9m_^$Rx`d*A4m7M;M#7S)=gxhVcmr`$inNx2Wax7<4rcDYR6nyvcj z;n@`Z-rC;b${~e+Tzyu4%ysits|Jlqj-hbB*(!;EjR|Lt-VnAqM!oA#_1QMEf7J+9QjGZK`Mg|u;-gsP;kvFq*rU&%pS`apin^Uf;KuLQw z+dV>Hs9KC*RwAe&^Ct%2Ftjr8@wV+|@_WH`fDNf~g@>x^zB(;*J}k>GkS?;HA9@`2 zA+^x;on<3l`N-t?+=Av$TYIRP3WKL7t2NMIJy63dCC!_yW^hG zN>yW19`d7FtoJA^y1SO^s@al)o_yCse8DyLX$G3^>{ zjJ>?|pSgjfJ0>t9c(^zVyv#n)Q%}575YD(Hm-x2S1KouUv&P>lJi=b|ASrdQ@=qsb z(6)r_Xn^VZ<^7};3n+u<&1)MxC4^gc|JfD=w|O@~nAbjdailPIWL-&A4-tRg5O=xC z(ozp;M|cpxacKMCPBz3=6jD=cNoz@47w{w-F(;PlMy_zyWg8=#6WMRcc zEsoCam)P64xId-`PSeSG%S8tScEYU%t5HOZA`92mAR8l!1yU zM24^*0p%oz1-*?Jdfs~)Wkyh-6cwdjGc1QGRx5hP)IgN@Ap>Q?GWj6;_F!^W|26veJ%TS|b{-Juk)=1g*!r8B$ld;bqyO#gepP7;xesw`%n<1td zZ((iF*(q2X^xq|0|8;}-7IV{jCeYv#trUj;LDfd4{s8a&!=+%OoFuvHr>;S3#wb*) zq4hFso~!Msx<7sLMA{-1Vh{NB^RNo6$3zS9bu1Pls)uYKZb1=(s6~4zOm&Ur&iHD0Wds2I6tEgPO^h0TF5ol z&dEF8*6PtO>`8JJC^&v4`21aK_rTVjnX9pwDeF;+ zk8|;abyqo$>f#|I`n|L=J1)}EIEjOx&&*kn2sKRXSv1YY+x0mrKmY85ulaJm-vuBI zzj5ayV9A&rpDR&dL7@E=&Y#DY_9So@fBeMkiUAf?M*+< zeEyB9DRSt$#t|Rmu1M{Iz0|vtB#m%pt5@oQ)H@?bVAk4{Nycz8M*=3)3^1OD@E&Xf zqaJ8cS@pbc>3|gF@SO)bF83C1ONVZ;Y_jyYUlO-_Mn{q^?HbE4tb=b7fS+6l)&G}7t!P@__Lu2*trlwAGNR|BFNG|{EB0eu-tcf6rHGN}$m0VH2S2}{hqBV9I ztxb*zix)Zn5-zK~@#ylDjt|+=J*^CMkkuj+U7vU7kL`!HtKm&9blXa)0>vggT?Omg9@hotwZ#xaC4S4nJfs)0TG zuME>9E|BhLQ8kK_)jy3A5%M zzLmjq*f*vXp!?SvyK&JVplSNFwq~`5G|O59GO|FlDk9%eGrsd*xL^MP|3rBmf4tS< z&3W>wwz9f}#1n_v`J|7D32R>(O;Y9_P?WLK(b$x30a<}KkA#$n4(oT$rN?m7qoB6T z*GDA3K1f;3-6PS}p8R5$qsZ0%dazFGVEuLVmLQIPVYzH<$w{l4T(tJ!Y_ZVA-mcVh z#qY$zlp{8rLJk_?_#$PjnS`$}=euaUF1m8Sy*0&zOs+0GSD`~Tg8TU+7|Th~No!SL z#_lrQR`3W;Tv&mDSbCSQ|Sq=-%}fx*<4k(LW&S`|1(VQx>#HnLw`0Exs$oV(7`* zAHt9eC$bkjAuP=5Q|(=NM;=A_vsRA)g)PU|YA;6} z)tV?W&y^+IxBi75^Z(|-ct26Gec?rN`!05drXCuzFyrpCs~Y`t3R_TE1bV!mAUNXc z1zM|mC7id}(=KLf&@*>cbr!ZJJTb918(3$23EtO3W6|(F;_*72`{~-rpSO^dnk}@K z(l~MF@dFN3E1)JOfJ9j+fNx}r9G-Ej&(^+4cY&f1^MEjF! zej1`d%9^;4<^*xe^|T3(67WWuZbV1ml2eBoi&*}ESXB;p^b&GAcdvn4`wagA?K`fxf1rk8~h062E z!nnCAOn_cC1hs>BzXsCTFT&?NfeP2#?l>_Pqr>_KOCHW7?soM zb1h=oAT5hU4g#H5iQ@dOv?Sh#yZ!Z5QnYeHx7!{vxseR-9TY>|b>;ii9Q zIB?=xqeheGQ>7`uZ^q8*Y@nCnos-Ccl(J_9sh?$jTpyl^kEi2mc|oW{*ZFgOUk9>; z9k#hgxe9+PRxN&V{pq-?{idh^;u^j_rao*Ft7c-@N~_P#eZP<8wwG{GmcZ*NVPN(- zIqgeTv5V=CaD7v=XZq-U2AH18gzE3U?965$MYtL;;!Fp1@>1~%I?w6%MwC6TeX15H zM&9A-b&d(~q!@O}%?v3zFJ12)HUYgmAl`Pd4LW(zAksq82TUQXthF!PPN%2IGX;=E zbcqZ$y1#PuRNdxr0*gp(M`m=16d5wC#qHS3O@OX4(sQ@i3$Y9(+(dt>MZj^m*P9CrW8i$fRMDdfDT7PHiH!}-x|^SkiyQ*{b9XRZ?b(%v zMqj~L-Hco8e!UY1@d4F29{#hmaZ+!R9b%>+Wp#CGEM|uRj zkM9Nh)dP?@d`9}e?s0rVT9;dBwN+QmT+%`+N3F{ICby><@TGYh2^bFhS%HRSF#X`e zRBG=66nyD4IKM?S6hFl^u}7d{sZDq9&ynb64ay97poc$m+wGxO*Tt&OV;z3S7}&nZ%!0GuS$H|;n}ereo7ZsKpVZ4y>3Wk zHm}}bpCA%YbiRbMGwWa8!LGe+k=?_pm?-G`_0TBH5RrF1sivstKN=;U{8Dq2PqQB! zvL7alRol=9XdeB=^rAfJ@z zgxW^{?&Oll@D-|#l?f_7arK+UhQ>& zslkn>!!%LmBqg9oCxB>rAmu@4{ejnzX)Hg3=oCT&bks0dESu6XXn^j@GoBW!Sld^{ z-&T5nayZ818iR`Fy6K?d6XbBp0d{`E=~=lf7}h3n6~E zM7QumEBZ2-yOb!!oQYZ&EczB~33NP|E?}ok;@4-^ZN^3Jc02kt&b2cQNi%&pki2pu zlmvQKR)^C#YM<9@!~pms-NM@VaH{BxM1$jZ_AgsB0u7~>Kf<0P8Ez@bQhp$E3HQnU z_fG`N_8EM|s4^tj*RH2Q`w8v7jW77L4TWjLIOMTWUSGAeK}OGinVpPBKebt!5EuNue!2 zQB~INqRW7^m(Cd|+T1pD69vvXv~-bsm4{vxP&-y}ul`I2HnDha`acagIRhalnPmg} zdtW&omn2IU@W430-dD9i`H`vM0qTfJ&Kl;~321gW+r&;$;_Ujx7A5idc>S@z26N_T zpl==Jf}z;`o2iUR!mxdCMLlW$99-P3fON*hmrct2DbM+s&Cli91vXB(hPvv{(ZJ}= z-jPz*>_Ke{Z>E%vQ&PF=?SniV&-7Z_*zf$$mYDQR-6C5G=4Q@G6J_*i@GJYC0{GCH z@w-+@MYPX%L)FWbXOC{04D?uUh&5E+UrnT{%^f`y!q&y|pZjkAL8E8a%JAc$Ez12= z@VXgak4pfRT$zpPqp=TsguLnr+nrC&P4R<2)8~>$Kva0~y|dybQ)|D2lPe+JC{0j~V!mj4<7F@|soi&wm~ zjAG2QfG_=Io)AXs7`3R8#UZDWebF~g*S&sDPB6$c6gbuA7|qzXd2N^dVc9Tw1Jx?U z-t9&sXJLBxy}lw#W6|ZXP@*{5^93TQ*Mjn2dZ_FwPCDGqZ>FohuOzO~g?0yFG?`ZL z3V(&$;Wz&Os(bISrq*qJGzy9+h#*J@l`5eK(mN_$sx%2rI>dnV8c;w$s&wfhT?jpd zj!2UtNR9N~LQCj1-nrH}Yd_ny%h`M1bMN!}{WUYcIcJ`kFXJoY9q)Jt1mMAqCU%Vf z_#DhhVxsj{II^hn=;eLkgAT3{ox#PMcDJ;{N1Dhkii{6~IJ%?LE_|LsM~M}={!VcJ z_s0(<+IaP=@`cy5PXnZb({;BI@&}nNwpOv^l4{J!dh zxVz|9)KFK!TD3s@5FhW1YqK389$V9~*;YpT6J*gWv?a83-}lmYSjpWC9#cJKmY0}r zT)-7ELPjKJWWz?NEWw8Ox#c)AeFV9^yS{9rrp)Jgk5P9Z<`uwJWtvY6!$p!^J@Voy zd5t5}9xNHx+^p3hN$aAAF-nD_>(FmK-){Me;WDoE=mzn$6##M5zV8*h9g;qYLYsE= z=Gz@-q!ZpY<7FTBcPHwH?ZEBFOyGkUF>0PP^|&)t}vCxnnrls}}IFez?Rd!uPvVT2S!j3<6VvooZEsfsv!GI-Lcok6ir zIos^Q^8pP;vPT=x?_Y|WG5ZRF3X@N zoos~@Q6EBVjBkHIUnC?iu3(+}ap8464dFTH!}aFJsJhDUuItb>!lZZO?F?Jug0E#Q zj?zXj!Tp;u=7F0{8slfk-G;P{ndnppuMU@p-LRrh>w1O)3h3a4!<--k9;@^yaU5Ow z2hyutd8vhC&|BowKa98!b}YdqinnJ2$ZiK*xxI6x_*8>>RewD82V}-Ec`7@yA&I=Z zC%rqC>c^rPHYbBD7X@8n6ORf={a|nE6}LkgrgI zA;yoXqOLL__s#s%#@B>tmMb!bD}LMt=cO$IPc9n{x+Y;co&y?x%_`w@RGXVK0f`b- zveiQkS^_J}Mq4W~+ZNF(UwB6tOh2LPMtks6`hH=ZxnYOcczd?90mp$UKMMCfiE+-4 zD9P`CXJh@Z3W^Z);qPOZt+4MdfTxAug-hcMHe%Y_Ud!6{NQ~%%R z)n4fIEd6o|e-Eb@aFUqHxRn5PbZ6hC0aeJvRYnB3AnUwclmYdg?9Z8dJkNl{qnQQc z1j-Af={dJrh?%n?c^KGQ)YOXSOq>Kh6nsA%U4Ik=1QQx=|h^=S7$#9NS7YTwNjhrT>cd z+Lp-0=J!H9o_-zkt)1hG`r_%MA6L}Mp9>AP6%!@9?cg^lEh6^%rJ!Yhir)PFE#w~x zDzr^^%JDeK7Dx6Dq6|U9T z>d{hVzYP|#k%-2NiZT=rJ+EbpT@=1EQtE*U}&N;+kpMkR7!7!5n5oo*Px}s2r zr}NNUVrrgj7u!c=u$(KdU!pcjhp?0oD`(Bzhhpv+w+cEzo*sowXrlOw#@29mvst*( zyJ^nnI~l$8-0bGqv`uvi7H&-_PZX*5_-%VYy?!K{@pr5B2{^)<8I#smTtaan;b+P) zUyF=U7?;A!7K-iO&mc?R2A6i%X*!%@Sczc2_)II0|IO;91SF-=MIxH#DGsIc+aW`S z5c|4pmPsSR%Vf>mKCg&wGmu=bk*Xg($eho-w3aw6pHXVXaC^lz?sbRE)g8N~A7Mhe zUTe3OY&=RD1yN8uAzc6FJ(a6Zy?MlR(Ipwx174G;Vtf3Z%2ez0xn|Vkr|hQHIX>#* zX3@K0JK$)#b=-P+8{S(0zlmrXRI+vmk&+18A6h8AMjlx5Y!^RcxZ6xMeWbFjRe{8= zLK^Ch_*U;+CE!kZoK9*s_j^3${XHqD!@G@^EqxeoWY;b=D94_b zDhCYjARvBhXVebN3B@q_MCWD3t~o3)|Vd9n6amyeDM=dBlz z8;!TZv62=okSdw*x~8-rR_RM&S3Y-qzkWV*`1Kq*w%*g}-qqS!<5laE4idM08pt~q_;Rv0Hj{q}DQ z_@A77KtL<$|BD3Ozc@;NI81*WzoA{h%;a{F&kI-t*#a*{2J10ikx94HQqCDkcEzZq z#m{bx_RcoXX&ymO4}XFVj{v1;|1Vis^x-zZZUm&|=An%r%TJJeONv?N$oKyJYw(r6 zCefp3r6Qq*d}Oxhn``9?S?oTcUG6O0Ubmw>J0yJ28nc-}?kH(`mbPkYcMm)ieAnn6 z1s{#WgQ%MI390oTtEJW@HsbF&&EhE{c%a4-rI~QJI@{ZlB9&%|<-~35S(W6`89lbz zLS}6&ddu&~3D9XlhpaPELu)9qygKA0%#4tD#PU+zdheS!(<>ERz0S*v^d_Ks*=2V>6C({m(tfyYAe0r`FQZ= zH3w}Adhd7Q2QukiHIu$&gG*W(s@skOf$=Qex&B@g^8V=WCjwzF{4wa&dn2p)uE2)s z>F;~S9NrJJWki#M=sthi{fUAzec>b$42wdP%hIAi^}s82f9CGG`<^O4dY|0i)2r}q z>H?!qyvrQ;!{ehrwh&t0P(<>)%p^v7xiw!A+StCSVq^83rehE^dwO_6Zan3}ITE>G zC|RMFS1-6BW{BTvF`-@P(717?Q$55I`CefFLVGPhpmP~y!9u&-*cGbFx1=SL8U z><>fJg+HnMkmA!&bXikC2c)QZjRX$n!eCKw@%>#x7@P@nV(TrY4JF@%HbJ(ASExz-Aj!D03`DfTepRBzTI6jwgS!Wb{h zJDXDP9c?#UaAFO82D6SbCo-~e#DCNCqsGjPPZL*Ln;A|;NQl4Cygbk^#sAE3s1emH z?~pip=m_V@Pv2WF?coy4o33^`P4(w{3u&r{nDqTN#%2Fc7~_A8_5VIf^39iqG*?Qh zGUXDek|K|<_mh#tUKFvI#QNo+i#$!+=-eA&G6y;VP*uenURmy2FE2PI5ilysMb+dT z4aqXGVB1Gx3q-C;PPnYy-D8es!2#@Vn*t2+14HEAt#OBpZ4Wnw*9{%o^?pV#LAP&a z8ZJd|VVtFI)nV-`|W6 zXPmwI?g7Zp-AtTG+!9Z8Myy;)NHoBqNs@pcKnN7`(tT@B#N!OJ(RSy?GZ>Wq{brjf&>~zutCt*#} zU2-Q(WGm#^qgki@QZQrEkjOTRGFrUbuzsq#x|Z4&=PqPM5B@NWvxxIIq7=AWQq^$| zeN+_-t#-d>UM*)C>#4lEV)3S{3s5K0d@WkSpuo1ZD<}wudXTm(s}$xh*!M8ef(p0mYPiy$ZRw9l&7Pn z^er5xE#Su2(_=8rAOSTQU^)z%C+us-#$xg|Wl-y2gR*2N+@ zdm-xm_!`<1;Ws+%F~DZo2ze3Fc|o7AsE5aFU+o3j?c`ujMAD_ODP}fZZ|nY(j~8R! z!17&uvx#v1K(1Shl;oumAp#l`3d+tGOYi+xg*@i-LuWI;)V-jCVPXoEEEOZ7GlUo- z$TNziuC&9>xXVtdUvDmggX4)V)Mz-XQ^2c*Ks|qFzR}4Wz%RC8D`~J&5@O##>LQ@Qa;zQiX^6zyu=G z=Y`uGS3&L5_*x4g>k}S_)EzbUvl35_R=95_HR8>{HE4m#r}8K&dKO!zKxV!rQ@LyY zlzbtCrm&&rYQoI*hUy)6hqb8MiO$i6kd~%9*o+wKK)vrgN}hLIBWp+?H2h^mJYp6^{nrxK-CaNpc5DkVvcp0>@|71VN0NjTEy z9@(NHv~vaK;W7qx7DO)w1>HMw29Hag`J9Ua_CBt#|*o!nuifr^OA7hRK zm9gM9RLUD<$6JYsV_|;Bh!^MU0aL0r@5EiVkfon|<*@A4cj@xtcvi!uO*3tyD*=t< zNgINlPtDh!58NFUDy&-LnKwn=3Kg`ipvgTg_kNY@?IDCF>AWO~oI7^Z65K*e&p2y9 zw&6nMLOZ?IbSiH1gANlCFdsjnV!#z%q5Z=pW%9x~8dM^&M?;pV2PTsQJmoU(?An)? zR|ZlUq0A4l=Ij2Kr>6~cOSUYbUqeozJFk!4u;4E}W2V&>JJZO|s;B5;g|Gzh7FNZ2 z_{7fS`UuCE*I7qenb)*85HmA=1>H30;LA)I^~r}ycjd#~GzPLqv-Lb*Tun@h-yT?_ ztb(q&RdQ zFUCmE&bKj5%)VMEwAPs5`cPs7S^SphJGj@R^vA|)ttm^xCtF3?mI;M3&Lb$pMHKVQFObU=fZQ*Ixg z3Yf>oC0j}2{E0UrF=Q6!{3ew~{!tC&WMSO+0f$>W+EnMm>4br&k>_OgVZ%9@Rr$@! zu%v}T^~~Y&=2jo+a`?GKrifaJTK*AgGqL|>ZgELDh%BjUQFvh0Sq*MFH&5SM=TgKH zf>N9eTryWsX6h6x&2WVgJDwTm*&fzeHE1xbQvz9D#oaWO4hwM+5hoW+RUj>-hz9oq z9IIwq7E9!1DjgCK(Q}gw4(fpRj`*_SNn_C+;muEmh`Q?pj!(y1j~Fkx5Vf z!U+Xno4sBwCgcmW?0`em-s1fpV!$+mg1JaD>US!p$Jpw>=p=%H?%vzqX1h#(fsF#5moA^ zRXw_aplP0tRgV|8yt%J~onqdPDKlIr%eB*zrPcGJ8`z+08-K>l-r2h&EJ2#5(WD{6 zFV$l?Vk=2KcM6l^GcNl&knV)_(VXPV2dLF2Z470myrA&2SD|9wGtCZu&@l$`kXO&&FS!L(T z(gd!oL_j7zChkzFRY`YIE!8qw=KKhn%HY8K7h0ZD!WoL>Nd zlqF^{W{%&I(q>`3MrmT6u@U4>+rP}kFF6pDduiL-)UnHOvtBBek?vE;z&grO$ zvqN~o=w`-@P0<>9jx=}UrX#I!)OGt!PybfV(k$E)_qZO`d)F<6H?o%}^E0)^8^2Vx zm|KNLq#?4aGv4C!K^Y|+TnAYi*{Y7S(qZZpBNJjdXR|@ zqFT$iHuMuzaYmmzK_0wA7$r0A{~n*$;LvOG3xHm}(c9!1 zaL``q2dak!x**Vf(58WcOrtYOegLt*z62Gr9EnzkJ?**oknboqex_&2H10R3r~iK8 zd6Jpior3cyvg$Wu2hrS)oq66Pb(}e|5UP7BddnVVKxD7$>KvYHJG^CZN)ePi(9m$J zgX)Mw#z^1ENT&Up`i$Cl*229dDyX^*V^3Zh_XquD0$V$m%vC{b2fFtmK4rZbdJ)O*^tMZzM_0VaQadG~=A~ zPKg~mv#7mZlc|`a&`W-_bb|f#&N%SrN;)r;bP!Q3G?SkRR2zP3dh3AW>AoP;01C2* z+ckSBPG-O6N7d<@NUa9v&iyKHuqfMPa4BSXWx7Gd^SD9f<7U0F<46AJ2Mz*(diu?3 zcw=<>S@fg-0b1LC^}9CxKT8$)*TnDsA-47>zxyxu?T|&L81UuWo8O;{K`auwu`Fwy zpq5OUjWXNPADpUwKAg}OfSAuoRiAI4xgxEMr4>fkO>6d?Q(~fc`7>L=?I+5ReeNT! z!_6kzvZRd+WgpwIDV2BI5h5&2iqh#4`idqHti$ol_jV*cy*_T(NwZ1wc zyc0wrI@N3A{kqU~2lUAH^2mY*1Z)V;npSxf*c%KgV!Sj zw|}gD(e#Xv7YA=D^Cj%heBJOP2lQgBuz|*9$6G5lMCS>|iWvIcl9P%vv3%fyr=xl~ z3n^(GIQpPG@mv|9lES~3_YwMjUb;NIP`}>aL*p|z^r$eKRJA2ZB&M#@CBMi>k;UJY z?7g4SxopXFZ8O7>^?Lg>J^b#3)~&nGsB&GH&R4}IBBD+K3|Y`_uuz-7c=OiCRB3)5 zN^BKf7*gWn-X9rebrLmT(e(wQ};6XW5t_ zrZcC#P6RBulSqSRlz5r~tsZTS7=k9NC)mT(iw#%mn-%uNQxN{}%c=R3CX?j`wfT1C6ZitVA^^?tqo#CVR!OgkhtfW-~sCPZuDb*>+ zrHq3@_nPw|P3amC5np7s988@*S{?0UR}YZe^z!5X5~ajg^ACZ_TJc%ekSE9rN4g$9xBMfZ1trH3sloi z$0Oiq1y7*l4jHW{IeP`VYO(gqCNgn#5y*$~h0jXtW;&VF&cWO#+NS9%MmFQ`PO=*l zVMXVXu(#1*+|9Z6m8k9FC1X1n55j`0OpqS$?Db-cQC8~+h*XC7U8%j?#^MmEZ^ZK-vrL&?D*Zbr>9~We1OrcWM$myh04daRG#K-340&RLLi&`cSlfwz@ zG1SA|{m$0pXK??R(&545HZuSErp96C2PSL__kADW9D4MqnXLkBl+&?STFH1$&WX>U z<2)#My6*C`bOqYs3vqjl;Ecqm8Qw+-KtX%rF``n6q<^LaGktcDgeOeEuNhMe$Wwjy zYB81M7LmStR4_0pUo8=-=v;J>xtr)6ki}Vn+>+)T>3A~-SJnAS>PR3!@nCo1TIX=s z4qzFkjyS81998ia0wvnF96o8+HKauJXhL`8*->_|UGEvid>hv>%;Lct!}poThr~jd z^O57qN!yQMlRE)`%YOaZd}nYUYEG{{(0R(>CZ{FF0?&U8@+UB;Ba7Nv!fw6&1TTmE~Qvgb}TokCZ z7EnCpX^}DP_;l#%ew=$^Jl&(ZTAG(%`cR-_Ha|GJlgwM95efBSYup=MFW&+IL|(wz zZtz(!$-iAcVx&j{M}Lo>~dWlYGRY=`?E>A-Bb@0X&UEgu_yz1 z8XWx9?>5v!nhS?$dczQt8|0CM=?~#Jv@T332kbbFV_YW$+{Ik{d zA6?t|35xVKoV3+5Op0d}-g!UfG<^T-x&D4hbBlKhrir$q7UB_^LFn}?gzOHDulL+* zD!vHU{!RZY828s6VnKCe2T9A!1ts9wzHIJso^L^@_1jhN=kI>F!5niEGGCS$*PNMY z%p7Y7-uGs6X@cpr%1!9H`LiA;X7~=CMr(Zxcq!nm=a7#WxAT5E%;pBY@8_i(OzC~P z+i}TGJSLzaxfRCiWO}nf)$Tl!DIB)g@8jcq#~^R|lY54`97FI#1aAH58Rr1|U=?(& z#9~bH@PgAjKV7XT#d;x6jfmxueP)XOg6i~w@fKDDTUoNfLfVpsdcLI7U{mCg^N`G| zJydP*{JmtFy2*>o5>d+iqEd?0;7aQFk=CbKp@p2)L(3}TBE@>!8+eh?Z;XQF4wPI2 zfEXgO)Xy{#jsA)OhR(9s@Bezk%KT4IFj@zD=3Qmm0T)A6_aSK4_gh-d)&~mX9x=92 zJQ5%Zljp>E-)AAz+#VQ9r+MfE0U%KCPs2z+!s4Aw}aj? z@`spRcGh>4G5mtf3rk6Azch@1<%sL34?S`LUmV0LNoTL@9~6`7w+`9@BB?oQz6IKr zP^D=tn#2cjadE*5%n!j8M{;=@m)b{$lVC2l7+GwW$tWTUKX=mh!Pk2Z5PD97Xk}mH z7t?j@BPkww-eVF*>xllKj$zk3TPi$-kmJPaDV`yf!|`psXb7-l8J7I~hf)6d_&;m` z|FQfrn7+*;$r~OvC0Dp6{({pV?o^m%`&kr3A`VjC7pi=g(lRRCCRk)FzN9`UVA6~8 zHVY`g9}{}J$Wong6i=>?6M7tKc0>&-y`Jk&5-3V|_Ed7yjDX}o>SR>DBY!@XnaXxt zC#gx&)w<}>-7@QQ)M1Up7|P7JS_H}rhw?)Z7ek5yI`b1%4l|b zeQ*YyD(dj#>IE)mBp|livrDWaIN7zY=+SN_Oc>Sla^rNSXdHQNoo-dvJFKw{o*qh+ z=^8;IH>&*Iqp`5`>cGNq+tN}u`V|Pov&97*n=N&`-RV$U49**L%PYfChg!l#6e``J zi~>){q+W~0e~L9C54bGlW8-lFv7X(2f~YdQp6*^fpxi|7-jkeWb6v-Vg(z#*1p}=Z zz*F0cBR>yI>ciQWTXv`p#hBjIaURf~fisEJtFOkkBE7@b>g&PNn8Ohvz`YNuIMK}< zb&5N1O;3p)PfAbHFinmwg%ZunLN+$59;kx(Vwl0D4^`RQ)Sfjxcqd`YU9wmonp-056;Ah>mG6c|CF zs0ki@<4e(#tDs^sQa|$%rFKvW7kL5HsrjZrdRhqCM`Mp!U8}U@y;%Q1X#eK;E3E+# znfg;g{rCS;nfhl&^{enWK>Yh9p>!!7c#unNEptnL1bFZhpq#J_I; zmEriWlG`u-^}1iPj?m`0Ie$6V$-}NtlA6xyqas^Lrqq{2W!`Hf9T68Upi2?%f`Ifi zAdK1sXhcZYt>CY0%N&hxoUxxM{RAZ$yZi(ZdSss2*Qq=d$ng#XEt}yE-nuz5wl^5| zepks_h?Tm9OBjv_xDrlCKAi(3V$G=h1ZDo^F{GdxkDwZn#$a|#t}57ebL59vnfz7a zvrp$BmrQKxn(6L(iT`ZWF0l4%{rMeSUH{ID$#X>SiCTt53wx_N$?)mz49NtwcOEklNkEP{|6_ENRX+BchD?Q#v%U~6~eu6px zlKG$BwoaUrAnk8jqkny@DEs3pvHoA%K5oka^vNqLHUlSWO*umM z+h{EvLHIpEpL;ynF**9R8H5syi?s%dqlmn~wjlj~FvjUg+sxIKMv5&x(Wc-*J;2v_*wo6GT@M6$V=Qb7S_e&#bO~xNaiDyy~LN?Id4Z|K*fTcfgTY;qwcT ztr}6YmpYear{B}^16bEW0X{y#MqMVaU-Rfrca z@7>Yo0c?D%u|T&Je6WmtVgS0nKQuz>sH)~DH!N7lY8!hnuIlba%f+Y?DSvc@-PG^H-_{D(&TxJl;w_|oPmk7ZwI36AvEu5%FuWz3T0Mcp!=*%N&& zojZfEjDFjC6#TKv-vHw=jE6IeSbQ#t?b8 zvGvfx9w235>sO39P^BGVvX%~Ckkw0CkuK*Kg5#P zdFxyJ^Qz>301s6Cp40esb!_p@vnOUXDYiU`W*7LqDp{?9eO@g_`|3Z^*@y803)`09 z`Y@lP_@-I0PL%T%mMp@HWahx!I4}sO?!0Mh=kMcTA)LZTX<{U}O&QMK(OS+yaz(W7 z!qW)rx|5E9Gm-0YUQWJ;sJb_MxMDRIZ%iMHQGy&&=R&N;lUOhIYTN(vCFxC(v7IJE zLnJpa?)_IGUV-2Es3#4+P<)tWoRDQjyBIYvT@y}O+A4CK=P?&^w4vzD zT;0gI9ux3A@d^>hVIgBN1%&xybJPvwdPCzi-VzZ4KQrA-mpcc7V2rrNgoMC7kNq#=``2YLpC3F zmCPCNQ4H$J*UG&kuu>+xsIbHZ@cpJxsNlhIFx0c{<(B}_0dKC?uZZSk0f2@Qq$_1; zT|Ka!uNu0UBH_iL!8aN4hQr0@dj6U|7k!Ts-$n`dRyGA>Qei+PBpXaZeoP12HN9eL ze{G(Wn3o=cg!J5_TR7F5?P1(_trK{Kb)&QQ$K400BkzS_aQcuWGc12*Y^tXaqrYh$ z$(UlIQe1XXI1Q!?Xn!Pfzf^P8`8~^O>}GFdoD1))!jdSeBI#s#ouRzy zqTitSAsQ%mvQ}X)Qmv_uV3@Jw500T#sYqX)yy0VZI;uAF`qE}~G9{jM^!?gk@^fcc z9|UyI>VATR(@vVDWk${C%OQd4dA?xl@$$4OibZ%flViMbL-PaX&PoUPE(EEsQI**J z%`D|)`FJ?iruGVBh0lhr2@pH^Jt1B68^5XYQAjaFxo}3qTCjkR)?v#~J~HK5AQjOY z8hM86BmH(6g4Ltaw{4Tc=T)EE!5Mp<>;|Sw$Q2Xs?A~DBFHvP+QoJ_Iu>bu?vUT1D ztr05pmLo<_1TWD*tl#f}BTvJsW|jCMAqk&ADnttS&{`V?BC zfGT^#zj@xDyfZTD(qoX&)Ot5_VxMOb8lAA#m?Pcp!h*3Lf6D{x%;b;y8;=j_MM}tE znAVJRPJ(Y8%|)Hq%9a3Phb^F#9nYp!vUOW8Ep68>y(owEHD_7YDA|(I7BszjZmtID zBK4?tR*FRGVFQg@cZJeff~e$;`*QQy6oVK%?KwYAwgX>5Goj`cQ7)BV*4AchBW1Iji5M88*87q-k&I_KSE zmN3-Yr02N{J&tTWg155*9fnV{=~Jt{fC6ubM=dLwI7;yW-1Tz8??eXXK`l!ir6})L z12+T9U1Bw|fb1WPcwSygVy-93yLn6KO|X=@M1bI_=i@L0eS7Bn*R?boEKa8jL`*27 zDm>xPO3IuX{$m~dTUTAC`ZN8nu@5JHWF~oDs^y}&=+f>c`0pk!{yX3W%Se+tBTk z!?>RyIiZziGFKkz-^!bt3L^`6zE=mBvYeW0F>Mf_x@Si}AH>SdM%5PXGeXn{?>s~DJRI$2mv}VUTV4nVK7%`5e}ax41F8~cdL(UC#Us#cMqGkWWuF(A%-@VV|zmpj0BzwJN6o-*(3_yQ#&u_c|Irau7cp`_M%+_3tu21T9Se0211+A3&DKM%Y%%MV50z~7<+^ym_(PQj zxMQh}9`Nc!>_(lhe;9U67tw$7e0c8^@*{qW*yZZi4<~c4i{dp2@AeHpJCA@o0tDCK4#6R~(=j|if;+*rakoYSBxryD3GObz9fG@CaCdiaUS}V< zXYaeuzWbhg-W%`#$L$_#j2^41SJhlKYtCoHBPJ!Kq-CS0q-LQZC1v1aWO>TL z&CN|gFCfCtDa^*j&G~Z>xJQp3p`oG?qN5XXQjt<|{h2krqp!b3zPA{CkDMM*Qh!tfz2m%g3vBb3Jkgha%2Pv{vKnYej)`S=9{ zU%nC(mync_R#Z|}QB_mdFfcSSHZe6bw|8)Ka&~cb^Ly_f5cuI^P-N8S=$P0qU*po! zGcvQXb8_=a%gQS%tEy{iTUy)NJ370%dqzgb#wR9!OieE>udJ@EZ)|RDA03~Zo}FJ@ zUS0o`3l4z)i&%e`>^E{@!{mB^fB=tx^iwXl2hOktkB#t(YC6bKZ^j=MZ?g3;>zot#)VbI|NcIcDv%V!>B_rQyykN1E#HspvKZpd2x zfPF*!6nYGG9J~iw;1>_B@g3xEf#7@KV>%Rz0tqsdp9Wb((7lxf9DW&f_?HXV-vfu< zmqJi~ZQOeRpuPu|P49vCNX-kUOqG}WO`v38>&W?o9i2ql0^A<$kx?8*5oKK$D9jH0D ze0EC$Q@+Z;zbYT31!-~N8owO$7i~+y2y}zELInA;FJe!M+@V+?#Cu?w(Y;~Vi$dWC*;E;VsrAn{#aocz@UqyuT;Eo)m#S za{|Bc+n&D}pb1N2#PHNx$&sc~hkRyCKIw)N(O#>PgSN^a14CRP$$djeW@316;Qpv- zt;ODc4z++btAMU~w`&DASejNdEOgINO74N-DkZq#VKIAAmg*#=3aozGcL4((G$JC* zu_Qri8W?4&6-`1UdYiT5_}qK;S!vbx0GhUzY#yhKa8cu1eiK?p{2xm1ID&~JrrR|& z@js?)%16}sF-vw7N8e6X%!T3A>= zskHqffS%}}{mEWFss8Gi%5KHE4CHEw;_gIZiuE4Ycn`Vgx1l(!alpJf8f!ZEa1UJ8 zf|l$yDNa^$Aq&{|z>#D4M9w|Hv3(Ed8G}x?67PY(|B=G5>K;gfowUI79n!Kk1ls6S zxCa_(V7_A$b}zjtkioFbd%#2D9%vsv%LDaM{7?_Sy*7dT=X-V7M`gC? z%9ISr%u@5bB%CoQ0x_0oEpWx`wK zB?~fR{x)4x6?8(+W^J`}z&#|ld7ncoEHUfH<)r!%Z|*Yx{Q2PPrpT6&?qI?+wP+bo zjOa;;!5HrM2i_u>e;+UWKc?%1?8Qo}?BQTXOcIGj^@?wQ-$d0TPGq8VTCn;vCS;=T z%ekiK8U~l)tvHwL;!Pzk3r@Ko{LEOKRV#fQZP0D^G(-)0wBk1x{x0_ zH}`-<0|m-Gu$nXlne5+yM$0@SGxMDkeh7)NM%uJLcWV^1z*I#vu`c1rlE9-b{Euk%G%b4PSgpy=vNt2sBXx=#?s zz0H??6Ip8{u1rI_RVR~ucBi$`NYu+)NVZ;Z%;00ldnB&9F@XP()^--nV4?xBesqZ} zYG%fqoh5X`N{pe@h&P%<4vC~%crtTBz=H}+SP4HOi%LZq1UWKY=%+tku<{YZQ#JeW zXKdQfbCgvBm_P1Q#qmTkQlo6HQ27izCu`Tp!G(fg_QV$QHN~uaKjh%@)vdri@Qju3 z9)MH4Wq=_D`PSsq2cwWfd{9dzL(aL)uJu{%Jum=0H-JH5`lZ2ZIJ2fpl;kCeTWT-T zdm!a;JPf+zLh(z)$Op*$b{y}4#Jk;8qtq=y%XJGf&Vgx#I+<3KRGLvrhehRUhCC_O z5EB8}d8>oI)iZt%d&Y*!GEs7O_22?7Ib(Gi@d0?shg#mA&vfZQeLVvOPcJ*bq|R;}0s7QuJ#`aqF3f=M8B)(_xuQAIrB`&jb)ARruD1 zRo*@LiamtG2XG<_hgJL-enyJWh&oYRu1gROKkh0tbE;J&tFJ+ZNMqd(C8h+QU!olAIou|6{d?)yU?)Ly4 z4J>;3rhEYnttjlg2f{a-X06wmF)pCe;;04NgXZU{YNd>9a{R`kz89%J8$X@1=)7JVh|SShcvfo^LOM;gVlG(u)7**= zXysb>g5WAHDMkJg8hh6cHTI^@o184!IuEy}E@N8p-?p=rIQJ}$5*F3%79#j$1>z9# zik8qGX`heFMG#tl4~I1uf);XdO12Y)t5DA{ptM{QIjXsH0cmXz?%xtm$P<&=g6!m? zi87TZxh!K1pniKJNosg!;Z!YSjk%4u%`P}}rP*;0BsdI;?3CW1m&iCALpANtUfN#S z-tPD@x2;%+%VU#hOf)8^4u@n0W=)nsinrC%;!5(kN<#?1{rfgZ#}5*9dYTly*KX~X z-NVsMGs`!rtH+I1d53xv8wLM;fQ*vy-%SaS!5hJ&2udII`@iTy#uLGVo4rxi=2v_CSK{r@+cWg z7qQ~TG9O;wAf%%J4@wgUY;%X>_vgEmk{~aIvX~4lj$?fjH8mp(7_kp|hcRd0uUR(9 zW+m*Hw`omEJv}lozsv_5u+s2TCah2}bdp#2znR@PmY6k zc92zzzFXWQ)Ye+v)==VMcCq#Bhmn+{d)S#@SK;@;Z*1&4h2J3d-vgn=6|6!nv7$u0 zW_Jx-lopqXZO;viDc>8d;?WeUj|htu_ZL`rkvCP2CaacK$i$l!PK`YB3qTXN6tlg zVzs~j#5Kdl`*>hA5*WmTOJ2OQ*BYx}b(fn=AC1{~oy5>wg83~l1f(j!KNqs-$y`v( zai|*NA>w`n-_D_FJ~7gm?hLP<7j?5(xlLj{lzW>$ngc)R2-fV3)2T=|>o{n&a+F&2 zr`Y8RXScMx0yS>zFBfdbb%wB1KrlBWS-eQI1D4*D2kZ#bPZz*0@ zK_#r=A(LmF&?qUTdb6yT`L|++ju|J~FG3MhX7;Se36}Jrw3I`LRsQu-kQ{U3^{B#^ zX86!>S_|%p0*|jaujUK*?6S$8)00<7w1{#y9kB;F%BZux3zU&DEh@i2X;IfVcEp(C zZ?&8S-~;gH?3=@lmY757TPh z-8t=t2K~8q>FaMg1Xvx!+Wub8>*zyleFjAs;77H6X@pErIBmFeOqK>IB=XFs_j5f9 z=QOw69HRD0p?!`a&&h|^xdTLIj-{qPL8gjji|L}EQsHgIH;>FABi=%Vd>v5I@iVbe zOGoaQt?@XXNb! zG23OCS1znHGR}N5oKfE>XCnDd=HxsX;ENk-kI3#S-d`iQSz-+WajIjlLb^X8zxf~` zoDkVcewu_cxqns%enBQaPsoGLd$hYR9KecOn)l6ckP@raudB$p4JuT9I2h<4Lo&4o z_K?V(Nvp4-W#R3^b1mi5$BXhOlY0ohM6VZ(dsH>VBYDOnC8Mp)eBhdG{WWp$B3w5M z$sE|!g@Rzn*S6_PfzZMgBQg7(Z*kKpxX=o83lW2^KseN#n~oK@b#ET(LN;eKfX&Uc zAb{a_8dE@i>XydtBaFgqjRi3h>sQqu_*FVM;SN+cQ!wOT9KI4IB7hf^*^JD=P1&IC?*e<^MCDB1Kp+D4B!F2FTbIUCiz3G8g9 z03jj7>$}e=u+Itr(6X|@D71D?RqEHi*IQ5!#vBVLH_KR?9m%T&34G1`G{@Yb3is>_ ze)GZiVzpxW5vrKM7440ETeQ)xMq|*2HeI4S2+CQ|O7ktJcD#^n*ll=htVH%*18s6Y z^KDBILs9F8-;?%p8Ro3~AF@sbFI z#{Z=h4S7&XMASm@GkNoT*A-T#&X=Q(QewRjvoxN>jJK{)ed;eCDrp_VZ6sjAo`Go* zMCM3in>fUYkN8PhRzFh|c0(l`1w1X1xDgWR;wOtIx*wgM!j!Y`n!MVQd*>^-$a-X+ z&!3ogdhfW#Go5H){xjT+GOLP+w9I4eC(9G4^`XU%sM6lfJ^74x=7m)scCC4i(r2zB zr6;hrQJ=#1nkIf#MG_e7k$?Rz%v_SE*CHHye$$M!Wqky>h}I6VrSvfXpRlneurlJ! zV0J#wA#)k8FwU)B&@5``bdDpwYEVWbNxD)}6!eul7>!fftX=kJzl0%3+%s2eOj z9ObwZX(dMk$*zwxh&vUl8b+fvmRl_Jt)4D8xL9MqU*3SZOD$D)=jOLv963!KHA8;+w)fs5Pwej$#%L{ z8hAB}?fglT^wU#Vnm~E7P0`5R%*=L)0*Ce@Y3EM~Tn%1PuH3YU(?=RazM5Oyi*j9B3z1v6JI&zA~i zBu=1%=QGZJ{ZmcGOCc*Xpf82e8{EY>^A@6%P37eyu{WF@YkN-`-9HibBT7-BjMjcsI_iV39AST~wN%)0=aqvuNh=j+kkv>?~OBk`Xe~IJu zRq=IUqhVy>L^D!%Cx5;#kb(-!oKOrysu%iH$PQri9|5#kb>W(KqKe}8?H1tR@Ky}K z`-6r*pbqJloM_s`H-PZ~R{7~gcD=B8qT**fL2*ao_?LKsT*_k)d^EvnWF!&EtF(bo zPKE9^=PD7rDkC3EEXBK&q*4Llqh}Bv$qsL6(5ZP^E1h;CRHEvT%f|+0-I`cn{>dJi zAphAYTj>!@%#nMbx~XsM>;a_nv}vFHgd1*P4)(}g&%2?p;0O6)%A@1hV?oHP@3yX_ z!!~3&{Bc$lePFbHtW*$)HvWn&NlBTXkRy84KB;khY?LlYV1=WXxGONPx(noQyfw$; zoEx0!Fg~9hx9QwS*%UVb!s?>>s>hz;>Ybr#57>-+vX5v+c}9z48@FZIv?O?ANnr&G zltk-FLm!CI}8R^T~?DY(kZ<8Tb$A@0_WF_s9*Q z&arC;NsB4xValg}MYOhBB?)c0`0;$qsf^AQ)u{3F9 zjjvf4jQ25;j~NX+qHA#+wF%vHa+*|-bv_yI+5`C`9;ZT+XNSO&xI59*I@aML4P!7*Sy{gWgkoHV76}|Jdx~u z)8l%%O=lJj&(Crahb5eqqScg8p1Q(I6Nw|7?>sUbT+d`)QrAh)8lFvWzC^odUe6cH zR=eaA|77Y$dJ+)j-fa3(hM63dOS0AWf;A49{0ER6eTd#(QQV@9vwXG^HX*`8Y#g=! z9(BO8M75w^%sp2-GrIGAMZTnQPyGcV`iG}(d!Li7klgaE4?9~PZ{Zn8yXpm#P*&F` znFa`YgsnlOfs;M*1toc;TaK{d<*EjY-awzC6*KI&k;boc-pO}~3y|dxm$PHcZV5JK zFSFV!2kb+)@j~bMJwCh(5KaKpk}@Kt3hN4~Th6yz*4=GJen=WtPz8PXk*i>Xf!QP3 z+%Kd&Qc%&rw(@#Rk4t@)t=j;jyV}i2lID;bu6qleJ>*u|V%aA2@y9Hdk}dxL^1WHG zYsnPHr}Rdt)1u{?LZN1241Zhv{baBQU8Uk0)tD8nZml@4u)jPhT2LSq7hN9u9)bnq zDPgoZc?zbsoW#s8_%gV`rXvFx%H#X6{X`Op9lEvGa9IY3Iy)sKVDYRp<62}_%(1oJH%Y(?u6ba1 zj6C5DB6(mMCMpEHQ#)Ts*1mSaa#5wD7QisVJ)M^y<()1ZDV66%613CSNGr-w>%UhJ zGgau2?@E9%%!O!{;51j^6xPEvUn^_>V$?$_qc?MVVtu(sA#(+X(*yy#=aZYye3|7f zV}^@u=2-cT>+81y?XZUIT4aCb9=$9K)M+e=`zaI zENiCK7U4RyK!I8>5}Wvw-^_FaB8b3To;vkV^#c|MQ05sAKMTE%Z4)I9hhjD1nq9MI zU#sMzxA_YXo|Scd8BklY_t|_1*e3tDb`IRFM=vrR0i7qFt!36mM~t840U7MC&mJO9%eN2gY@^qiDt40qSpNW!5BDS7XARS_g%L zfKU(BtSxf*5gxJfNS@jXA4`W#sF|Vj^HO}WLCb5jF(}3`J9j=wZ@+T7U{S7}HQNMx z4K5bGf$p2U6qB0n+_NUcO&zMF0Z-`>iI{G6a2(0kcN~b{`H!dE0j7!26timt&i3rW zmW%?T`kh=8fpbnS?8`R~{~1H28UFh|>@QgL4;h(X=eU%ICk2dtvSJ!PS+Q_vXl)8c z7z|fpGF*N|4(z))iT_mJTrnM`aPzO+uHo@VH4qnjx)r;AvVA1Mp{B-ac*EVo%9e z{LhRqgzo88p&tbJ4MssdGjAtvO;1+ufn+lX&pkj4qYeMJ|6iMjB3~&sgkp-Ayv%3x zv>W8f(rv%VSxic3t^FfD`UBnSo8-~UPzwxW7d>{yn%4C52s6r2ITMLJ@MEw^*ptfeK+ul78 zWO){>?bI_b#1ld3E6bBXWr$2+@7hYgy5bXZk%AFT=`!)bo&_HV+Ee#n}=%mz~)Te=1xNC&+J$sg^Qww5dTW9j*`$;{|`K%J}l zCs)1thnDf*RUfv&4q44GZ8}G3gjEq#R=_Omuk#Lnv9SLUz5in*|JUEY)+-0cUk4aD z@aZrs3QyDi_Nozkf`AC5|Ji3BD?+QPMYF2*@a{xV{QX=kD)G`Y20tHI_N-_|A$V}4 zn5x60iwwd;U4O8+^Q>~Z;PDA}UKWAAsuhu+p;vULLN6}T99Lstxr@d}4@T|1CT{RX z#cjA8Y~H>t{{Md7n7OpZ10wXv`Ri}GXJ49GVW5p9`H6Tp)1JSPHwQ-3YQo_c%wub< zQsNqx=jJrF6=z)YaXOG+z$$e|i?xy!EFQpW+61?^rWZagjlFl$P`-`X$+|*TP)qwn z>z@$1)QusI3=i+u54iFCMxAHtYr{H5G{35-gmKio=in~#3Q8%thSRDFJd^R%?VMg~ zFL#y<&iau`pi(+h>Cu~hVW>=NKS_9-1OM%(FQlPC%5>#xAtaCo!`!tI1n!ikD+Nd?8(xP9s`)NDqA(Elm!K zG*ypc3Q?zK%!oxA#&^t`reke;9o|w0d*iDy&h|M2lJIM{-a_W7d7Ex3bq`+Vn@v;= zi_Z+RQDfkjgm}A;#p$FzF3eCje}jlM_dTkEBuDj4E669`0j_&f2RX2ndeL-LQ~fUc zEuZ3*iaB3%)Ca*s#C?(~Vf(l9F60(pt-*?_9GkHh`up3RyNh)KREa9G3rt*WlGF0^ z2*!0N48U(9dTp}bXu)4tOy(Z$M>Lg}-JVvJOJr-+tZedw)zjb=-paHagv>Xk*4=g$jwhyPZVYL47=H>Z{TV>@ zN9X*;pu<3BU&v}eLwRfP|nIplb#8rU@8nvcb}s3ln)2PTVPjG zoc5Z|T>_$W$jhX zug0#}kHbjq9sZWTGqk^BQ99gE_&Yd1qk@=Mjb zONW&ad= zc;YWq_n&c!Kr!QGBkN-0h?TKA>XpKWAgrg`Z$fVDbqaMXj@&vd>cCVfZw#djx1X(; zirfQpMXwzj-p4~Lgr4HqwuyBJ&&&7fw%U&<`{n2jg#>A2XUJ@9R&qfGWn5-X-k7gB z(mUG6*seu|ze@{ZjQn?jtbaDnzX-mt=H|ose#qpjTdzO#%ZIf0dOKF2y)QHVd<1=k_qtkVKGvt z4Hw~~0a`_uzbZ;Mi)!u5P8LfC^5@c90 zG%6dZmCWl3&p6vk_#MTzTfkmZuqc{0L+|PKbwI-K{(4{MiC=g<86kOsZA8>FBAo~I zHT8$zEqcEi#|eF1W)^s2s%$7^@c9$xb~VDqPSn_u({)E~=B%lrcGCxp8qr*`qHO$@ z-gA*iftO`8J>PqxIc|TDUKBqY73Q$Wh*z|5_CQdOWl7DNWy^><5U%&=6b^ZZ5(YwP znJ(wMs7*@G@r--PTG24}LEoo~#s(!%&0F$eumLO~+`cyE#2Y!=Q(A5t+DG*<)BaTX zVfYZrT0JOFa~mo;$;O&q`%0U)9l5q5X0B{qo1%_7mnNyYb4}a-_g%>Vl%tKLB_{IS zZESvm8y*B&;YMfMgxek7T&jh<@~K(ax>9U5JHZY_+=|y^&+NpLexw;6v$Zq5wUbp# zB8vT4cJYp^dGMN4COOULj@h`T7cBr!#(!7;zo^tiNi<0o)~MT)k1rU(+$xH~_DypieQ^kH;0KpAx$!0 zQ~lG`rbrvJcxCy6h%r)b#Ldnn|1U*Is=r}XegdTb94g}XdA#(u5=PCX zsT^;!)Fhfl2$Jd?Qk7Rr;t1|PwE9*(z)EZBNdOut%PW?X>5dz!st6rF(;xz&>2NJp z8w4n_*vc|7M}&m7o$gK45id64xRzXMtUp>d%?R2H=pdxTdx0Uc(~#3rHynJ-j(4Ek zGtrP(r>)a0hNrUer|XdaRiOB(E&_Oj0|ptxlc>Rw2GRXag}7r^@R@U^&+atg(&Nn* zfsf0UNmB5k*|PrACS8%JX^4D|sh65ICOcABRYl zgSC@eeSE4aIaKnpVkzS#OF!!r9=c#RPNcwzDtRnoAnw=O_G|t({8FCan2wiUPU}4& z!kVGmFPT3!g(XEowfWsoDnFk84`KaZML# zBV=}aNWzM029L^H-1eO+)X$INra@zb4I?%ifMF)@^8S0;~n z8#`dRZx4ITyf*lTjFIB|8u_|{{_et;{%Btb>@qh1fTnyyIExE0ZPbJHk(h;$|u(E9sFWa@2a@e%8J1(!ixO zHlsLFOj|eJTs=Nb#KXvjYIeavU{lih*h|4nS+PW1VqVw~e&V;>h|1c(>Q`Wue}64t zB!7E2O;O@9saAEe;(06)k#FelPL!1nLN);)F4Ux>BF~`hjjksl|Afd1fzRa5BnPvj&pq(;jI)zwO0`ZfD-sH{-UfPfcZCE znktO6{l_&e|3TUcr=jK+LIrr@rJ!L(&X0V?QwA)?ce>ADb?ZlI1L8QrUyc0C2&c-6 zT9dAd)ry4aqT*!sxSB`?#s%c0Zt~YA1_$5#0~eH4S72Kn*8KK4OJ~XhZ{gnMtOi%K z^EkPrf`Q&zW90$V+Z?VRhYOkdF4QE+Xtyr~$yQ77F%-?9M-YxPbe;-f&XxpB4nJqg>Q)yEDnXN;G1``)PQfN(*!r!GV zDDg;EcV5@-OhigAr;cSN?MW(eD-||9$<>pTh}1v6rhv7&tq8*(ux; zC!C3{(pBzR=fQ_tHxwAP01B=H-#$;%Aq}B<7k;irv7?iiw?pjp(b%ejBxi4#43%8% z8g49EPs>4wK=U6E3xe~0#gYOCSq8C4V?qaqvg^ryLFG+0H(5l)tb9QJv}TM%X-_qKEyuurT`Ipm-;p zTUds35vdh#dE|BW0$B8A;fo{`4wpi$%AoOLzR{!GaOaYS@Hgc$U&f;~h52+anELI! z;KsRk3h~s`CEI4~0|&I#R`~~MO=v{Zx3rDi)TF}-k>xbYi+9G>ie*Q0mdBe4AJkBM zxX$02$4er_t7xCsM*-D@)LerRyr^rpk1>gpFJ*3>ANyOl7(~usQTx+L(Kr;ye*+Be z;u5qaP-G>X)2V68o;Gjbu0PS|H`Z-k}`Ne#Y(hjCVw6VuE>lHhmR-&%r++0%Ty8Dr0Fw~0Ob zKQCcTG{X>2q1h~jAV8swPAk+p$IinL7l*+P_$u5xaf$2VIBlR~U(NB8+UcGSlzY6_X1 z8VfxVC%a(Ye)He!?f)gu{wspv&#`4ctKE*}Kh&lMY8Z`K;_MG_t$m*&3b%RKIq5Ai znv=B!#cw)VmhsH(vP?h0W3(}hhyT{yf=1EjP&BD@C=ynd-|q<#E=|wDZuTkRZx&s> zJ}(4eZd#5AB3;mg9yz83P-Z zt2AC0I&PV4i!I#)C}!luABvK}AJz1wR)sPkxlq3c<*Q|tG4t+DzFc@q?{f&?F3?~V z^Vw+)nS|}?tQd5JnTFw*jU$erKrhND#g9cgJvfg?9nSspVzKO zqu-_DlKMr1TuEw^ZI$carcq#Bx^2j*#}SbthBi|oDMlo}jnGkCvc}fEe5v+%C!ABl z85C=a5P=DA&4=CkZm$1W1(IR&B(n%d3`b*uYu{+)a+^SPEyQfXJ zeBZXOgUpD<=5GC{k$WFPk9I-0K9`eUZ+V9gif(0`3$~rM!%Yi`8yC(ZYv|%dx)3(* zmQ=l1EGqM}rnYn}IJMzz)^RY_;8$-^?UTBnxxUz zNv;{@^89lD(8`jlKwuIz^D^;($AXa6%8K43+M|75;&TnrTyih|xW|jJj;~7sYSIxh z!X>S^)m;jjv@1UldDgs0a813KMh$E0#5J#0JVUZ^#Rrt6l4g%VUB!ugn>|@fobrZs z7u#<9E|W@8z89Jk=S?zp3r_dIQpLgoY|gG}2L7IH`L#RwZ)K87JYnRTEn5 zSWoN9tZBb~-)Z0qt7g5L7&qRoRho`?-y2%DNIq zcqBK|zAD;kf;IM@8c zakssf!3o-Xte{|jvG0UBY9~o*Ll5(Nb9TXWHPJ?w&6W>D8H_6_+<_6uU)Vz`VqTqO}t#$c#3kI z2=7alAv^_j49@2?!fsw7`89bLN3Cp$_GH~{=(vwCxcoaonFb%S@)SN(DDn5a^%#m} z)>f6@ITF^}cbKRtsgiNo2r_Mp&$KKs&1P;j35nzwX{f+>E1y$LR(=>_vpu}S7gJ@i zbfrd%B4_m8(9-#}&HD}bN^v*U5+841lew?Qrh-A>q^hvWt0}W_Mf%Is@&QwP_NBTP z`dsl_m_D4p4xu30N}h&djqTwX%hPrfk~y>YA+3aX@%SAUor?O|XT^WnCm*m33@zoY zhE?~aS{EJH2R2L2lUyd|ny&L|n#D;*e6x)>=^QMYg*tX>&K5s9b%R9A1eR|OB(u(6 z>>1AIua4g|aZTo=1?}?t2IW4<*26nMi~Yos8`@<^q;DY=XMM-vL1kLoagFh*(=SF0%w!MGEOX6a={x^K@HtyZZegx- zFF~(KaNh&+ZaUDn0zV{zg>fSt-j~-7ax~X^RLm7Io~`0}enlQFHIk_lC)FtPr1Bzd zW^_cA*CBXq?9^RoZDkKDd#=@tQl6TM#LnG$CSuutO%3*wQJ?Hm$Xa ztS%`Fq6y9X-uvysp!-&qMMUE=hcPP>GT}d($e?KT=Hgm=em-9Ybo=yB6~cXH-Dc>6$_duEYw!!0XY!TQ!S52$j7rw2pDe~xuv5a_ zX8_>50rK%smQ3;dgtflJ(<-n1Rq;B1F%687E6vqNDbZ_-RohAkYCXHDkIuAaB*sfw z#sSD8|81>`XwBh*RiGW1bapP6i?M^&_}t31)BLoSxwHr#Q)M!2uKXSdyo$^hvg@WN zpBNnk7q1!HK;e%Zh(CZamK`i+l-n#ZugBnLF~7P|=rmUEGxtI%UBumGHg9E9^-^_%Y268}s>QvOuDUpMP7$qTvz(L)%pEQt4S=|L6x;0mscOVu}Z zFF(E#=P#Lb_sG9afZ|udGBD;)`pR#^)2l5;tTr~8!T5yA-_K~F$^lQZ%Q6|sT!?i> z>XnPZUVGq^9~Fm5?F7%O#~O4<8id`f6|v8HP!3UEv|`6G40BE)`?g@a zQ($bZP{3Y*6~dhy7YtkO48t4`f6^8I8r1;zkAdc|qxRq3iQ?B@>>p=c0bgZ1O$_*wNb3-_SDN`Enu`Q%s&RF94Tm3z|b?tnUk6=JR4dU+q}fJ!?Kc zBNA;@n`g$Mz|eM88PRHlt&*zydPS;V!%=eAo8ZN;9Tp;@{D_Xual_qhtGL)7lVdW4**UWl zZfDG@3)3Ivv^RHa8@aH$S_*$RYAspeDgry<@A)}D^8F^^tt1vLf**yntI7?_&B1UH zN5^M6#BEV2U9eGRqb{q{a9wo8!AF%LF$?$9GM$->_gc9=W1JCp>+lHgj=451Ubcv` z)E$YYvaECawza)_<5yD!-^&4gky?N0Xh5p+ZR>*y;g>oI?!KP>AB(HD{RS?GU_lC^ z1vSlfy7Hrzo};MaGDM=4o0d|?qhP$;es#W43n#MIBBCY=pNmm71|{72JlV6tWo{RA z-@**@oWP5}Qluz?T?T3M{@{XlHY7yN?Jlbt?R^)mQi!w-;&;1K4^hJG(r-FPt3;zW zK{$?0yO2jc5j_Ugb=K6pIw_9&Ah_^c{O+wL%+-cVd$2Qxq9&WQp4|^xQa!Hy6pGx_ zeiwnHN6R}JLhW9l$(Olu4Jft=>}Grr?!vU)`AScgo1_Y&KFQ|;O7?QapiA(wR1wSM zi?Jz{bL5NXF$m?5Inj{DzyyP#$rRNx@dEQ^Y3lZ+%z%Qe-fW$n>w5t9mNP4U`ROW{ z!^Z3Nc)}X&vx+|l(4jTzCx5khM7fdRT+~r-()N@uRFtMM5JSHzbituDN%^V#(9;`Nr;OGCNOUcK*R=!)>_qVP*9+G@fqw4ZaaW0bb5KVAZ6m!AkYh zmJaKf2h}qE2v4s<33laM9HVCxPKIZM0}?4`gvFWxru$Se zv~WZ^pU~W9;p*s3)SCR**)NumtKaLXnY?|LS=C|Nh!om7bsndrZSs0kp%avB3?p}G zPV}s7T?Xaq>+Ygn;5xd>QM8}A4)$2#O4sKXh9Sk;-6v>XYPS|&FJI=yH=)TJ9@NT6 z4Jr=nM!*6ht0FNMeUFSLJSQETR=dmtdiT@H;Dd3`d5P5j;9mbHh6X?Q5Av|yyg5Y% z-&JGzk_gdV?O8@8?1jZlX!}tL0xj(#5w%r;q&qveBv7l{d8+$e{Ia&rg2Llkf7os@ z%h5c%YAdW%VLZI8sAc@xZ87wb4DjRsnZ{l{*3^*URw=OwWyy=#^d6WN#o@CJUr?+5j6JEh6nG3eBwlY9t98~alq;Z9 z8%=H?Hyj;Rb{@4B0U@Q%`Vu}sraVT@EYe*%|Hc!Jq+a7WG810bC9hTd`cBpC)*`2H ziH<*)ElYVzP?=8l^e0XwdWfgYYM>A-F}#09)VNmv`tpp@=4UiM0;Y3P(%+yX+#h5; z3K?<9UzUvxTtb;RT&Aam54jzAdP)h1=IhjcDCsCk^A`gc$N>DcCKT>bch8*5NV@~t z_iaQsH9`sDx6UWj4GJ|j!CwW7UA?q1>Mh5|`Q_stJ*le9wM_P0h&PGycz$>`;8*l= zxu2L9Z3I)`jc2#K`1fs+RI=`Sz`TK=^?FR>JOY22F-*5cn7%*(BM&n8g6V9$fUh!g zdY00~NFgk{nrEdJJU-YePby>JPFreLWr_L46`g2c3N)`1cyyi{y{Wu@!@1CX!Q>n2 zCW8@WUuSV!xRHMk5M0KKqy>o?IiM8~sTpF~!RBw&`=^n~&C|__&6tDzg>}gc`Rw}u zi1;M@y;IqoY@YcAPIV3Ve>tdqPrMb2t;{;uQTWpOK(OQ2ko8P-B42@3&WS3{Z$TM`IO%$-a-*faBxIXX6W(RvxIWDc6E z%DYq5EJ}7PnDRSPRDbM6ROyRNnze7FpB?Qo765vin4IW9jy@U?za&FgpZRfBUi8&l z?}7~{5Sq`WYA`&)`FSOA`h3Nj&lGx=GZKQIKmVr7!T0lGuhDGxLf&cDH?)zj3VIsg zXd<+#gDzbC{|j20EdA6$)Z~rm1$EK z`wShaTg|}8<+HCnm12$B#$Dq0j^k<)f0zjTvRFX9F*fsu4CeCWr@7NBB8{)T96N){ zUCqxSNr9<8_OU$p%#7BaNn7jonUN+momA#d@ z<}@7Nypg;Xr9N6a9q4+0exa3c&FVJpiW$B_PJ9GSUM6hK>wEmNDqZj$KS+tv>q&zZ zqk@Zq+7E84)abw$n3R)G2w(ABnXp1G{_RYR}VniMSbgP5~8-b&Gg zDXd57X<4@gOKuD)SPOvHKZbP=9{Og}hz7#*OewR_-GrT`I*hs%|2C{( zTRPi`+rbr8w_zt_lj5bk=|TPwHln$tUHjd|wpO+UAw$y!ROp$!o{&4QzhfS~se)sTk#ccB>1~ zg{_rg&iPa1PQG)_-uv78%cgO{;t0E@VA;y% zHZndgvrcMO(B_JBz1TLqrr;)BMk9BuiIxgmIaC6nml|;6o~7qSNFpH4pn!77rlsy3 zY{wlm0u{FFQ>67^8<$fQM;%5pjP&bjIhNy>!ArK@62o81K<6|zR#DxfhZ;T-r-~Dcjh+etvk)g z6oiQ{v10)W(?v>h3p(Bn7sG|^0U8tfpDce94YpROZkV zIyPW~OwNMkQ3He*>4936%rRgJouzjTR$YA=`tYM#;r5aNCCgO;cu<`GVATNs-tm`! z1?+NW^pSy)t-$(hj+P5gHwku|AF=k3(1Z`)=#m_pH|1q_pPG2=bp|qH*l?~Q_(`4g z=%~!KNy(3Wu>N027me_(?guYDHG~^-yU(fdEWTQNrB_5{XH!Q zAd^6Rv9;zDBDy(V)}6Qq>`9IcQXJY@+2ySk39Q*wdRiyF!EblVgh-mGYKijgKA`bu zyBo3%VP1)6XWXj91Xt`w@bkz&@lAr+kdikTkPzT2#)#C(P6Md8_X}_8noSndZ_yS6%aL z?3}GU*Py2Yt=(nE4(=`PFggg^dyZ)Rgv8esMo^@VsB;!jU?+Vr0EHtTJn-y|pNUB2 zN%c2NG4GL>V+wwTV!-7dn-L$5T569QsaRT{E(BBGW^VKEC-ObXw;hS0<;Q`hvrr==^W%zsi#`}gd}e_F8NZW_uZ ziKO|QH-t*44v71U#Q4;Ajr*WWo;ACW?7Wwe zFlhP)qQDhfrfQEe8`*gj0@?AnNx{o9ET6U&QU@^vR#+(9{CIeD5x2c}s!N7ljWnTS zkJP`Nake(Uozy=jugx%SK|~HCo1Lilnk&~ECupRiyjSC7=i!PSUc7|ojj}D{DM9g1 z%%^ci-gOc~pwWVNF$24w3gL=Wt;uPr_fyvlU`6AGvb0u9T9Qu|k{|T&@Ra}&-JsR4 zq&{8vZ1On;>ci)2S`Jb*xZp^9hs^!27fL0{l`bo(iMJ)cS~NZ968bVdwfSJH4s!mQ z(I$yCDGmPtMc=HKph7cQ3g2OYOJQ)pFW6Vy=l!<>zfO2hgh{anWgN!l7%zMCRgSC+ za|b?EB!zLS@RC_XDwFQ?z&tP6*}1&KTy4Nwx0KPDMp1cvGY@Fd~=5)v3rnL*|i&Rlf*V)Z}^L~DO{bwbA-&f-A!qch(DZRlr z!(qH^rrJY^auZ_-1P8|h|4ibc2r)4jvEaXrT#+g-bnuIn<^--t?WkfAt00!vk)=eT zUgNgnr>aUUqM0l{oX%@dtH1Lyb*T=OFXcX*oVGoFlXYcY413#6fN|E7isf90p^w)2 zi(&m|CjW7i``4|6zb?#w`k;dr3Ylta?E5}k0bk9^LFH?yvd5Z-2sw(UEO>aFH!=zh zeaUIKP;E9OwpHgFmioL@iI&GqP3YHy2nbq#7tAHSG?&bwGm>6aRqAUZMH$_pdPF;y zCa1QsqI*I;)1%RR98dzW52a>d%)U}vrrK4)=3O~CVB%S}>eS-lI*CiJs2oWyg5}#R z*_n1@rgH({EHnMD6z#ve;{Uo@i|BvYOJwXal1CrBMXd9-wqiKte7M*E7&Qt>nbF{m zKOJPD4DwkU1!`pZiC+T>6w~TJ3o#T>MJI7fJ+X}@T=e2bb{AED1NmL=BS=7bE-~rI zDS@7eBEHhX30c8-w6{9hLlIr8%8DdPeq@_E$R_)2WZ&|fcH5#4^wDyGC*YRA7L z-w|73^d4GbbpA_`7VT@w_s0hJbR{S*OPrq+UYLIjm!F?VBi@XrjrvTWxp=63#`J{u z*x=-NE%gvw2gs7@{$MSXjh3kQhcT5uPSG0xgdCTC7^7uZ=*L-zDbkk_FM9SJ&>SjX zvqT``x2nUir@$^BS`#ntYoIP zkMp4G;u)yzL>p=1QEI@-=xU3p@m>LuK3`c4!eqa}q^&aRqQuoa;5@amVX>6iY8=TR zO8^8}sYfoqVe?Pc7QH0rXZMTy=O8xaScW_8uhExT-!VW0g}PVk1)o#hE+Z27B!~AB z%+04-QKM{dD>%#8`3z1`8p+BZo@IvQ5GUOFX_#63jy;u=n%{Rr#|^Ixs8PG4J$<=2 zO)g?SQmM|;1o+*3B*{J73{!MezSQL?E=SI}ppP?LWqbAv4Q^3Ti28T#N*#T-E*ia2kHBQWs@SvJ-#nHy_wa%xyR(G`3$R_0k zOaGhR`b(1fH^O?U&K~KKgV1-s@{_BPpdv#;-S0<4S_&MFuod~xx!?-%hWgslBu$185n*Px5T$5~N1vjDNLZ~w zd2QFAUKTI{%L}!-x9>c#ledl$BfQo1F4ojV>S^QoAb_z}yAf?9{=L@IfTDQ3)1ugB zCEIqqux4^{Rde69);dtAu5mFw6nEv(S#1Z(jWim&>LQa1r;{|D4)icv^rftr@?a~- zuB#5IY9ma8gxE*QvCl~pj= zq7J>#*_$uxZd>yO!}r{G$XDhk&w`m07v0s*w%C4PhPJgVKjtZXOWe!G0A z60V}-?~r)XW_drnpwu<|EfXuH6GWUS`@gPQ(pPqgDd;=|riFfMeD&hWF{?HzE-~qW zwX?T{sNjdl-cu-A_@2b_bNB>vP`r)DlXblDEmT|2Z~;GOswZ>%$2tz1eA9_Z^nUu| zQHMf?U>kvqiL?%)vMfI$uv4tGbk9>Pw%Gs?8?P&Krxgc8Y2F%v?=po?f)|xLtK%KOPOqtPg1s5TA~z)7c-}(T+lx@f z1`BDfZ^+!dHCyGE$ zzC0SENv#2Wj));rZaZI0d03!W0ewg~7Sb7Ubi%|_@QbzOH{I!*=ryD>76C^HzAZwk z;KPfeyy)gkR~fdk@8RfoSb?{gu&CI({A-m?J#n9(K~H~o9G$*_sNh-Ajk5IU|&g zu1vucxL+k$JlJ_mf(6UZrypTo#@GyK_&?CveM%-*;0<#5B^r@>D01Cdf)!c(TNBIm|1@<-_Lf}C*; zOREx1uVN~z42_qA|D8`@^8L(EZGw!XJguRYZuw+5TF zgX=-Izs5Rxau5hqt7E*fXR@Z0q`t6?Z?0;l<3wOVclgOdLzxG{9`^+LlFI{{G)h{v z_oz=RhG=ol^yA#&gU^j57&nIFmbm@059Q`oS zz@PW9Q=}vdnkx014QLJ*I6iZ8bDhBn7bS?Js3dc*-_!;%fhU$(7qM8<4 z*HyRo&L0F32DcV4sD(r-?M95P*Fbm=_QvQVWu15515;p*>l($nZ4?`l=Q=A`##*92y>-=`*q7d z(x~Vm@|4X^kHed9pnFjH#ra7{&Xt0BwM+1INO#?-iCA(Xnc14eztsQyl_&Zck02bf zYz*Rvaf@YDBP(T&7M=yo&MgaRfTYhN^HpF5S-ZB3nAn%5XUKZYkYd4f2Ew zi4O#t%D+I?72t${VW}gm=LF7P(8~SGv>Qij$`KlU4X`2rs|-9fbOMg)43q!*H_+3z zJ_Usp#>uvCpxTtfvmE7BpMNYVKt|Mmv-Lim|B?7`L^2l8tIdf=WF;~2P{WIyq4@0PMN5K(Y<0Gzjd2>lt*!?K_Y6= z=ujCCAy45Lb?)vl6D79C&c!{b5G38mB5 z-e}-68^gSrtiQUUi$0M_A0IP*fL$JUmc$TWY-Gz9P*)h&7DX}-W%zK$b&XhK+~Mg_ zXsjpsmY?|(>eyXrZpTj_RqNz36myerv2j2m4pEw2Mx4=nc@t;s)PK%4`x#sH|Kc;g Gjr|MTxapz* literal 0 HcmV?d00001 diff --git a/examples/nanogpt_4D_finetune/finetune_4D.py b/examples/nanogpt_4D_finetune/finetune_4D.py index 9e50bf5..e0ad58c 100644 --- a/examples/nanogpt_4D_finetune/finetune_4D.py +++ b/examples/nanogpt_4D_finetune/finetune_4D.py @@ -41,8 +41,9 @@ from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP from vescale.optim.distributed_optimizer import DistributedOptimizer from vescale.optim.base_optimizer import BasicOptimizer, GradOptimizerHookBase -from sharding_plan import nanoGPT_plan +from sharding_plan import nanoGPT_plan, nanoGPT_plan_dist_dropout import vescale +from vescale.dtensor.random import manual_seed # ----------------------------------------------------------------------------- # default config values designed to train a gpt2 (124M) on OpenWebText @@ -95,6 +96,7 @@ DDP_grads_in_fp32 = True save_checkpoint_path = "./nanogpt_checkpoint_dir" load_checkpoint_path = "" +use_dist_dropout = True config = {} @@ -115,6 +117,7 @@ def main(): init_process_group(backend=backend, world_size=world_size, rank=rank) VESCALE_DEVICE_MESH.init_device_mesh(device, (dp_size, tp_size), mesh_dim_names=["DP", "TP"]) + mesh = VESCALE_DEVICE_MESH.get() ddp_rank = get_rank() // tp_size else: rank = 0 @@ -124,15 +127,17 @@ def main(): # world_size number of processes will be training simultaneously, so we can scale # down the desired gradient accumulation iterations per process proportionally master_process = rank == 0 # this process will do logging, checkpointing etc. - seed_offset = ddp_rank # each process gets a different see assert batch_size % dp_size == 0 local_batch_size = batch_size // dp_size tokens_per_iter = gradient_accumulation_steps * dp_size * local_batch_size * block_size - print(f"tokens per iteration will be: {tokens_per_iter:,}") + if master_process: + print(f"tokens per iteration will be: {tokens_per_iter:,}") + print(f"Use new distributed random: {os.environ.get('VESCALE_SINGLE_DEVICE_RAND', '1')}") if master_process: os.makedirs(out_dir, exist_ok=True) - torch.manual_seed(1337 + seed_offset) + torch.manual_seed(1337) + manual_seed(1337, mesh) torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast @@ -235,7 +240,9 @@ def get_batch(split, bsz=batch_size, lbsz=local_batch_size): # + + + parallelize the model and wrap it with DDP using veScale APIs if ddp: - model = parallelize_module(model, VESCALE_DEVICE_MESH["TP"], nanoGPT_plan) + model = parallelize_module( + model, VESCALE_DEVICE_MESH["TP"], nanoGPT_plan_dist_dropout if use_dist_dropout else nanoGPT_plan + ) model = DDP( model, data_pg_or_device_mesh=VESCALE_DEVICE_MESH["DP"], diff --git a/examples/nanogpt_4D_finetune/sharding_plan.py b/examples/nanogpt_4D_finetune/sharding_plan.py index da34ece..952a0dd 100644 --- a/examples/nanogpt_4D_finetune/sharding_plan.py +++ b/examples/nanogpt_4D_finetune/sharding_plan.py @@ -34,6 +34,23 @@ "lm_head.output": [[Replicate()]], } +fwd_plan_dist_dropout = { + "transformer.wte.input": [[Replicate()]], + "transformer.wte.output": [[Replicate()]], + "transformer.wpe.input": [[Replicate()]], + "transformer.wpe.output": [[Replicate()]], + r"transformer.h.\d+.input": [[Shard(1)]], + r"transformer.h.\d+.attn.input": [[Replicate()]], + r"transformer.h.\d+.attn.c_proj.output": [[Shard(1)]], + r"transformer.h.\d+.attn.output": [[Shard(1)]], + r"transformer.h.\d+.mlp.c_fc.input": [[Replicate()]], + r"transformer.h.\d+.mlp.c_proj.output": [[Shard(1)]], + r"transformer.h.\d+.mlp.output": [[Shard(1)]], + "transformer.ln_f.input": [[Shard(1)]], + "lm_head.input": [[Shard(2)]], + "lm_head.output": [[Replicate()]], +} + params_plan = { "transformer.wte.weight": [Shard(1)], "transformer.wpe.weight": [Shard(1)], @@ -53,3 +70,5 @@ } nanoGPT_plan = {"parameter": params_plan, "forward": fwd_plan} + +nanoGPT_plan_dist_dropout = {"parameter": params_plan, "forward": fwd_plan_dist_dropout} diff --git a/test/dmodule/test_dfactory.py b/test/dmodule/test_dfactory.py index 74ca372..eb63ab5 100644 --- a/test/dmodule/test_dfactory.py +++ b/test/dmodule/test_dfactory.py @@ -39,19 +39,6 @@ class DFactoryTest(DTensorTestBase): def world_size(self): return 4 - # def _seeding(self): - # import os - - # os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - # torch.use_deterministic_algorithms(True) - # torch.manual_seed(0) - # torch.random.manual_seed(0) - # torch.cuda.manual_seed(0) - # self._rng_state = torch.random.get_rng_state() - - # def _reset_rng(self): - # torch.random.set_rng_state(self._rng_state) - def _match_factory_dfactory(self, factory, dfactory, global_shape, placements, device_mesh): aten_dfactory_pi = _factory._provide_args(device_mesh, {factory: PI.from_placements(placements)}) @@ -64,7 +51,7 @@ def _match_factory_dfactory(self, factory, dfactory, global_shape, placements, d start, end, step = 0, global_shape[0], 1 assert not placements[0].is_shard() or placements[0].is_shard(0) - with _factory.FactoryDispatchMode(device_mesh=device_mesh, aten_dfactory_pi=aten_dfactory_pi): + with _factory.FactoryDispatchModeOn(device_mesh, aten_dfactory_pi): actual1 = torch.arange(end, dtype=dtype, layout=layout, requires_grad=requires_grad) actual2 = torch.arange(start, end, dtype=dtype, layout=layout, requires_grad=requires_grad) actual3 = torch.arange(start, end, step, dtype=dtype, layout=layout, requires_grad=requires_grad) @@ -98,7 +85,7 @@ def _match_factory_dfactory(self, factory, dfactory, global_shape, placements, d ) goldens = (golden1, golden2, golden3) elif factory == torch.full: - with _factory.FactoryDispatchMode(device_mesh=device_mesh, aten_dfactory_pi=aten_dfactory_pi): + with _factory.FactoryDispatchModeOn(device_mesh, aten_dfactory_pi): actual = torch.full(global_shape, fill_value, dtype=dtype, layout=layout, requires_grad=requires_grad) golden = dfactory( global_shape, @@ -111,12 +98,12 @@ def _match_factory_dfactory(self, factory, dfactory, global_shape, placements, d ) actuals = (actual,) goldens = (golden,) - elif factory in [torch.zeros, torch.ones, torch.empty, torch.randn]: - if factory == torch.randn: + elif factory in [torch.zeros, torch.ones, torch.empty, torch.randn, torch.rand]: + if factory in [torch.randn, torch.rand]: manual_seed(0, device_mesh) - with _factory.FactoryDispatchMode(device_mesh=device_mesh, aten_dfactory_pi=aten_dfactory_pi): + with _factory.FactoryDispatchModeOn(device_mesh, aten_dfactory_pi): actual = factory(global_shape, dtype=dtype, layout=layout, requires_grad=requires_grad) - if factory == torch.randn: + if factory in [torch.randn, torch.rand]: manual_seed(0, device_mesh) golden = dfactory( global_shape, @@ -155,15 +142,126 @@ def test_match_factory_dfactory(self): torch.empty: dtensor.empty, torch.full: dtensor.full, torch.randn: dtensor.randn, + torch.rand: dtensor.rand, torch.arange: dtensor.arange, } - # self._seeding() for factory, dfactory in factory_dfactory.items(): for global_shape in [(4, 4), (5, 4), (5, 7, 9)]: for placements in ([Replicate()], [Shard(0)]): self._match_factory_dfactory(factory, dfactory, global_shape, placements, device_mesh) + @with_comms + def test_nested_dfactory(self): + device_mesh = DeviceMesh(self.device_type, range(self.world_size)) + + replicate_adp = _factory._provide_args(device_mesh, {torch.empty: PI.from_placements([Replicate()])}) + shard_adp = _factory._provide_args(device_mesh, {torch.empty: PI.from_placements([Shard(0)])}) + + # Off, Off + with _factory.FactoryDispatchModeOff(): + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + with _factory.FactoryDispatchModeOff(): + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + + # Off, On + with _factory.FactoryDispatchModeOff(): + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + with _factory.FactoryDispatchModeOn(device_mesh, replicate_adp): + self.assertTrue(isinstance(torch.empty(self.world_size), DTensor)) + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + + # Off, multiple On + with _factory.FactoryDispatchModeOff(): + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + with _factory.FactoryDispatchModeOn(device_mesh, replicate_adp): + self.assertTrue(isinstance(torch.empty(self.world_size), DTensor)) + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + with _factory.FactoryDispatchModeOn(device_mesh, replicate_adp): + self.assertTrue(isinstance(torch.empty(self.world_size), DTensor)) + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + + # On, Off + with _factory.FactoryDispatchModeOn(device_mesh, replicate_adp): + self.assertTrue(isinstance(torch.empty(self.world_size), DTensor)) + with _factory.FactoryDispatchModeOff(): + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + self.assertTrue(isinstance(torch.empty(self.world_size), DTensor)) + + # On, multiple Off + with _factory.FactoryDispatchModeOn(device_mesh, replicate_adp): + self.assertTrue(isinstance(torch.empty(self.world_size), DTensor)) + with _factory.FactoryDispatchModeOff(): + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + self.assertTrue(isinstance(torch.empty(self.world_size), DTensor)) + with _factory.FactoryDispatchModeOff(): + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + self.assertTrue(isinstance(torch.empty(self.world_size), DTensor)) + + # Off, On, Off + with _factory.FactoryDispatchModeOff(): + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + with _factory.FactoryDispatchModeOn(device_mesh, replicate_adp): + self.assertTrue(isinstance(torch.empty(self.world_size), DTensor)) + with _factory.FactoryDispatchModeOff(): + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + self.assertTrue(isinstance(torch.empty(self.world_size), DTensor)) + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + + # On, Off, On + with _factory.FactoryDispatchModeOn(device_mesh, replicate_adp): + actual = torch.empty(self.world_size) + self.assertTrue(actual.placements[0].is_replicate()) + + with _factory.FactoryDispatchModeOff(): + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + + with _factory.FactoryDispatchModeOn(device_mesh, shard_adp): + actual = torch.empty(self.world_size) + self.assertTrue(actual.placements[0].is_shard(0)) + + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + + actual = torch.empty(self.world_size) + self.assertTrue(actual.placements[0].is_replicate()) + + ### With Unrelated ### + from torch.utils._python_dispatch import TorchDispatchMode + + class UnrelatedDispatchMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + return func(*args, **kwargs if kwargs is not None else {}) + + # Off, Unrelated, Off + with _factory.FactoryDispatchModeOff(): + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + with UnrelatedDispatchMode(): # Expect On + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + with _factory.FactoryDispatchModeOff(): + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + + # On, Unrelated, Off + with _factory.FactoryDispatchModeOn(device_mesh, replicate_adp): + self.assertTrue(isinstance(torch.empty(self.world_size), DTensor)) + with UnrelatedDispatchMode(): + self.assertTrue(isinstance(torch.empty(self.world_size), DTensor)) + with _factory.FactoryDispatchModeOff(): + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + self.assertTrue(isinstance(torch.empty(self.world_size), DTensor)) + self.assertTrue(isinstance(torch.empty(self.world_size), DTensor)) + + # Off, Unrelated, On + with _factory.FactoryDispatchModeOff(): + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + with UnrelatedDispatchMode(): + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + with _factory.FactoryDispatchModeOn(device_mesh, replicate_adp): + self.assertTrue(isinstance(torch.empty(self.world_size), DTensor)) + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + self.assertTrue(not isinstance(torch.empty(self.world_size), DTensor)) + @with_comms def test_api(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -313,6 +411,111 @@ def forward(self, x): out = dm(data) self.assertTrue(dtensor.equal(out, ones_replicate)) + @with_comms + def test_api_nested(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + data = torch.ones(HIDDEN_SIZE, device=self.device_type) + + class Inner1(nn.Module): + def forward(self, x): + return torch.zeros(x.shape, dtype=x.dtype, device=x.device) + + class Outer1(nn.Module): + def __init__(self): + super().__init__() + self.m = Inner1() + + def forward(self, x): + a = torch.zeros(x.shape, dtype=x.dtype, device=x.device) + b = self.m(x) + c = torch.zeros(x.shape, dtype=x.dtype, device=x.device) + return a, b, c + + class Inner2(nn.Module): + def forward(self, x): + return torch.ones(x.shape, dtype=x.dtype, device=x.device) + + class Outer2(nn.Module): + def __init__(self): + super().__init__() + self.m1 = Inner1() + self.m2 = Inner2() + + def forward(self, x): + a = torch.ones(x.shape, dtype=x.dtype, device=x.device) + b = self.m1(x) + c = torch.ones(x.shape, dtype=x.dtype, device=x.device) + d = self.m2(x) + e = torch.ones(x.shape, dtype=x.dtype, device=x.device) + return a, b, c, d, e + + class Root(nn.Module): + def __init__(self): + super().__init__() + self.mm = Outer1() + + def forward(self, x): + a = torch.empty(x.shape, dtype=x.dtype, device=x.device) + b, c, d = self.mm(x) + e = torch.empty(x.shape, dtype=x.dtype, device=x.device) + return a, b, c, d, e + + # Off, Off + model = parallelize_module(Outer1(), device_mesh, {}, factory={Outer1: False, Inner1: False}) + out = model(data) + for o in out: + self.assertTrue(not isinstance(o, DTensor)) + + # Off, On + model = parallelize_module(Outer1(), device_mesh, {}, factory={Outer1: False, Inner1: True}) + a, b, c = model(data) + self.assertTrue(not isinstance(a, DTensor)) + self.assertTrue(isinstance(b, DTensor)) + self.assertTrue(not isinstance(c, DTensor)) + + # Off, multiple On + model = parallelize_module(Outer2(), device_mesh, {}, factory={Outer2: False, Inner1: True, Inner2: True}) + a, b, c, d, e = model(data) + self.assertTrue(not isinstance(a, DTensor)) + self.assertTrue(isinstance(b, DTensor)) + self.assertTrue(not isinstance(c, DTensor)) + self.assertTrue(isinstance(d, DTensor)) + self.assertTrue(not isinstance(e, DTensor)) + + # On, Off + model = parallelize_module(Outer1(), device_mesh, {}, factory={Outer1: True, Inner1: False}) + a, b, c = model(data) + self.assertTrue(isinstance(a, DTensor)) + self.assertTrue(not isinstance(b, DTensor)) + self.assertTrue(isinstance(c, DTensor)) + + # On, multiple Off + model = parallelize_module(Outer2(), device_mesh, {}, factory={Outer2: True, Inner1: False, Inner2: False}) + a, b, c, d, e = model(data) + self.assertTrue(isinstance(a, DTensor)) + self.assertTrue(not isinstance(b, DTensor)) + self.assertTrue(isinstance(c, DTensor)) + self.assertTrue(not isinstance(d, DTensor)) + self.assertTrue(isinstance(e, DTensor)) + + # Off, On, Off + model = parallelize_module(Root(), device_mesh, {}, factory={Root: False, Outer1: True, Inner1: False}) + a, b, c, d, e = model(data) + self.assertTrue(not isinstance(a, DTensor)) + self.assertTrue(isinstance(b, DTensor)) + self.assertTrue(not isinstance(c, DTensor)) + self.assertTrue(isinstance(d, DTensor)) + self.assertTrue(not isinstance(e, DTensor)) + + # On, Off, On + model = parallelize_module(Root(), device_mesh, {}, factory={Root: True, Outer1: False, Inner1: True}) + a, b, c, d, e = model(data) + self.assertTrue(isinstance(a, DTensor)) + self.assertTrue(not isinstance(b, DTensor)) + self.assertTrue(isinstance(c, DTensor)) + self.assertTrue(not isinstance(d, DTensor)) + self.assertTrue(isinstance(e, DTensor)) + @with_comms def test_with_fwd_hook(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -339,6 +542,45 @@ def forward(self, a=None): self.assertTrue(isinstance(out2, DTensor)) self.assertTrue(dtensor.equal(out2, zero_replicate)) + # submoduled simple case + class OuterSimpleArgs1(nn.Module): + def __init__(self): + super().__init__() + self.m = SimpleArgs1() + + def forward(self, x): + a = torch.zeros(x.shape, dtype=x.dtype, device=x.device) + b, c = self.m(x) + d = torch.zeros(x.shape, dtype=x.dtype, device=x.device) + return a, b, c, d + + class OuterDefaultArgs1(nn.Module): + def __init__(self): + super().__init__() + self.m = DefaultArgs1() + + def forward(self, x): + a = torch.zeros(x.shape, dtype=x.dtype, device=x.device) + b, c = self.m(x) + d = torch.zeros(x.shape, dtype=x.dtype, device=x.device) + return a, b, c, d + + for mcls in [OuterSimpleArgs1, OuterDefaultArgs1]: + for fwd_plan in [{"m.input": [[Shard(0)]]}, {"m.input": {"a": [Shard(0)]}}]: + # factory = True + dm = parallelize_module(mcls(), device_mesh, {"forward": fwd_plan}, factory={mcls: True}) + _, out1, out2, _ = dm(data) + self.assertTrue(isinstance(out1, DTensor)) + self.assertTrue(out1.placements[0].is_shard(0)) + self.assertTrue(isinstance(out2, DTensor)) + self.assertTrue(dtensor.equal(out2, zero_replicate)) + # factory = False + dm = parallelize_module(mcls(), device_mesh, {"forward": fwd_plan}, factory={mcls: False}) + _, out1, out2, _ = dm(data) + self.assertTrue(isinstance(out1, DTensor)) + self.assertTrue(out1.placements[0].is_shard(0)) + self.assertTrue(not isinstance(out2, DTensor)) + # complex case class MixedArgs2(nn.Module): def forward(self, a, b, c, d=1.0, e=None, *args, f, g="str", **kwargs): @@ -387,7 +629,7 @@ def forward(self, a, b, c, d=1.0, e=None, *args, f, g="str", **kwargs): self.assertTrue(dtensor.equal(out[-1], zero_replicate)) @with_comms - def test_with_model_patch(self): # TODO: support nested factory False + def test_with_model_patch(self): class MLP(nn.Module): def __init__(self): super().__init__() diff --git a/test/dtensor/comm/test_all_to_all.py b/test/dtensor/comm/test_all_to_all.py new file mode 100644 index 0000000..bc143ef --- /dev/null +++ b/test/dtensor/comm/test_all_to_all.py @@ -0,0 +1,64 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import copy +import unittest +from typing import Tuple + +import torch +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, +) + +from vescale.dtensor.placement_types import Replicate, Shard +from vescale.dtensor.device_mesh import DeviceMesh +from vescale.dtensor.api import redistribute_dtensor, distribute_tensor + +from common_dtensor import DTensorTestBase, with_comms + + +class AllToAllTest(DTensorTestBase): + @property + def world_size(self): + return 4 + + @unittest.skip("failed in CI, strange!") + @with_comms + @parametrize("shard_dims", [(0, 1), (1, 2), (2, 1), (1, 0)]) + def test_all_to_all_first(self, shard_dims: Tuple[int]): + original_shard_dim, target_shard_dim = shard_dims[0], shard_dims[1] + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + x = torch.rand(8, 8, 8, 4).cuda() + copy_x = copy.deepcopy(x) + + dx = distribute_tensor(x, device_mesh, [Shard(original_shard_dim)]) + dy = redistribute_dtensor(dx, device_mesh, [Shard(target_shard_dim)]) + + d_copy_x = distribute_tensor(copy_x, device_mesh, [Shard(original_shard_dim)]) + r_copy_x = redistribute_dtensor(d_copy_x, device_mesh, [Replicate()]) + d_copy_y = redistribute_dtensor(r_copy_x, device_mesh, [Shard(target_shard_dim)]) + + self.assertTrue(dy.placements == d_copy_y.placements) + torch.testing.assert_close(dy._local_tensor, d_copy_y._local_tensor) + + +instantiate_parametrized_tests(AllToAllTest) + +if __name__ == "__main__": + run_tests() diff --git a/test/dtensor/ops/test_pointwise_ops.py b/test/dtensor/ops/test_pointwise_ops.py index 9b5207f..abac354 100644 --- a/test/dtensor/ops/test_pointwise_ops.py +++ b/test/dtensor/ops/test_pointwise_ops.py @@ -275,6 +275,71 @@ def test_mul_out(self): self.assertEqual(input_tensor, dtensor.to_local()) self.assertEqual(expected, dt.to_local()) + def test_mul_placements(self): + device_mesh = self.build_device_mesh() + torch.manual_seed(self.rank) + input_size = (8, 4) + input_tensor = torch.randn(*input_size, device=self.device_type) + input_dtensor = DTensor.from_local(input_tensor, device_mesh, [Replicate()]) + + other_tensor = torch.randn(*input_size, device=self.device_type) + other_dtensor = DTensor.from_local(other_tensor, device_mesh, [Replicate()]) + expected = torch.mul(input_tensor, other_tensor) + + # test R mul P + input_dtensor = input_dtensor.redistribute(device_mesh, [Replicate()]) + other_dtensor = other_dtensor.redistribute(device_mesh, [Partial()]) + output_dtensor = torch.mul(input_dtensor, other_dtensor) + output_dtensor = output_dtensor.redistribute(device_mesh, [Replicate()]) + self.assertEqual(expected, output_dtensor.to_local()) + + # test P mul R + input_dtensor = input_dtensor.redistribute(device_mesh, [Partial()]) + other_dtensor = other_dtensor.redistribute(device_mesh, [Replicate()]) + output_dtensor = torch.mul(input_dtensor, other_dtensor) + output_dtensor = output_dtensor.redistribute(device_mesh, [Replicate()]) + self.assertEqual(expected, output_dtensor.to_local()) + + # test P mul P + failed = False + try: + input_dtensor = input_dtensor.redistribute(device_mesh, [Partial()]) + other_dtensor = other_dtensor.redistribute(device_mesh, [Partial()]) + output_dtensor = torch.mul(input_dtensor, other_dtensor) + output_dtensor = output_dtensor.redistribute(device_mesh, [Replicate()]) + except Exception as e: + failed = True + self.assertEqual(failed, True, msg="pointwise P mul P should fail") + + def test_div_placements(self): + device_mesh = self.build_device_mesh() + torch.manual_seed(self.rank) + input_size = (8, 4) + input_tensor = torch.randn(*input_size, device=self.device_type) + input_dtensor = DTensor.from_local(input_tensor, device_mesh, [Replicate()]) + + other_tensor = torch.randn(*input_size, device=self.device_type) + other_dtensor = DTensor.from_local(other_tensor, device_mesh, [Replicate()]) + expected = torch.div(input_tensor, other_tensor) + + # test P div R + input_dtensor = input_dtensor.redistribute(device_mesh, [Partial()]) + other_dtensor = other_dtensor.redistribute(device_mesh, [Replicate()]) + output_dtensor = torch.div(input_dtensor, other_dtensor) + output_dtensor = output_dtensor.redistribute(device_mesh, [Replicate()]) + self.assertEqual(expected, output_dtensor.to_local()) + + # test R div P + failed = False + try: + input_dtensor = input_dtensor.redistribute(device_mesh, [Replicate()]) + other_dtensor = other_dtensor.redistribute(device_mesh, [Partial()]) + output_dtensor = torch.div(input_dtensor, other_dtensor) + output_dtensor = output_dtensor.redistribute(device_mesh, [Replicate()]) + except Exception as e: + failed = True + self.assertEqual(failed, True, msg="pointwise R div P should fail") + def test_triu(self): device_mesh = self.build_device_mesh() input_size = (8, 4) diff --git a/test/dtensor/ops/test_random_ops.py b/test/dtensor/ops/test_random_ops.py index 1210e49..d89b599 100644 --- a/test/dtensor/ops/test_random_ops.py +++ b/test/dtensor/ops/test_random_ops.py @@ -21,7 +21,7 @@ import torch.distributed._functional_collectives as funcol from torch.distributed.distributed_c10d import broadcast_object_list -from vescale import DeviceMesh, DTensor, Shard, Replicate, distribute_tensor +from vescale import DeviceMesh, DTensor, Shard, Partial, Replicate, distribute_tensor import vescale.dtensor.random as random from vescale.dtensor.random import is_rng_supported_mesh, manual_seed from vescale.dtensor import empty as dempty @@ -38,8 +38,7 @@ def _run_init_op(self, init_op, *args, **kwargs): device_mesh = DeviceMesh(self.device_type, mesh_shape) all_shapes = [(8, 4), (4, 4, 4), (8, 8, 4, 4), (5, 6, 7, 8, 9)] for global_shape in all_shapes: - all_placements = [Replicate()] + [Shard(d) for d in range(len(global_shape))] - + all_placements = [Replicate(), Partial()] + [Shard(d) for d in range(len(global_shape))] from itertools import product all_placements = [list(placements) for placements in product(all_placements, repeat=mesh_dim)] @@ -63,14 +62,16 @@ def _run_init_op(self, init_op, *args, **kwargs): else: torch.cuda.manual_seed_all(0) expected_tensor = init_op(torch.empty(*global_shape, device="cuda"), *args, **kwargs) - dist_expected = distribute_tensor(expected_tensor, device_mesh, placements) - + dist_expected = distribute_tensor(expected_tensor.detach().clone(), device_mesh, placements) manual_seed(0, device_mesh) dtensor = init_op( dempty(*global_shape, device_mesh=device_mesh, placements=placements), *args, **kwargs ) - self.assertTrue(list(dtensor._spec.placements) == placements) - self.assertEqual(dtensor._local_tensor, dist_expected._local_tensor, atol=0.0, rtol=0.0) + if any(p.is_partial() for p in placements): + self.assertTrue(all(not p.is_partial() for p in dtensor._spec.placements)) + else: + self.assertTrue(list(dtensor._spec.placements) == placements) + self.assertEqual(dtensor._local_tensor, dist_expected._local_tensor, atol=0.0, rtol=0.0) full_tensor = dtensor.full_tensor() self.assertEqual(full_tensor, expected_tensor, atol=0.0, rtol=0.0) @@ -110,17 +111,18 @@ def test_manual_seed(self): with self.assertRaisesRegex(RuntimeError, "different seed values"): manual_seed(self.rank, device_mesh) - def run_dropout(self, global_shape, mesh, placements): + def run_dropout(self, global_shape, mesh, placements, inplace): torch.cuda.manual_seed_all(0) - dropout = torch.nn.Dropout(p=0.2) + dropout = torch.nn.Dropout(p=0.2, inplace=inplace) expected_tensor = dropout(torch.ones(global_shape, device=self.device_type)) - dist_expected = distribute_tensor(expected_tensor, mesh, placements) + dist_expected = distribute_tensor(expected_tensor.detach().clone(), mesh, placements) manual_seed(0, mesh) dtensor = distribute_tensor(torch.ones(global_shape, device=self.device_type), mesh, placements) dtensor = dropout(dtensor) - self.assertEqual(dtensor.to_local(), dist_expected.to_local(), atol=0.0, rtol=0.0) + if all(not p.is_partial() for p in placements): + self.assertEqual(dtensor._local_tensor, dist_expected._local_tensor, atol=0.0, rtol=0.0) full_tensor = dtensor.full_tensor() self.assertEqual(full_tensor, expected_tensor, atol=0.0, rtol=0.0) @@ -131,12 +133,29 @@ def test_deterministic_dropout_1d(self): shapes = [(9, 7), (4, 16, 16), (7, 5, 16)] mesh = DeviceMesh("cuda", torch.arange(self.world_size)) for global_shape in shapes: - for placements in ([Replicate()], [Shard(0)], [Shard(1)]): - self.run_dropout(global_shape, mesh, placements) + for placements in [[Replicate()], [Partial()], [Shard(0)], [Shard(1)]]: + self.run_dropout(global_shape, mesh, placements, inplace=True) + self.run_dropout(global_shape, mesh, placements, inplace=False) mesh = DeviceMesh("cuda", torch.arange(self.world_size).reshape(self.world_size // 2, 2)) for global_shape in shapes: - for shard in ([Replicate(), Replicate()], [Shard(0), Shard(1)], [Shard(1), Shard(0)]): - self.run_dropout(global_shape, mesh, placements) + for placements in [ + [Shard(0), Shard(1)], + [Shard(0), Replicate()], + [Shard(0), Partial()], + [Shard(1), Shard(0)], + [Shard(1), Replicate()], + [Shard(1), Partial()], + [Replicate(), Shard(0)], + [Replicate(), Shard(1)], + [Replicate(), Partial()], + [Replicate(), Replicate()], + [Partial(), Shard(0)], + [Partial(), Shard(1)], + [Partial(), Partial()], + [Partial(), Replicate()], + ]: + self.run_dropout(global_shape, mesh, placements, inplace=True) + self.run_dropout(global_shape, mesh, placements, inplace=False) @with_comms @skip_if_lt_x_gpu(4) @@ -152,26 +171,26 @@ def test_deterministic_uniform_2d(self): placements_list = [ # this list of placements should be enough to cover [Shard(0), Shard(1)], [Shard(0), Replicate()], - # [Shard(0), Partial()], + [Shard(0), Partial()], [Shard(1), Shard(0)], [Shard(1), Replicate()], - # [Shard(1), Partial()], + [Shard(1), Partial()], [Replicate(), Shard(0)], [Replicate(), Shard(1)], - # [Replicate(), Partial()], + [Replicate(), Partial()], [Replicate(), Replicate()], - # [Partial(), Shard(0)], - # [Partial(), Shard(1)], - # [Partial(), Partial()], - # [Partial(), Replicate()], - ] # TODO: Add Partials in the future + [Partial(), Shard(0)], + [Partial(), Shard(1)], + [Partial(), Partial()], + [Partial(), Replicate()], + ] for placements in placements_list: torch.manual_seed(0) torch.cuda.manual_seed_all(0) golden = torch.empty(*[self.world_size for _ in mesh.size()], device=self.device_type) golden.uniform_(0, 1) - dist_golden = distribute_tensor(golden, device_mesh, placements) + dist_golden = distribute_tensor(golden.detach().clone(), device_mesh, placements) manual_seed(0, device_mesh) dtensor = distribute_tensor( @@ -181,7 +200,11 @@ def test_deterministic_uniform_2d(self): ) dtensor.uniform_(0, 1) - self.assertEqual(dtensor.to_local(), dist_golden.to_local(), atol=0.0, rtol=0.0) + if any(p.is_partial() for p in placements): + self.assertTrue(all(not p.is_partial() for p in dtensor._spec.placements)) + else: + self.assertTrue(list(dtensor._spec.placements) == placements) + self.assertEqual(dtensor._local_tensor, dist_golden._local_tensor, atol=0.0, rtol=0.0) full_tensor = dtensor.full_tensor() self.assertEqual(full_tensor, golden, atol=0.0, rtol=0.0) diff --git a/test/dtensor/ops/test_tensor_ops.py b/test/dtensor/ops/test_tensor_ops.py index f1c24be..0262aa4 100644 --- a/test/dtensor/ops/test_tensor_ops.py +++ b/test/dtensor/ops/test_tensor_ops.py @@ -16,8 +16,8 @@ from unittest import skip from vescale import DeviceMesh, DTensor, distribute_tensor -from vescale.dtensor._diff import EnablePartialMode from vescale.dtensor.placement_types import Partial, Replicate, Shard, InterleavedShard +from vescale.dtensor import empty as dempty class DistTensorOpsTest(DTensorTestBase): @@ -128,17 +128,6 @@ def test_op_out_variant(self): self.assertTrue(res.placements == replica_spec) self.assertEqual(replicate_out.to_local(), expected_dt.to_local()) - @with_comms - def test_empty_like(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - shard_spec = [Shard(0)] - - input_tensor = torch.randn(4, 8, requires_grad=True) - dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) - empty_like_dt = torch.empty_like(dist_tensor) - # empty is not deterministic, so we only check that the shard propagation worked - self.assertEqual((4, 8), empty_like_dt.to_local().shape) - @with_comms def test_fill_inplace(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -151,49 +140,49 @@ def test_fill_inplace(self): self.assertEqual(full_expected, full_like_dt.to_local()) self.assertEqual(full_expected, dist_tensor.to_local()) - @with_comms - def test_full_like(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - shard_spec = [Shard(0)] - - input_tensor = torch.randn(4, 8, requires_grad=True) - dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) - full_like_dt = torch.full_like(dist_tensor, 42.0) - full_expected = torch.full((4, 8), 42.0) - self.assertEqual(full_expected, full_like_dt.to_local()) - - @with_comms - def test_ones_like(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - shard_spec = [Shard(0)] - - input_tensor = torch.randn(4, 8, requires_grad=True) - dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) - ones_like_dt = torch.ones_like(dist_tensor) - ones_expected = torch.ones(4, 8) - self.assertEqual(ones_expected, ones_like_dt.to_local()) + def _run_xxx_like(self, xxx_like_op, *args, **kwargs): + all_mesh_shapes = [ + torch.arange(self.world_size), + torch.arange(self.world_size).reshape(self.world_size // 2, 2), + ] + for mesh_shape in all_mesh_shapes: + mesh_dim = mesh_shape.dim() + device_mesh = DeviceMesh(self.device_type, mesh_shape) + all_shapes = [(8, 4), (4, 4, 4), (8, 8, 4, 4), (5, 6, 7, 8, 9)] + for global_shape in all_shapes: + all_placements = [Replicate(), Partial()] + [Shard(d) for d in range(len(global_shape))] + from itertools import product + + all_placements = [list(placements) for placements in product(all_placements, repeat=mesh_dim)] + + for placements in all_placements: + sharded_dims = [placement.dim for placement in placements if placement.is_shard()] + if len(sharded_dims) > len(set(sharded_dims)): + # Skip the placements that shard along the same dim more than once + continue + expected_tensor = xxx_like_op(torch.empty(*global_shape, device="cuda"), *args, **kwargs) + dist_expected = distribute_tensor(expected_tensor.detach().clone(), device_mesh, placements) + dtensor = xxx_like_op( + dempty(*global_shape, device_mesh=device_mesh, placements=placements), *args, **kwargs + ) + self.assertEqual(dtensor._local_tensor.shape, dist_expected._local_tensor.shape, atol=0.0, rtol=0.0) + if any(p.is_partial() for p in placements): + self.assertTrue(all(not p.is_partial() for p in dtensor._spec.placements)) + else: + self.assertTrue(list(dtensor._spec.placements) == placements) + if xxx_like_op != torch.empty_like: + self.assertEqual(dtensor._local_tensor, dist_expected._local_tensor, atol=0.0, rtol=0.0) + full_tensor = dtensor.full_tensor() + self.assertEqual(full_tensor.shape, expected_tensor.shape, atol=0.0, rtol=0.0) + if xxx_like_op != torch.empty_like: + self.assertEqual(full_tensor, expected_tensor, atol=0.0, rtol=0.0) @with_comms - @skip("failed") - def test_ones_like_partial_sum(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - shard_spec = [Partial()] - - input_tensor = torch.randn(4, 8, requires_grad=True) - dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) - assert dist_tensor.shape == (4, 8) - - with EnablePartialMode(): - ones_like_dt = torch.ones_like(dist_tensor) - ones_expected = torch.ones(dist_tensor.shape) - assert isinstance(ones_like_dt.placements[0], Partial) - ones_like_dt_replicate = torch.ones_like(dist_tensor) - assert isinstance(ones_like_dt_replicate.placements[0], Replicate) - - self.assertEqual( - ones_expected, - ones_like_dt.to_local(), - ) + def test_xxx_like(self): + self._run_xxx_like(torch.empty_like) + self._run_xxx_like(torch.ones_like) + self._run_xxx_like(torch.zeros_like) + self._run_xxx_like(torch.full_like, fill_value=42.0) @with_comms @skip("failed") @@ -212,24 +201,6 @@ def test_fill_inplace_partial_sum(self): dist_tensor.redistribute(device_mesh, [Replicate()]).to_local(), ) - @with_comms - @skip("failed") - def test_zeros_like_partial_sum(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - shard_spec = [Partial()] - - input_tensor = torch.randn(4, 8, requires_grad=True) - dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) - assert dist_tensor.shape == (4, 8) - - with EnablePartialMode(): - zeros_like_dt = torch.zeros_like(dist_tensor) - assert isinstance(zeros_like_dt.placements[0], Partial) - zeros_like_dt_replicate = torch.zeros_like(dist_tensor) - assert isinstance(zeros_like_dt_replicate.placements[0], Replicate) - zeros_expected = torch.zeros(dist_tensor.shape) - self.assertEqual(zeros_expected, zeros_like_dt.to_local()) - @with_comms def test_zero_inplace(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -242,17 +213,6 @@ def test_zero_inplace(self): self.assertEqual(zeros_expected, zeros_like_dt.to_local()) self.assertEqual(zeros_expected, dist_tensor.to_local()) - @with_comms - def test_zeros_like(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - shard_spec = [Shard(0)] - - input_tensor = torch.randn(4, 8, requires_grad=True) - dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) - zeros_like_dt = torch.zeros_like(dist_tensor) - zeros_expected = torch.zeros(4, 8) - self.assertEqual(zeros_expected, zeros_like_dt.to_local()) - @with_comms def test_equal(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) diff --git a/test/initialize/test_defer_init.py b/test/initialize/test_defer_init.py index b6f0df7..c268dde 100644 --- a/test/initialize/test_defer_init.py +++ b/test/initialize/test_defer_init.py @@ -47,7 +47,7 @@ def _test_accuracy_base(self, op_call, global_shape, sharding, mesh): dist_golden = distribute_tensor(tensor_golden, mesh, sharding) manual_seed(0, mesh) - tensor_defer = deferred_init(op_call, global_shape) + tensor_defer = deferred_init(op_call, global_shape, device=self.device_type) dtensor_defer = materialize_dtensor(tensor_defer, mesh, sharding) self.assertTrue( diff --git a/test/parallel/ddp_optim/test_moe.py b/test/parallel/ddp_optim/test_moe.py new file mode 100644 index 0000000..bcc469b --- /dev/null +++ b/test/parallel/ddp_optim/test_moe.py @@ -0,0 +1,251 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import copy + +import torch +import torch.distributed as dist +from torch.testing._internal.common_utils import ( + run_tests, +) + +from vescale.dtensor.placement_types import Replicate, Shard +from vescale.dtensor.device_mesh import init_device_mesh +from vescale.dtensor.api import redistribute_dtensor +from vescale.dmodule.api import parallelize_module +from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP +from vescale.optim.base_optimizer import BasicOptimizer, BasicOptimizerHook + +from common_dtensor import DTensorTestBase, with_comms +from test_models.mlp import ( + MLP, + BSZ, +) + +HIDDEN_DIM = 512 + +PAIRWISE_PARAM_SHARDING_PLAN = { + r"moe.experts.\d+.fc1.weight": [Shard(0)], + r"moe.experts.\d+.fc1.bias": [Shard(0)], + r"moe.experts.\d+.fc2.weight": [Shard(1)], + r"moe.experts.\d+.fc2.bias": [Replicate()], + r"ln.weight": [Replicate()], + r"ln.bias": [Replicate()], +} + +FWD_RESAHRDING_PLAM = { + r".input": [[Replicate()]], + r"moe.experts.\d+.fc1.input": [[Replicate()]], + r"moe.experts.\d+.fc2.output": [[Replicate()]], +} + + +def get_unfied_param_and_data(bsz, hidden_dim, dtype=torch.float): + fc1_weights = torch.rand(8, hidden_dim * 4, hidden_dim, dtype=dtype).cuda() + fc1_biases = torch.rand(8, hidden_dim * 4, dtype=dtype).cuda() + + fc2_weights = torch.rand(8, hidden_dim, hidden_dim * 4, dtype=dtype).cuda() + fc2_biases = torch.rand(8, hidden_dim, dtype=dtype).cuda() + + ln_weight = torch.rand(hidden_dim).cuda() + ln_bias = torch.rand(hidden_dim).cuda() + + batch1_epoch1 = torch.rand(bsz, hidden_dim, dtype=dtype).cuda() + batch2_epoch1 = torch.rand(bsz, hidden_dim, dtype=dtype).cuda() + batch1_epoch2 = torch.rand(bsz, hidden_dim, dtype=dtype).cuda() + batch2_epoch2 = torch.rand(bsz, hidden_dim, dtype=dtype).cuda() + + # allreduce parameter and batches to make sure they are same at all ranks + torch.distributed.all_reduce(fc1_weights) + torch.distributed.all_reduce(fc1_biases) + torch.distributed.all_reduce(fc2_weights) + torch.distributed.all_reduce(fc2_biases) + torch.distributed.all_reduce(ln_weight) + torch.distributed.all_reduce(ln_bias) + torch.distributed.all_reduce(batch1_epoch1) + torch.distributed.all_reduce(batch2_epoch1) + torch.distributed.all_reduce(batch1_epoch2) + torch.distributed.all_reduce(batch2_epoch2) + + params_and_inputs = { + "fc1.weights": torch.unbind(fc1_weights, 0), + "fc1.biases": torch.unbind(fc1_biases, 0), + "fc2.weights": torch.unbind(fc2_weights, 0), + "fc2.biases": torch.unbind(fc2_biases, 0), + "ln.weight": ln_weight, + "ln.bias": ln_bias, + "batch1_epoch1": batch1_epoch1, + "batch2_epoch1": batch2_epoch1, + "batch1_epoch2": batch1_epoch2, + "batch2_epoch2": batch2_epoch2, + } + + return params_and_inputs + + +class MoEBlock(torch.nn.Module): + def __init__(self, hidden_dim): + super().__init__() + self.experts = torch.nn.ModuleList( + [ + MLP(hidden_dim), + MLP(hidden_dim), + MLP(hidden_dim), + MLP(hidden_dim), + MLP(hidden_dim), + MLP(hidden_dim), + MLP(hidden_dim), + MLP(hidden_dim), + ] + ) + + def forward(self, x): + # we simulate a sparse MoE by only invoking some of the experts. + output = torch.zeros_like(x) + for i in range(0, 4): + output += self.experts[i](x) + return output + + +class Net(torch.nn.Module): + def __init__(self, hidden_dim) -> None: + super().__init__() + self.moe = MoEBlock(hidden_dim) + self.ln = torch.nn.LayerNorm(hidden_dim) + + def forward(self, x): + return self.moe(self.ln(x)) + + +class VeScaleDDPTest(DTensorTestBase): + @property + def world_size(self): + return 4 + + def gen_golden_output(self, params_and_inputs): + m = Net(HIDDEN_DIM).cuda() + m.ln.weight = torch.nn.Parameter(params_and_inputs["ln.weight"]) + m.ln.bias = torch.nn.Parameter(params_and_inputs["ln.bias"]) + for i in range(8): + m.moe.experts[i].fc1.weight = torch.nn.Parameter(params_and_inputs["fc1.weights"][i]) + m.moe.experts[i].fc1.bias = torch.nn.Parameter(params_and_inputs["fc1.biases"][i]) + m.moe.experts[i].fc2.weight = torch.nn.Parameter(params_and_inputs["fc2.weights"][i]) + m.moe.experts[i].fc2.bias = torch.nn.Parameter(params_and_inputs["fc2.biases"][i]) + + optimizer = torch.optim.Adam(m.parameters(), lr=0.01) + + # epoch 1 + optimizer.zero_grad() + output = m(params_and_inputs["batch1_epoch1"]) + output.sum().backward() + output = m(params_and_inputs["batch2_epoch1"]) + output.sum().backward() + + # manually reduce-mean the grad + for p in m.parameters(): + if p.grad is not None: + p.grad /= 2 + + optimizer.step() + + # epoch 2 + optimizer.zero_grad() + output = m(params_and_inputs["batch1_epoch2"]) + output.sum().backward() + output = m(params_and_inputs["batch2_epoch2"]) + output.sum().backward() + + # manually reduce-mean the grad + for p in m.parameters(): + if p.grad is not None: + p.grad /= 2 + + optimizer.step() + + return m + + @with_comms + def test_ddp_moe(self): + tp_parallel_size = 2 + + dp_size = self.world_size // tp_parallel_size + device_mesh = init_device_mesh(self.device_type, (dp_size, tp_parallel_size), mesh_dim_names=("DP", "TP")) + tp_sub_mesh = device_mesh["TP"] + dp_pg = device_mesh.get_dim_groups(0) + + params_and_inputs = get_unfied_param_and_data(BSZ, HIDDEN_DIM) + new_params_and_inputs = copy.deepcopy(params_and_inputs) + + ve_model = Net(HIDDEN_DIM).cuda(self.rank) + ve_model.ln.weight = torch.nn.Parameter(params_and_inputs["ln.weight"]) + ve_model.ln.bias = torch.nn.Parameter(params_and_inputs["ln.bias"]) + for i in range(8): + ve_model.moe.experts[i].fc1.weight = torch.nn.Parameter(params_and_inputs["fc1.weights"][i]) + ve_model.moe.experts[i].fc1.bias = torch.nn.Parameter(params_and_inputs["fc1.biases"][i]) + ve_model.moe.experts[i].fc2.weight = torch.nn.Parameter(params_and_inputs["fc2.weights"][i]) + ve_model.moe.experts[i].fc2.bias = torch.nn.Parameter(params_and_inputs["fc2.biases"][i]) + + ve_model = parallelize_module( + ve_model, tp_sub_mesh, {"parameter": PAIRWISE_PARAM_SHARDING_PLAN, "forward": FWD_RESAHRDING_PLAM} + ) + + ve_model = DDP( + ve_model, + data_pg_or_device_mesh=dp_pg, + accumulate_allreduce_grads_in_fp32=True, + overlap_grad_reduce=True, + use_distributed_optimizer=False, + bucket_size=2000000, + whitelist_module_types=[MoEBlock], + ) + + ve_optimizer = torch.optim.Adam(ve_model.parameters(), lr=0.01) + ve_optimizer = BasicOptimizer(ve_optimizer, models=ve_model, grad_hook=BasicOptimizerHook) + + # epoch 1 + ve_optimizer.zero_grad() + ve_model.zero_grad_buffer() + x = params_and_inputs["batch1_epoch1"] + if dist.get_rank() == 2 or dist.get_rank() == 3: + x = params_and_inputs["batch2_epoch1"] + ve_model(x).to_local().sum().backward() + # trigger grad all reducing synchronously if not overlap_grad_reduce, + # or wait asynchronous grad reduce finish. + ve_model.finish_grad_sync() + ve_optimizer.step() + + # epoch 2 + ve_optimizer.zero_grad() + ve_model.zero_grad_buffer() + x = params_and_inputs["batch1_epoch2"] + if dist.get_rank() == 2 or dist.get_rank() == 3: + x = params_and_inputs["batch2_epoch2"] + ve_model(x).to_local().sum().backward() + # trigger grad all reducing synchronously if not overlap_grad_reduce, + # or wait asynchronous grad reduce finish. + ve_model.finish_grad_sync() + ve_optimizer.step() + + golden_module = self.gen_golden_output(new_params_and_inputs) + for name, golden_param in golden_module.named_parameters(): + param = ve_model.module.get_parameter(name) + replicate_param = redistribute_dtensor(param.data, tp_sub_mesh, [Replicate()]) + torch.testing.assert_close(golden_param.data, replicate_param._local_tensor) + + +if __name__ == "__main__": + run_tests() diff --git a/vescale/ddp/distributed_data_parallel.py b/vescale/ddp/distributed_data_parallel.py index c8b5325..f14c423 100644 --- a/vescale/ddp/distributed_data_parallel.py +++ b/vescale/ddp/distributed_data_parallel.py @@ -4,7 +4,7 @@ # Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. ################################################################################ -from typing import Dict, Union +from typing import Dict, Union, List, Any import torch import torch.distributed.distributed_c10d as c10d @@ -38,6 +38,7 @@ class DistributedDataParallel(torch.nn.Module): per bucket _if_ overlap_grad_reduce is True and pp_rank is 0. bucket_size (int): the size of single bucket, only useful when bucketing is enabled. By default, 40000000. + whitelist_module_types (List[Type]): Types of sparse submodules. By default, None. Returns: A :class:`DistributedDataParallel` object. @@ -55,7 +56,8 @@ class DistributedDataParallel(torch.nn.Module): mlp = parallelize_module(MLP(), mesh, ..., ...) ddp_module = DDP( module=mlp, - data_pg_or_device_mesh=mesh + data_pg_or_device_mesh=mesh, + whitelist_module_types=[MoEBlock] ) # run the forward. ddp_module(torch.rand(xxx)) @@ -71,6 +73,7 @@ def __init__( use_distributed_optimizer: bool = False, disable_bucketing: bool = False, bucket_size: int = 40000000, # Unit: number of the elements + whitelist_module_types: List[Any] = None, **kwargs, ): super().__init__() @@ -166,6 +169,18 @@ def __init__( grad_acc.register_hook(self._make_param_hook(param, self.param_to_grad_buffer)) self.grad_accs.append(grad_acc) + # Register backward hook for submodules of sparse structure. + if whitelist_module_types is not None and self.overlap_grad_reduce: + for submod in self.module.modules(): + is_sparse = False + for t in whitelist_module_types: + if isinstance(submod, t): + is_sparse = True + break + if not is_sparse: + continue + submod.register_forward_pre_hook(self._make_sparse_module_pre_hook(), prepend=True) + def forward(self, *inputs, **kwargs): """ Calls the wrapped module's forward() method. @@ -208,6 +223,31 @@ def param_hook(*unused): return param_hook + def _make_sparse_module_backward_hook(self, sparse_module, param_to_grad_buffer): + """ + Creates the all-reduce / reduce-scatter hook for back propagation of sparse Modules, like MOE. + """ + + def backward_hook(*unused): + # we do nothing if not overlap_grad_reduce. + if not self.overlap_grad_reduce: + return + # force to mark all parameters in the sparse_module as ready for allreduce + # once we found the back propagation of the module is finished. + for param in sparse_module.parameters(): + param_to_grad_buffer[param].register_grad_maybe_absent(param) + + return backward_hook + + def _make_sparse_module_pre_hook(self): + def sparse_module_pre_hook(module, args): + for x in args: + if isinstance(x, torch.Tensor) and x.requires_grad: + x.register_hook(self._make_sparse_module_backward_hook(module, self.param_to_grad_buffer)) + break + + return sparse_module_pre_hook + def start_grad_sync(self, *unused): """ Initiates grad sync (all-reduce or reduce-scatter) communication operations diff --git a/vescale/ddp/grad_buffer.py b/vescale/ddp/grad_buffer.py index 5053c01..5f382a6 100644 --- a/vescale/ddp/grad_buffer.py +++ b/vescale/ddp/grad_buffer.py @@ -58,6 +58,7 @@ def __init__( self.params_list = params self.params = set(params) self.params_with_grad = set() + self.whitelist_params = set() self.data = data # The distributed optimizer needs to keep track of this bucket's offset # within the full grad_buffer. @@ -76,6 +77,7 @@ def reset(self): Reset metadata in bucket in preparation for the next iteration of training. """ self.params_with_grad = set() + self.whitelist_params = set() self.communication_handle = None self.partial_grad_communication_handle = None self.communication_issued = False @@ -163,8 +165,9 @@ def finish_grad_sync(self): if self.communication_handle is None or (not self.communication_issued): warnings.warn( f"DDP Bucket expects {len(self.params)} params all having .grad" - f"but gets {len(self.params_with_grad)} grad available." - f"This may be due to unused model parameters. " + f"but gets {len(self.params_with_grad)} grad available, " + f"and gets {len(self.whitelist_params - self.params_with_grad)} params marked as absent in backward." + f"This may be due to unused and unmarked model parameters. " "We issue blocking communication for this bucket after other overlapped communications." ) self.start_grad_sync() @@ -172,6 +175,7 @@ def finish_grad_sync(self): assert self.communication_handle is not None and self.communication_issued, ( f"Communication call has not been issued for this bucket " f"({len(self.params_with_grad)}/{len(self.params)} params have grad available)" + f"({len(self.whitelist_params - self.params_with_grad)}/{len(self.params)} params have grad marked as absent in backward)" ) self.communication_handle.wait() @@ -186,8 +190,23 @@ def register_grad_ready(self, param: torch.nn.Parameter): assert param not in self.params_with_grad, "Cannot set grad twice" assert self.overlap_grad_reduce, "register_grad_ready() should be called only when overlapping grad reduce" self.params_with_grad.add(param) - # If all params in bucket have grads available, issue communication call. - if len(self.params_with_grad) == len(self.params): + # If all params in bucket have grads available or marked as absent in backward, issue communication call. + if len(self.params_with_grad.union(self.whitelist_params)) == len(self.params): + self.start_grad_sync() + + def register_grad_maybe_absent(self, param: torch.nn.Parameter): + """ + Registers grads for the passed-in param to be "ready" for grad sync. + + NOTE: This API should only be called when there is a sparse model structure, like MOE. + """ + assert param in self.params, "Param is not in the bucket" + assert self.overlap_grad_reduce, "register_grad_ready() should be called only when overlapping grad reduce" + if param in self.params_with_grad: + return + self.whitelist_params.add(param) + # If all params in bucket have grads available or marked as absent in backward, issue communication call. + if len(self.params_with_grad.union(self.whitelist_params)) == len(self.params): self.start_grad_sync() def register_partial_grad_ready( @@ -462,3 +481,14 @@ def register_partial_grad_ready( """ bucket = self.param_to_bucket[param] bucket.register_partial_grad_ready(param, model_parallel_device_mesh, placements) + + def register_grad_maybe_absent(self, param: torch.nn.Parameter): + """ + Registers grads for the passed-in param to be "ready" for grad sync. + + NOTE: This API should only be called when there is a sparse model structure, like MOE. + """ + assert self.overlap_grad_reduce, "register_grad_ready() should only be called when overlap_grad_reduce is True" + if self.is_last_microbatch: + bucket = self.param_to_bucket[param] + bucket.register_grad_maybe_absent(param) diff --git a/vescale/dmodule/_dmodule.py b/vescale/dmodule/_dmodule.py index 50f861d..5e03a30 100644 --- a/vescale/dmodule/_dmodule.py +++ b/vescale/dmodule/_dmodule.py @@ -269,6 +269,7 @@ def init_parameters(module: nn.Module, is_sharded: bool): Non-appointed parameters and buffers will be `Replicate` (i.e., default plan). """ assert DModule.has_all_attributes(module) + device_type = module._device_mesh.device_type # pre-order traverse root and submodules for submod_path, submod in module.named_modules(): # get assigned plans from root @@ -282,6 +283,8 @@ def init_parameters(module: nn.Module, is_sharded: bool): if param_pi.placements is None: # default plan param_pi.placements = [Replicate()] * module._device_mesh.ndim param = DModule._distribute_parameter(param, module._device_mesh, param_pi, is_sharded) + # force to put param on given device, like cuda. + param = torch.nn.Parameter(param.data.to(device_type)) submod.register_parameter(param_name, param) # distribute immediate buffers @@ -292,6 +295,8 @@ def init_parameters(module: nn.Module, is_sharded: bool): if buffer_pi.placements is None: # default plan buffer_pi.placements = [Replicate()] * module._device_mesh.ndim buffer = DModule._distribute_parameter(buffer, module._device_mesh, buffer_pi, is_sharded) + # force to put buffer on given device, like cuda. + buffer = buffer.to(device_type) submod.register_buffer(buffer_name, buffer) @staticmethod @@ -394,24 +399,22 @@ def prepare_factory(module: nn.Module, factory: Union[bool, Dict[nn.Module, Unio if not fqn_submods: return - # verifiy that there is no nested submodule for factory dispatch mode - fqns = [fqn for fqn, _ in fqn_submods] - for fqn in fqns: - for other_fqn in fqns: - if fqn == other_fqn: - continue - if other_fqn.startswith(fqn): - raise NotImplementedError( - f"Nested submodules for dtensorizing factory is not supported yet: `{fqn}` and `{other_fqn}`!" - ) - - # normalize appointed factory, wrap the forward with factory dispatch mode + # turns off factory for all model patches as the highest override (inner most wrapper) + from vescale.model.patch.utils import is_patched from vescale.dmodule._factory import wrap_factory_mode + for submod in module.modules(): + if is_patched(submod): + wrap_factory_mode(False, submod) + + # normalize appointed factory, wrap the forward with factory dispatch mode for fqn, submod in fqn_submods: factory_placement: Union[bool, Dict] = factory[type(submod)] + if not factory_placement: # False or {} + wrap_factory_mode(False, submod) continue + if factory_placement is True: factory_pi = {} # all factories as default placement else: # Dict @@ -420,7 +423,7 @@ def prepare_factory(module: nn.Module, factory: Union[bool, Dict[nn.Module, Unio for f, p in factory_placement.items() } factory_pi = {f: p for f, p in factory_pi.items() if p is not None} - wrap_factory_mode(submod, module._device_mesh, factory_pi) + wrap_factory_mode(True, submod, module._device_mesh, factory_pi) """ ============ Bound Methods Below ============ """ @@ -516,7 +519,7 @@ def _param_str(param: nn.Parameter) -> str: if param_str.endswith(", "): param_str = param_str[:-2] param_str += ")" - elif isinstance(param.data, nn.Tensor): + elif isinstance(param.data, torch.Tensor): param_str = "TensorParam(" if show_shape: param_str += f"shape={param.data.shape}, " diff --git a/vescale/dmodule/_factory.py b/vescale/dmodule/_factory.py index 94750a5..6816a07 100644 --- a/vescale/dmodule/_factory.py +++ b/vescale/dmodule/_factory.py @@ -19,10 +19,17 @@ import warnings from inspect import signature import functools +import contextlib import torch from torch import nn from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._python_dispatch import ( + _pop_mode, + _push_mode, + _len_torch_dispatch_stack, + _get_current_dispatch_mode_stack, +) from vescale import dtensor from vescale.dtensor.placement_types import Replicate @@ -42,21 +49,15 @@ torch.empty: (aten.empty.memory_format, dtensor.empty), torch.full: (aten.full.default, dtensor.full), torch.randn: (aten.randn.default, dtensor.randn), + torch.rand: (aten.rand.default, dtensor.rand), torch.arange: ((aten.arange.default, aten.arange.start, aten.arange.start_step), dtensor.arange), } -class FactoryDispatchMode(TorchDispatchMode): - def __init__( - self, - _dispatch_key=None, - device_mesh: DeviceMesh = None, - aten_dfactory_pi: Dict[Callable, Tuple[Callable, PI]] = None, - ): - super().__init__(_dispatch_key) - assert device_mesh is not None +class FactoryDispatchModeOn(TorchDispatchMode): + def __init__(self, device_mesh: DeviceMesh, aten_dfactory_pi: Dict[Callable, Tuple[Callable, PI]]): + super().__init__() self.device_mesh = device_mesh - assert aten_dfactory_pi is not None self.aten_dfactory_pi = aten_dfactory_pi def __torch_dispatch__(self, func: Callable, _, args: Tuple, kwargs: Optional[Dict] = None): @@ -134,8 +135,52 @@ def _provide_args(device_mesh: DeviceMesh, factory_pi: Dict[Callable, PI]) -> Di return aten_dfactory_pi -def _provide_wrapped_forward( - origin_forward: Callable, device_mesh: DeviceMesh, aten_dfactory_pi: Dict[Callable, Tuple[Callable, PI]] +@contextlib.contextmanager +def FactoryDispatchModeOff(): # ref: `torch.utils._python_dispatch._disable_current_modes()` + # --- enter --- + saved_idx, saved_mode, saved_len = None, None, None + # turn off the lastest mode from the top of stack + if any(isinstance(_mode, FactoryDispatchModeOn) for _mode in _get_current_dispatch_mode_stack()): + saved_len = _len_torch_dispatch_stack() + _tmp_stack = [] + _idx = saved_len - 1 + while _idx >= 0: + _mode: TorchDispatchMode = _pop_mode() + if isinstance(_mode, FactoryDispatchModeOn): + saved_idx, saved_mode = _idx, _mode + break + else: + _tmp_stack.append(_mode) + _idx -= 1 + while _tmp_stack: + _push_mode(_tmp_stack.pop()) + assert _len_torch_dispatch_stack() == saved_len - 1, "Only one mode should be poped!" + # ------ + try: + yield + # --- exit --- + finally: + # restore the saved mode at the original idx in the stack + if saved_idx is not None: + _tmp_stack = [] + _idx = saved_len - 1 + while _idx >= 0: + if _idx == saved_idx: + _push_mode(saved_mode) + break + else: + _tmp_stack.append(_pop_mode()) + _idx -= 1 + while _tmp_stack: + _push_mode(_tmp_stack.pop()) + assert _len_torch_dispatch_stack() == saved_len, "Stack should be restored!" + # ------ + + +def _provide_wrapped_forward_on( + origin_forward: Callable, + device_mesh: DeviceMesh, + aten_dfactory_pi: Dict[Callable, Tuple[Callable, PI]], ): # new forward @functools.wraps(origin_forward) # copy signatures @@ -143,11 +188,31 @@ def forward(*args, **kwargs): if _IS_DEBUG: print(f"[DEBUG] forward({args}, {kwargs})") print(f"[DEBUG] signature(forward): {signature(forward)}") - # assert that arguments are correct - signature(forward).bind(*args, **kwargs) + signature(forward).bind(*args, **kwargs) # assert that arguments are correct print(f"[DEBUG] origin_forward({args}, {kwargs})") - # call original forward under factory mode - with FactoryDispatchMode(device_mesh=device_mesh, aten_dfactory_pi=aten_dfactory_pi): + # call original forward + with FactoryDispatchModeOff(): # isolate this factory mode on, in case of nest modes + with FactoryDispatchModeOn(device_mesh, aten_dfactory_pi): + return origin_forward(*args, **kwargs) + + if _IS_DEBUG: + print(f"[DEBUG] signature(origin_forward): {signature(origin_forward)}") + print(f"[DEBUG] signature(forward): {signature(forward)}") + + return forward + + +def _provide_wrapped_forward_off(origin_forward: Callable): + # new forward + @functools.wraps(origin_forward) # copy signatures + def forward(*args, **kwargs): + if _IS_DEBUG: + print(f"[DEBUG] forward({args}, {kwargs})") + print(f"[DEBUG] signature(forward): {signature(forward)}") + signature(forward).bind(*args, **kwargs) # assert that arguments are correct + print(f"[DEBUG] origin_forward({args}, {kwargs})") + # call original forward + with FactoryDispatchModeOff(): return origin_forward(*args, **kwargs) if _IS_DEBUG: @@ -157,14 +222,23 @@ def forward(*args, **kwargs): return forward -def wrap_factory_mode(mod: nn.Module, device_mesh: DeviceMesh, factory_pi: Dict[Callable, PI]) -> None: # noqa: B006 - # prepare args to factory mode (put here to avoid runtime overhead) - aten_dfactory_pi = _provide_args(device_mesh, factory_pi) +def wrap_factory_mode( + on: bool, mod: nn.Module, device_mesh: Optional[DeviceMesh] = None, factory_pi: Optional[Dict[Callable, PI]] = None +) -> None: # noqa: B006 + if on: # turn on factory mode + assert device_mesh is not None + assert factory_pi is not None + + # prepare args to factory mode (put here to avoid runtime overhead) + aten_dfactory_pi = _provide_args(device_mesh, factory_pi) + + # wrap forward with factory mode + # NOTE: bound method with `MethodType` will disable signature (either set by `__signature__` or `@functools.wraps``), + # which disables forward hooks appointed by forward plan (as `(x,) != (*args, **kwargs)` ) + # so we use unbound method here to keep the same signature + mod.forward = _provide_wrapped_forward_on(mod.forward, device_mesh, aten_dfactory_pi) + else: # turn off factory mode + mod.forward = _provide_wrapped_forward_off(mod.forward) - # wrap forward with factory mode - # NOTE: bound method with `MethodType` will disable signature (either set by `__signature__` or `@functools.wraps``), - # which disables forward hooks appointed by forward plan (as `(x,) != (*args, **kwargs)` ) - # so we use unbound method here to keep the same signature - mod.forward = _provide_wrapped_forward(mod.forward, device_mesh, aten_dfactory_pi) if _IS_DEBUG: print(f"[DEBUG] signature(mod.forward): {signature(mod.forward)}") diff --git a/vescale/dmodule/api.py b/vescale/dmodule/api.py index ba907e6..f2d1e00 100644 --- a/vescale/dmodule/api.py +++ b/vescale/dmodule/api.py @@ -136,13 +136,16 @@ def parallelize_module( - `True`: all submodules and all factory funcs will be converted to DTensor in `Replicate`. - `False` or `{}`: disable this factory function conversion to DTensor. - `{ submodule_cls : True }`: only this `submodule_cls`'s all factory function will be converted to DTensor in `Replicate`. - - `{ submodule_cls : False or [] }`: exclude this `submodule_cls` for factory function conversion to DTensor. + - `{ submodule_cls : False or {} }`: exclude this `submodule_cls` for factory function conversion to DTensor. - `{ submodule_cls : { factory_func : } }`: only this `submodule_cls`'s `factory_func` will be converted to DTensor in ``. + Nested Case: `{ submodule_cls_outer : True/False/{..}, submodule_cls_inner : True/False/{..} }` can have `submodule_cls_inner` nested in `submodule_cls_outer`, + in which case we let the inner `submodule_cls_inner` overrides `submodule_cls_outer` in `True/False/{..}`, i.e., like a context manager in Python. + Note: Currently, this factory converison: - only covers `forward()` - assumes same for `factory_func` - - does NOT support nested `submodule_cls` + - won't be affected by other TorchDispatchMode Returns: (Optional) this parallelized model instance @@ -211,13 +214,35 @@ def __init__(self): self.fc2 = nn.Linear(8, 8) def forward(self, x): - x = torch.zeros(x.shape) + x = torch.zeros(x.shape) # to be converted to DTensor zeros during runtime x = self.fc1(x) x = self.fc2(x) return x dmlp = parallelize_module(MLP(), ..., factory=True) # or factory = { MLP: {torch.zeros: [Replicate()]} } + Example:: using factory for nested classes + + class MLP(nn.Module): + ... + + def forward(self, x): + x = torch.zeros(x.shape) # to be converted to DTensor in Shard + ... + + class Block(nn.Module): + def __init__(self): + super().__init__() + self.mlp = MLP() + + def forward(self, x): + x = torch.zeros(x.shape) # to be converted to DTensor in Replicate + x = self.mlp(x) + return x + + dmlp = parallelize_module(MLP(), ..., factory={ Block : {torch.zeros: [Replicate()]} + MLP: {torch.zeros: [Shard(0)]} }) # inner class overrides + Example:: using gradient synchronization with customized target ... diff --git a/vescale/dtensor/_collective_utils.py b/vescale/dtensor/_collective_utils.py index ec3ccab..a813935 100644 --- a/vescale/dtensor/_collective_utils.py +++ b/vescale/dtensor/_collective_utils.py @@ -9,8 +9,9 @@ ################################################################################ import logging -from typing import List, Optional import math +import copy +from typing import List, Optional import torch import torch.distributed._functional_collectives as funcol @@ -155,6 +156,78 @@ def mesh_all_to_all( return work +def mesh_all_to_all_single( + tensor: torch.Tensor, + mesh: DeviceMesh, + original_shard_dim: int, + target_shard_dim: int, + mesh_dim: int = 0, + async_op: bool = False, +): + """ + transpose the sharded tensor along a device mesh dimension. + + Args: + tensor (torch.Tensor): tensor to all-to-all. + mesh (DeviceMesh): device mesh that communication happens. + original_shard_dim (int): the dim that source tensor is sharded + target_shard_dim (int): the dim that transposed tensor is sharded + mesh_dim (int, optional): indicate which mesh dimension we want + to broadcast on, we by default choose the first rank on the + mesh dimension as source of truth. + async_op (bool, default False): unused arguments. As all-to-all will + always be sync. + + Returns: + A :class:`Tensor` object + """ + if DebugLogger.IS_DEBUG_MODE: + DebugLogger.log_communication( + mesh_all_to_all_single, tensor, mesh, original_shard_dim, target_shard_dim, mesh_dim + ) + + # if rank is not part of mesh, simply return tensor, which should be an empty tensor + if mesh.get_coordinate() is None: + return tensor + mesh_size = mesh.size(mesh_dim) + assert tensor.size(target_shard_dim) % mesh_size == 0, "we don't support unvevn shard on ``target_shard_dim``" + input_rank = tensor.ndim + assert input_rank >= 2, "input must has at least 2 ranks" + + target_shape = copy.deepcopy(list(tensor.shape)) + target_shape[original_shard_dim] *= mesh_size + target_shape[target_shard_dim] //= mesh_size + + dim_group = mesh.get_dim_groups(mesh_dim) + assert isinstance(dim_group, ProcessGroup) + + if target_shard_dim != 0: + k_new_shape = list(tensor.shape) + k_new_shape[target_shard_dim] //= mesh_size + k_new_shape[0] *= mesh_size + new_shape = list(tensor.shape) + new_shape[target_shard_dim] //= mesh_size + new_shape.insert(target_shard_dim, mesh_size) + indices = ( + [target_shard_dim] + list(range(0, target_shard_dim)) + list(range(target_shard_dim + 1, tensor.ndim + 1)) + ) + tensor = tensor.reshape(new_shape).permute(indices).reshape(k_new_shape) + + output = funcol.all_to_all_single(tensor, output_split_sizes=None, input_split_sizes=None, group=dim_group) + if original_shard_dim == 0: + return output + + n, *out_shape = list(output.shape) + + indices = ( + list(range(1, original_shard_dim)) + + [original_shard_dim, 0] + + list(range(original_shard_dim + 1, output.ndim + 1)) + ) + + return output.reshape(mesh_size, n // mesh_size, *out_shape).permute(indices).reshape(target_shape) + + def mesh_broadcast( tensor: torch.Tensor, mesh: DeviceMesh, diff --git a/vescale/dtensor/_diff.py b/vescale/dtensor/_diff.py index f87b5a6..fed72c5 100644 --- a/vescale/dtensor/_diff.py +++ b/vescale/dtensor/_diff.py @@ -18,79 +18,56 @@ import functools import os from typing import Callable +import logging -from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode, _push_mode -VESCALE_PARTIAL_MODE = os.environ.get("VESCALE_PARTIAL_MODE", "0") == "1" VESCALE_DISABLE_REDISTRIBUTE = os.environ.get("VESCALE_DISABLE_REDISTRIBUTE", "1") == "1" global VESCALE_SHARDING_SUGGETSION VESCALE_SHARDING_SUGGETSION = [] -def switch_partial_mode(func: Callable): +def dummy_p2p(func: Callable): @functools.wraps(func) def wrap(*args, **kwargs): - global VESCALE_PARTIAL_MODE - if VESCALE_PARTIAL_MODE: - with EnablePartialMode(): - out = func(*args, **kwargs) + global VESCALE_DUMMY_P2P + if VESCALE_DUMMY_P2P: + msg = f"{get_rank()}: {args}" + logging.info(msg) else: + if VESCALE_DUMP_INSTRUCTION: + if vescale_file_to_dump is not None: + vescale_file_to_dump.write(f"=========================\nrank:{get_rank()}: {args}, {kwargs}\n") + vescale_file_to_dump.flush() + msg = f"rank:{get_rank()}: {args}, {kwargs}\n=======================\n" + logging.info(msg) out = func(*args, **kwargs) - return out + if VESCALE_DUMP_INSTRUCTION: + if vescale_file_to_dump is not None: + vescale_file_to_dump.write(f"output: {out}\n") + vescale_file_to_dump.flush() + msg = f"output: {out}\n" + logging.info(msg) + return out return wrap -class EnablePartialMode(TorchDispatchMode): - """ - To enable the DTensor to be PartialSum for performance - By sometimes, we find there have some optimization chance - for partial state, so we enable to get a partial DTensor - by torch ops - - chance one: adjust the reshard AllReduceReassociate - The AllReduceReassociate can be simplify - allreduce(x) + allreduce(y) to allreduce(x + y), - there will be some alllreduce save for partial activation - - Note: - EnablePartialMode only influence the xxx_like op porpagation - rules. if you want this mode affect some other function, maybe - refer to ```switch_partial_mode``` , by wrapper any function - with ```@switch_partial_mode``` there will be also tracked by - EnablePartialMode - - Usage: - ``` - from vescale.dtensor.dispatch import EnablePartialMode - with EnablePartialMode(): - partial_tensor = torch.ones_like(other) - ``` - """ - - @staticmethod - def _enable(): - global VecaleParitalMode - VecaleParitalMode = True - - @staticmethod - def _disable(): - global VecaleParitalMode - VecaleParitalMode = False - - def __enter__(self): - EnablePartialMode._enable() - _push_mode(self, self.__dict__.get("_dispatch_key", None)) - return self +def manage_dump_file(func: Callable): + @functools.wraps(func) + def wrap(*args, **kwargs): + if VESCALE_DUMP_INSTRUCTION: + with open(f"instruction-{get_rank()}.txt", "w+") as file: + global vescale_file_to_dump + vescale_file_to_dump = file + out = func(*args, **kwargs) + else: + out = func(*args, **kwargs) - def __exit__(self, exc_type, exc_val, exc_tb): - EnablePartialMode._disable() - _pop_mode(self.__dict__.get("_dispatch_key", None)) + return out - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - return func(*args, **kwargs) + return wrap class DeferReshardMode: diff --git a/vescale/dtensor/_dispatch_bypass.py b/vescale/dtensor/_dispatch_bypass.py index 133500f..fb7bb37 100644 --- a/vescale/dtensor/_dispatch_bypass.py +++ b/vescale/dtensor/_dispatch_bypass.py @@ -33,6 +33,7 @@ def __init__(self): self.op_handlers = { aten.linear.default: BypassOpDispatch.decompose_handler, aten.is_same_size.default: BypassOpDispatch.is_same_size_handler, + aten.nonzero.default: BypassOpDispatch.nonzero_handler, } def apply( @@ -73,6 +74,31 @@ def is_same_size_handler( rhs = cast(torch.Tensor, args[1]) return lhs.shape == rhs.shape + @staticmethod + def nonzero_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], + ) -> object: + from vescale.dtensor import DTensor + + input_ = kwargs.get("input", args[0]) + assert isinstance(input_, DTensor) + input_spec = input_._spec + all_replicate = all(p.is_replicate() for p in input_spec.placements) + assert all_replicate, "input placement has to be replicate" + input_local = input_._local_tensor + output_local = op_call(input_local) + return DTensor( + local_tensor=output_local, + device_mesh=input_spec.mesh, + placements=input_spec.placements, + shape=output_local.shape, + dtype=output_local.dtype, + requires_grad=output_local.requires_grad, + stride=output_local.stride(), + ) + _bypass_op_dispatch = BypassOpDispatch() @@ -98,7 +124,6 @@ def __init__(self): aten._to_copy.default: BypassOpShardingProp.copy_handler, aten._local_scalar_dense.default: BypassOpShardingProp.scalar_handler, aten.equal.default: BypassOpShardingProp.scalar_handler, - aten.nonzero.default: BypassOpShardingProp.nonzero_handler, } def apply(self, op_info: OpInfo) -> bool: @@ -109,31 +134,6 @@ def apply(self, op_info: OpInfo) -> bool: else: return False - @staticmethod - def nonzero_handler(op_info: OpInfo) -> OutputSharding: - """ - Bypass nonzero because the output shape is dynamic. - We allow only replication on the input/ouput. - """ - op_schema = op_info.schema - input_spec = op_schema.args_schema[0] - all_replicate = all(p.is_replicate() for p in input_spec.placements) - assert all_replicate, "input placement has to be replicate" - input_local = op_info.local_args[0] - output_local = torch.nonzero(input_local) - out_tensor_meta = TensorMeta( - shape=output_local.shape, - stride=output_local.stride(), - dtype=output_local.dtype, - ) - return OutputSharding( - output_spec=DTensorSpec( - mesh=op_info.schema.args_spec[0].mesh, - placements=op_info.schema.args_spec[0].placements, - tensor_meta=out_tensor_meta, - ) - ) - @staticmethod def copy_handler(op_info: OpInfo) -> OutputSharding: op_schema = op_info.schema diff --git a/vescale/dtensor/_utils.py b/vescale/dtensor/_utils.py index 7fcc057..4127a61 100644 --- a/vescale/dtensor/_utils.py +++ b/vescale/dtensor/_utils.py @@ -84,7 +84,7 @@ def compute_local_shape_and_global_offset( for idx, placement in enumerate(placements): mesh_dim_size = mesh.size(idx) - if isinstance(placement, (Shard, InterleavedShard)): + if isinstance(placement, Shard): shard_dim = placement.dim local_offset = [0] * len(global_shape) assert shard_dim < len( diff --git a/vescale/dtensor/dispatch.py b/vescale/dtensor/dispatch.py index ba21802..64d5e63 100644 --- a/vescale/dtensor/dispatch.py +++ b/vescale/dtensor/dispatch.py @@ -18,7 +18,7 @@ import vescale.dtensor.dtensor as dtensor import vescale.dtensor.random as random -from vescale.dtensor._diff import VESCALE_DISABLE_REDISTRIBUTE, switch_partial_mode +from vescale.dtensor._diff import VESCALE_DISABLE_REDISTRIBUTE from vescale.dtensor.device_mesh import DeviceMesh from vescale.dtensor.op_schema import ( OpInfo, @@ -205,7 +205,8 @@ def to_dt(res, spec): spec.mesh, spec.placements, shape=spec.tensor_meta.shape, - dtype=spec.tensor_meta.dtype, + # use the local result dtype, as spec generated by sharding prop might be wrong. + dtype=res.dtype, requires_grad=res.requires_grad, stride=spec.tensor_meta.stride, ) @@ -234,7 +235,6 @@ def to_dt(res, spec): return res -@switch_partial_mode def _operator_dispatch( op_call: torch._ops.OpOverload, args: Tuple[object, ...], @@ -359,6 +359,8 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor: if local_results is None: # None for None return None else: # should return self instead of re-wrapping + args[0]._local_tensor = local_results + args[0]._spec = output_sharding.output_spec return args[0] elif _is_out_variant_op(op_call): # out variant could possibly have multiple out args (i.e. lu_unpack.out) diff --git a/vescale/dtensor/ops/basic_strategy.py b/vescale/dtensor/ops/basic_strategy.py index 2e7309b..803e672 100644 --- a/vescale/dtensor/ops/basic_strategy.py +++ b/vescale/dtensor/ops/basic_strategy.py @@ -9,11 +9,11 @@ ################################################################################ from dataclasses import dataclass -from typing import Dict, List, Tuple +from typing import List, Tuple from vescale.dtensor import DeviceMesh from vescale.dtensor.op_schema import OpStrategy, PlacementStrategy -from vescale.dtensor.placement_types import DTensorSpec, InterleavedShard, Partial, Placement, Replicate, Shard +from vescale.dtensor.placement_types import DTensorSpec, InterleavedShard, Partial, Replicate, Shard @dataclass @@ -85,97 +85,6 @@ def parse_dims(cls, input_dims: List[str], output_dim: str) -> "EinsumDims": ) -""" -for any einsum option there can be conclusion in below ways -a batch matmul peration -C_b_i_j = \\sum{A_b_i_k * B_b_k_j} -mesh shape is (n0,n1) - -# | parallel mapping | outspec | inputspec | cost -1 | i->0, j->1 | RS^0S^1 | RS^0R,RRS^1 | 0 -2 | i->0, k->1 | RS^0R | RS^0S^1,RS^1R | all-reduce(M/n0, 1) -3 | j->0, k->1 | RRS^0 | RRS^1,RS^1S^0 | all-reduce(M/n0, 1) -4 | b->0, i->1 | S^0S^1R | S^0S^1R,S^0RR | 0 -5 | b->0, k->1 | S^0RR | S^0RS^1,S^0S^1R | all-reduce(M/n0, 1) -6 | i->{0, 1} | RS^01R | RS^01R,RRR | 0 -7 | k->{0, 1} | RRR | RRS^01,RS^01R | all-reduce(M, {0, 1}) -""" - - -def deduce_out_mode( - lhs_mode: int, - rhs_mode: int, - edims: EinsumDims, - lhs_mesh_dims_map: Dict[str, int], - rhs_mesh_dims_map: Dict[str, int], - lhs_interleaved_shard_dims: Dict[str, int], - rhs_interleaved_shard_dims: Dict[str, int], -): - split_batch = lhs_mode & 4 or rhs_mode & 4 - split_concat = lhs_mode & 1 - lhs_split_spartial = lhs_mode & 2 - rhs_split_spartial = rhs_mode & 2 - - out_mode = 0 - out_mesh_dim_mapping = {} - out_interleaved_shard_dims = {} - reshard_cost = None - if split_batch: - out_mode |= 4 - for batch_dim in edims.batch_dims: - if batch_dim not in out_mesh_dim_mapping: - # batch dim is not sharded in lhs and rhs - if batch_dim not in lhs_mesh_dims_map and batch_dim not in rhs_mesh_dims_map: - continue - # make sure sharding information of batch_dim stays consistent between lhs and rhs - if (batch_dim not in lhs_mesh_dims_map) or (batch_dim not in rhs_mesh_dims_map): - raise ValueError("batch dim must be sharded in both lhs and rhs") - assert ( - lhs_mesh_dims_map[batch_dim] == rhs_mesh_dims_map[batch_dim] - ), f"batch dim sharding information inconsistent, {lhs_mesh_dims_map[batch_dim]} vs {rhs_mesh_dims_map[batch_dim]}" - if (batch_dim in lhs_interleaved_shard_dims and batch_dim not in rhs_interleaved_shard_dims) or ( - batch_dim in rhs_interleaved_shard_dims and batch_dim not in lhs_interleaved_shard_dims - ): - raise ValueError("batch dim sharding information inconsistent, found InterleavedShard and Shard") - if batch_dim in lhs_interleaved_shard_dims: - assert ( - lhs_interleaved_shard_dims[batch_dim] == rhs_interleaved_shard_dims[batch_dim] - ), f"batch dim sharding information inconsistent, found InterleavedShard({lhs_interleaved_shard_dims[batch_dim]}) vs Interleaved shard({rhs_interleaved_shard_dims[batch_dim]})" - out_mesh_dim_mapping[batch_dim] = lhs_mesh_dims_map[batch_dim] - if batch_dim in lhs_interleaved_shard_dims: - out_interleaved_shard_dims[batch_dim] = lhs_interleaved_shard_dims[batch_dim] - - if split_concat: - # output will be partial - reduce_mappings = {} - for reduce_dim in edims.contracting_dims: - assert len(lhs_mesh_dims_map[reduce_dim]) == len( - rhs_mesh_dims_map[reduce_dim] - ), "reduce dim in different mesh is not allowed" - if reduce_dim in lhs_interleaved_shard_dims and reduce_dim in rhs_interleaved_shard_dims: - assert ( - lhs_interleaved_shard_dims[reduce_dim] == rhs_interleaved_shard_dims[reduce_dim] - ), "reduce dim should be interleaved sharded of same interleaved size" - else: - assert ( - reduce_dim not in lhs_interleaved_shard_dims and reduce_dim not in rhs_interleaved_shard_dims - ), "one reduce dim is interleaved sharded, but the other not" - reduce_mappings[reduce_dim] = lhs_mesh_dims_map[reduce_dim] - reshard_cost = reduce_mappings - - if lhs_split_spartial: - for d in edims.lhs_out_only_dims: - if d in lhs_interleaved_shard_dims: - out_interleaved_shard_dims[d] = lhs_interleaved_shard_dims[d] - out_mode |= 2 - if rhs_split_spartial: - for d in edims.rhs_out_only_dims: - if d in rhs_interleaved_shard_dims: - out_interleaved_shard_dims[d] = rhs_interleaved_shard_dims[d] - out_mode |= 1 - return out_mode, out_interleaved_shard_dims, reshard_cost - - def gen_einsum_strategies( equation: str, mesh: DeviceMesh, @@ -203,28 +112,29 @@ def gen_einsum_strategies( inputs, output = equation.split("->") lhs, rhs = inputs.split(",") - # bitset mode to represent split - lhs_shard_dims = [lhs[shard.dim] for shard in lhs_placements if shard.is_shard() or shard.is_interleaved_shard()] - rhs_shard_dims = [rhs[shard.dim] for shard in rhs_placements if shard.is_shard() or shard.is_interleaved_shard()] - - def generate_interleaved_shard_dims(dims, placements): - interleaved_shard_dims = {} - for p in placements: - if not p.is_interleaved_shard() and not p.is_shard(): - continue - - input_dim = dims[p.dim] - if input_dim not in interleaved_shard_dims: - if p.is_interleaved_shard(): - interleaved_shard_dims[input_dim] = p.interleaved_size - else: - raise ValueError( - "vescale doesn't support mulitiple shard (and one of them is InterleavedShard) of one input dim" - ) - return interleaved_shard_dims - - lhs_interleaved_shard_dims = generate_interleaved_shard_dims(lhs, lhs_placements) - rhs_interleaved_shard_dims = generate_interleaved_shard_dims(rhs, rhs_placements) + # {"b": [(0, S), (1, IS)]} + lhs_shard_dim_infos = {} + rhs_shard_dim_infos = {} + + for i, p in enumerate(lhs_placements): + if not p.is_shard(): + continue + if lhs[p.dim] not in lhs_shard_dim_infos: + lhs_shard_dim_infos[lhs[p.dim]] = {} + if p.is_interleaved_shard(): + lhs_shard_dim_infos[lhs[p.dim]][i] = ("IS", p.interleaved_size) + else: + lhs_shard_dim_infos[lhs[p.dim]][i] = ("S", None) + + for i, p in enumerate(rhs_placements): + if not p.is_shard(): + continue + if rhs[p.dim] not in rhs_shard_dim_infos: + rhs_shard_dim_infos[rhs[p.dim]] = {} + if p.is_interleaved_shard(): + rhs_shard_dim_infos[rhs[p.dim]][i] = ("IS", p.interleaved_size) + else: + rhs_shard_dim_infos[rhs[p.dim]][i] = ("S", None) if linearity: lhs_spec = DTensorSpec(mesh, lhs_placements) @@ -233,93 +143,78 @@ def generate_interleaved_shard_dims(dims, placements): placement = PlacementStrategy(output_spec=out_spec, input_specs=[lhs_spec, rhs_spec]) return OpStrategy([placement]) - def construct_tensor_dim_to_mesh_dim(placements, dims): - maps = {} - for idx, placement in enumerate(placements): - if placement.is_shard() or placement.is_interleaved_shard(): - char = dims[placement.dim] - if char not in maps: - maps[char] = set() - maps[char].add(idx) - return maps - - def deduce_sharding_mode(shard_dim): - mode = 0 - for dim in shard_dim: - if dim in edims.batch_dims: - mode |= 1 << 2 # SRR - if dim in edims.contracting_dims: - mode |= 1 # RRS + out_shard_dim_infos = {} + out_reduce_dim_info = {} + """ + Validation Check And Generate OutShardDimInfo + """ + # 1. same batch and constrating dims + for d in edims.batch_dims + edims.contracting_dims: + if d not in lhs_shard_dim_infos and d not in rhs_shard_dim_infos: + continue + if d not in lhs_shard_dim_infos and d in rhs_shard_dim_infos: + raise ValueError(f"found rhs sharded on {d}, but lhs not") + if d in lhs_shard_dim_infos and d not in rhs_shard_dim_infos: + raise ValueError(f"found lhs sharded on {d}, but rhs not") + assert len(lhs_shard_dim_infos[d]) == len( + rhs_shard_dim_infos[d] + ), "lhs and rhs must be sharded on the same number of mesh dims" + for mesh_dim in lhs_shard_dim_infos[d]: + # assert lhs_shard_dim_infos[d][mesh_dim][0] != "P", "batch or contract dims must not be partial sharded" + assert mesh_dim in rhs_shard_dim_infos[d], f"found lhs sharded on mesh dim @{mesh_dim}, but rhs not" + lp = lhs_shard_dim_infos[d][mesh_dim] + rp = rhs_shard_dim_infos[d][mesh_dim] + assert ( + lp[0] == rp[0] and lp[1] == rp[1] + ), f"lhs and rhs must be samely sharded on mesh dim @{mesh_dim}, found {lp} and {rp}" + + if d in edims.batch_dims: + if d not in out_shard_dim_infos: + out_shard_dim_infos[d] = {} + out_shard_dim_infos[d][mesh_dim] = lp else: - mode |= 1 << 1 # RSR - return mode - - lhs_sharding_map = construct_tensor_dim_to_mesh_dim(lhs_placements, lhs) - rhs_sharding_map = construct_tensor_dim_to_mesh_dim(rhs_placements, rhs) - lhs_shard_mode = deduce_sharding_mode(lhs_shard_dims) - rhs_shard_mode = deduce_sharding_mode(rhs_shard_dims) - out_mode, out_interleaved_shard_dims, reshard_cost = deduce_out_mode( - lhs_shard_mode, - rhs_shard_mode, - edims, - lhs_sharding_map, - rhs_sharding_map, - lhs_interleaved_shard_dims, - rhs_interleaved_shard_dims, - ) + out_reduce_dim_info[mesh_dim] = ("P", None) + + # 2. lhs only dims + for d in edims.lhs_out_only_dims: + if d not in lhs_shard_dim_infos: + continue + out_shard_dim_infos[d] = {} + for mesh_dim in lhs_shard_dim_infos[d]: + out_shard_dim_infos[d][mesh_dim] = lhs_shard_dim_infos[d][mesh_dim] + + # 3. rhs only dims + for d in edims.rhs_out_only_dims: + if d not in rhs_shard_dim_infos: + continue + out_shard_dim_infos[d] = {} + for mesh_dim in rhs_shard_dim_infos[d]: + out_shard_dim_infos[d][mesh_dim] = rhs_shard_dim_infos[d][mesh_dim] + + # 4. no-shard dims + lhs_partial_mesh_dims = lhs_spec.sums + rhs_partial_mesh_dims = rhs_spec.sums + if lhs_partial_mesh_dims and rhs_partial_mesh_dims: + raise ValueError("rhs and lhs can not be both partial") + for mesh_dim in lhs_partial_mesh_dims + rhs_partial_mesh_dims: + out_reduce_dim_info[mesh_dim] = ("P", None) - # not split batch - # RS * SR , SS * SR, RS * SS placements = [Replicate()] * mesh.ndim - if out_mode & 4: - for dim in edims.batch_dims: - if dim in lhs_sharding_map or dim in rhs_sharding_map: - # lhs_sharding_map[dim] = rhs_sharding_map[dim] here. - # it's guaranteed in function `deduce_out_mode`. - for mesh_dim in lhs_sharding_map[dim]: - if dim not in out_interleaved_shard_dims: - placements[mesh_dim] = Shard(output.index(dim)) - else: - placements[mesh_dim] = InterleavedShard( - output.index(dim), interleaved_size=out_interleaved_shard_dims[dim] - ) - - def generate_placement(placements, dim_maps: Dict, type: Placement): - mesh_dims = [] - for dim in dim_maps: - mesh_dim = dim_maps[dim] - mesh_dims.extend(list(mesh_dim)) - for dim in mesh_dims: - placements[dim] = type - return placements - - if reshard_cost is not None: - placements = generate_placement(placements, reshard_cost, Partial()) - - for dim in edims.lhs_out_only_dims: - if dim in lhs_sharding_map: - if dim in lhs_interleaved_shard_dims: - placements = generate_placement( - placements, - {dim: lhs_sharding_map[dim]}, - InterleavedShard(output.index(dim), interleaved_size=out_interleaved_shard_dims[dim]), - ) - else: - placements = generate_placement(placements, {dim: lhs_sharding_map[dim]}, Shard(output.index(dim))) - - for dim in edims.rhs_out_only_dims: - if dim in rhs_sharding_map: - if dim in rhs_interleaved_shard_dims: - placements = generate_placement( - placements, - {dim: rhs_sharding_map[dim]}, - InterleavedShard(output.index(dim), interleaved_size=out_interleaved_shard_dims[dim]), - ) + for d in out_shard_dim_infos: + output_tensor_dim = output.index(d) + for mesh_dim in out_shard_dim_infos[d]: + if out_shard_dim_infos[d][mesh_dim][0] == "S": + placements[mesh_dim] = Shard(output_tensor_dim) + elif out_shard_dim_infos[d][mesh_dim][0] == "IS": + placements[mesh_dim] = InterleavedShard(output_tensor_dim, out_shard_dim_infos[d][mesh_dim][1]) else: - placements = generate_placement(placements, {dim: rhs_sharding_map[dim]}, Shard(output.index(dim))) + pass + for mesh_dim in out_reduce_dim_info: + if out_reduce_dim_info[mesh_dim][0] == "P": + placements[mesh_dim] = Partial() + else: + pass - assert (lhs_shard_mode & 4) == (rhs_shard_mode & 4), "vescale only support both split batch dim" - assert (lhs_shard_mode & 1) == (rhs_shard_mode & 1), "vescale only support both split concat dim" out_spec = DTensorSpec(mesh, tuple(placements)) placement = PlacementStrategy(output_spec=out_spec, input_specs=[lhs_spec, rhs_spec]) return OpStrategy([placement]) diff --git a/vescale/dtensor/ops/math_ops.py b/vescale/dtensor/ops/math_ops.py index 09a2e61..7aafeaf 100644 --- a/vescale/dtensor/ops/math_ops.py +++ b/vescale/dtensor/ops/math_ops.py @@ -31,7 +31,7 @@ register_op_strategy, register_prop_rule, ) -from vescale.dtensor.placement_types import DTensorSpec, Partial, Placement, Replicate, Shard +from vescale.dtensor.placement_types import DTensorSpec, Partial, Placement, Replicate, Shard, InterleavedShard aten = torch.ops.aten @@ -93,13 +93,17 @@ def map_placements_after_reduction( else: assert isinstance(placement, Shard) shard_dim = placement.dim + interleaved_size = getattr(placement, "interleaved_size", None) new_shard_dim = reduction_dims_map[shard_dim] if new_shard_dim == -1 or shard_dim in reduction_dims: # if new_shard_dim collapsed or its in the reduction dims # (i.e. for the case where keepdims=True), we generate partial new_placements.append(Partial(reduction_op)) else: - new_placements.append(Shard(reduction_dims_map[shard_dim])) + if interleaved_size is None: + new_placements.append(Shard(reduction_dims_map[shard_dim])) + else: + new_placements.append(InterleavedShard(reduction_dims_map[shard_dim], interleaved_size)) return tuple(new_placements) diff --git a/vescale/dtensor/ops/pointwise_ops.py b/vescale/dtensor/ops/pointwise_ops.py index e5de8a8..0415514 100644 --- a/vescale/dtensor/ops/pointwise_ops.py +++ b/vescale/dtensor/ops/pointwise_ops.py @@ -52,6 +52,7 @@ linear_pointwise_ops = [ aten.div.Scalar, # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op. + aten.div_.Scalar, aten.to.dtype, aten.add.Tensor, aten.add_.Tensor, @@ -407,15 +408,30 @@ def pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False) -> StrategyType: - # (Hongyu): allow pointwise P mul/div R - if op_schema.op in [aten.mul.Tensor, aten.div.Tensor]: - placements_a = op_schema.args_schema[0].strategies[0].output_spec.placements - if isinstance(op_schema.args_schema[1], float): + # (Hongyu): allow pointwise P mul/div R and R mul P + # (Li): allow pointwise inplace P mul_/div_ R. It's crucial for inplace dropout. + partial_strategy_index = -1 + if op_schema.op in [aten.mul.Tensor, aten.mul_.Tensor, aten.div.Tensor, aten.div_.Tensor]: + partial_a = ( + op_schema.args_schema[0].strategies[0].output_spec.is_partial() + if isinstance(op_schema.args_schema[0], OpStrategy) + else False + ) + partial_b = ( + op_schema.args_schema[1].strategies[0].output_spec.is_partial() + if isinstance(op_schema.args_schema[1], OpStrategy) + else False + ) + if partial_a and partial_b: + linearity = False + elif partial_b and op_schema.op in [aten.div.Tensor, aten.div_.Tensor]: + linearity = False + else: linearity = True - elif isinstance(op_schema.args_schema[1], OpStrategy): - spec_b = op_schema.args_schema[1].strategies[0].output_spec - if len(placements_a) == 1 and placements_a[0].is_partial() and spec_b.is_replicated(): - linearity = True + if partial_a: + partial_strategy_index = 0 + if partial_b: + partial_strategy_index = 1 max_shards_strategy_index = -1 max_shards = -1 @@ -429,6 +445,9 @@ def pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = elif _is_out_variant_op(op_schema.op): # out variant op should follow the out kwarg strategy followed_strategy = op_schema.kwargs_schema["out"] + elif partial_strategy_index != -1: + # follow partial strategy on element-wise mul/div cases + followed_strategy = op_schema.args_schema[partial_strategy_index] else: # normal pointwise op, we choose to follow the arg with # the max shards in case operands needs reshard @@ -462,7 +481,7 @@ def pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = # clear the partial placemnet if op does not support linearity # by default we just replicate the partial, need to see if this # is optimal for all cases - raise RuntimeError("Vescale not support Partial with no linearity op") + raise RuntimeError(f"Vescale not support Partial with no linearity op {op_schema.op}") else: out_placements.append(placement) @@ -488,6 +507,7 @@ def pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = redistribute_cost=None, ) ) + return pointwise_strategy diff --git a/vescale/dtensor/ops/random_ops.py b/vescale/dtensor/ops/random_ops.py index 144bf1b..5b081e1 100644 --- a/vescale/dtensor/ops/random_ops.py +++ b/vescale/dtensor/ops/random_ops.py @@ -13,11 +13,18 @@ from vescale.dtensor import DeviceMesh from vescale.dtensor.op_schema import OpSchema, OpStrategy, PlacementStrategy, StrategyType from vescale.dtensor.ops.utils import register_op_strategy, is_tensor_partial +from vescale.dtensor.placement_types import DTensorSpec, Partial, Replicate aten = torch.ops.aten -@register_op_strategy([aten.normal_.default, aten.uniform_.default]) +@register_op_strategy( + [ + aten.normal_.default, + aten.uniform_.default, + aten.bernoulli_.float, + ] +) def random_op_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: self_strategy = op_schema.args_schema[0] assert isinstance(self_strategy, OpStrategy) @@ -26,14 +33,26 @@ def random_op_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: for arg_strategy in self_strategy.strategies: arg_spec = arg_strategy.output_spec if is_tensor_partial(arg_spec): - # TODO: figure out how inplace random op should behave when it's partial - raise RuntimeError(f"{op_schema.op} with Partial is not supported yet!") - random_strategy.strategies.append(PlacementStrategy(output_spec=arg_spec)) + # if the arg_spec have partial, accept partial + # in the input_specs but output replicate for + # those corresponding mesh dims + + output_spec = DTensorSpec( + mesh=arg_spec.mesh, + placements=tuple(Replicate() if isinstance(p, Partial) else p for p in arg_spec.placements), + ) + random_strategy.strategies.append( + PlacementStrategy( + output_spec=output_spec, + input_specs=(arg_spec,), + ) + ) + else: + random_strategy.strategies.append(PlacementStrategy(output_spec=arg_spec)) return random_strategy -# (Hongyu) allow partial placements for dropout @register_op_strategy(aten.native_dropout.default) def random_op_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: self_strategy = op_schema.args_schema[0] diff --git a/vescale/dtensor/ops/tensor_ops.py b/vescale/dtensor/ops/tensor_ops.py index c28cc5f..1c380fb 100644 --- a/vescale/dtensor/ops/tensor_ops.py +++ b/vescale/dtensor/ops/tensor_ops.py @@ -12,8 +12,8 @@ import warnings import copy +import numpy as np import torch -from torch.utils._python_dispatch import _get_current_dispatch_mode from vescale.dtensor._utils import compute_local_shape from vescale.dtensor.op_schema import ( @@ -23,11 +23,12 @@ PlacementStrategy, StrategyType, OpStrategy, + TupleStrategy, ) -from vescale.dtensor._diff import EnablePartialMode from vescale.dtensor.ops.common_rules import pointwise_rule from vescale.dtensor.ops.utils import ( is_tensor_dim_sharded, + is_tensor_dim_interleaved_sharded, is_tensor_partial, normalize_dim, prod, @@ -117,18 +118,17 @@ def create_like_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: # if the arg_spec have partial, accept partial # in the input_specs but output replicate for # those corresponding mesh dims - enable_partial = False - mode = _get_current_dispatch_mode() - if isinstance(mode, EnablePartialMode): - enable_partial = True output_spec = DTensorSpec( mesh=arg_spec.mesh, - placements=tuple( - Replicate() if (isinstance(p, Partial) and not enable_partial) else p for p in arg_spec.placements - ), + placements=tuple(Replicate() if isinstance(p, Partial) else p for p in arg_spec.placements), + ) + create_like_strategy.strategies.append( + PlacementStrategy( + output_spec=output_spec, + input_specs=(arg_spec,), + ) ) - create_like_strategy.strategies.append(PlacementStrategy(output_spec=output_spec, input_specs=(arg_spec,))) else: create_like_strategy.strategies.append(PlacementStrategy(arg_spec)) @@ -136,27 +136,96 @@ def create_like_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: return create_like_strategy -@register_op_strategy( +@register_prop_rule( [ aten.new_empty.default, aten.new_full.default, aten.new_ones.default, - aten.new_zeros.default, + # aten.new_zeros.default, ], schema_info=RuntimeSchemaInfo(1, ["dtype"]), ) -def new_factory_strategy(mesh: DeviceMesh, _) -> StrategyType: +def new_factory_rule(op_schema: OpSchema) -> OutputSharding: # TODO: maybe we should generate all possible shardings intead of just stay # replicated for new factory methods - replica_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) - return OpStrategy([PlacementStrategy(replica_spec)]) + args = op_schema.args_schema + input_spec = args[0] + assert isinstance(input_spec, DTensorSpec) + mesh = input_spec.mesh + output_shape = args[1] + + # has partial spec + if any(p.is_partial() for p in input_spec.placements): + raise RuntimeError("constrcuting partial tensors using new-factory methods is ambigious") + + # no shard spec + if all(not p.is_shard() for p in input_spec.placements): + replica_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) + return OutputSharding(replica_spec) + # has shard spec, we refer to view op to do the sharding prop. + assert ( + input_spec.tensor_meta is not None + ), "tensor meta must not be None if you are constructing a sharded tensor using `new_zeros` or something like that" + original_numel = prod(input_spec.tensor_meta.shape) + target_numel = prod(output_shape) + assert original_numel == target_numel, "for now, we only support the same numel in new_factory methods" + + from vescale.dtensor.ops.vescale_view_ops import vescale_view_rule_prop, ops + + spec = ops[torch.Tensor.view] + output_sharding = vescale_view_rule_prop(op_schema=op_schema, spec=spec) + return output_sharding + + +# (Hongyu): support partial new_zeros +@register_prop_rule( + aten.new_zeros.default, + schema_info=RuntimeSchemaInfo(1, ["dtype"]), +) +def new_zeros_rule(op_schema: OpSchema) -> OutputSharding: + args = op_schema.args_schema + input_spec, output_shape = args[0], args[1] + assert isinstance(input_spec, DTensorSpec) + mesh = input_spec.mesh + output_stride = list(np.cumprod(output_shape[::-1])[:-1][::-1]) + output_stride.append(1) + + if input_spec.is_partial(): + partial_spec = DTensorSpec( + mesh=mesh, + placements=input_spec.placements, + tensor_meta=TensorMeta( + torch.Size(output_shape), + tuple(output_stride), + input_spec.tensor_meta.dtype, + ), + ) + return OutputSharding(partial_spec) + + # no shard spec + if all(not p.is_shard() for p in input_spec.placements): + replica_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) + return OutputSharding(replica_spec) + # has shard spec, we refer to view op to do the sharding prop. + assert ( + input_spec.tensor_meta is not None + ), "tensor meta must not be None if you are constructing a sharded tensor using `new_zeros` or something like that" + original_numel = prod(input_spec.tensor_meta.shape) + target_numel = prod(output_shape) + assert original_numel == target_numel, "for now, we only support the same numel in new_factory methods" + + from vescale.dtensor.ops.vescale_view_ops import vescale_view_rule_prop, ops + + spec = ops[torch.Tensor.view] + output_sharding = vescale_view_rule_prop(op_schema=op_schema, spec=spec) + return output_sharding @register_prop_rule( aten.new_empty_strided.default, schema_info=RuntimeSchemaInfo(1, ["dtype"]), ) -def new_empty_strided_rule(op_schema: OpSchema) -> StrategyType: +def new_empty_strided_rule(op_schema: OpSchema) -> OutputSharding: # TODO: maybe we should generate all possible shardings intead of just stay # replicated for new factory methods args = op_schema.args_schema @@ -263,16 +332,17 @@ def gen_bucketize_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyTyp return bucketize_strategy -@register_op_strategy(aten.slice.Tensor, schema_info=RuntimeSchemaInfo(1)) -def gen_slice_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: +# NOTE: we change to use rule-based way because we want to change non tensor args +@register_prop_rule(aten.slice.Tensor, schema_info=RuntimeSchemaInfo(1)) +def prop_slice(op_schema: OpSchema) -> OutputSharding: """ forwards all shardings except the slice dimension. """ defaults = (None, 0, None, None, 1) - input_strategy, dim, start, end, step = op_schema.args_schema + defaults[len(op_schema.args_schema) :] - assert isinstance(input_strategy, OpStrategy) - input_shape = input_strategy.output_shape - input_ndim = input_strategy.output_ndim + input_spec, dim, start, end, step = op_schema.args_schema + defaults[len(op_schema.args_schema) :] + assert isinstance(input_spec, DTensorSpec) + input_shape = input_spec.tensor_meta.shape + input_ndim = len(input_shape) assert isinstance(dim, int) if start is None: start = 0 @@ -282,6 +352,8 @@ def gen_slice_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: assert isinstance(end, int) assert isinstance(step, int) + mesh = input_spec.mesh + # normalize args slice_dim = normalize_dim(dim, input_ndim) start = normalize_dim(start, input_shape[dim]) @@ -291,32 +363,56 @@ def gen_slice_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: # calculate slice tensor meta output_shape_list = list(input_shape) output_shape_list[dim] = end - start - out_tensor_meta = TensorMeta( - shape=torch.Size(output_shape_list), - stride=input_strategy.output_stride, - dtype=input_strategy.output_dtype, - ) - - slice_strategy = OpStrategy([]) - for arg_strategy in input_strategy.strategies: - arg_spec = arg_strategy.output_spec - if not is_tensor_dim_sharded(arg_spec, dim=slice_dim) or redundant_slice: - # only add the strategy if the slice dim is not sharded - out_spec = DTensorSpec(mesh, arg_spec.placements, out_tensor_meta) - slice_strategy.strategies.append(PlacementStrategy(output_spec=out_spec)) - if not slice_strategy.strategies: - # if all strategies are filtered out, unsharding all specs on slice dim - # of the input strategy, and use that as the op strategy - for arg_strategy in input_strategy.strategies: - arg_spec = arg_strategy.output_spec - unshard_spec = DTensorSpec( - mesh, - unshard_tensor_dim(arg_spec.placements, dim=slice_dim), - out_tensor_meta, + if not is_tensor_dim_sharded(input_spec, dim=slice_dim) or redundant_slice: + # only add the strategy if the slice dim is not sharded + out_spec = DTensorSpec(mesh, input_spec.placements) + return OutputSharding(output_spec=out_spec) + if is_tensor_dim_interleaved_sharded(input_spec, dim=slice_dim): + interleaved_size = None + interleaved_shard_mesh_dim = None + for i, p in enumerate(input_spec.placements): + if p.is_interleaved_shard(dim=slice_dim): + if interleaved_size is None: + interleaved_size = p.interleaved_size + interleaved_shard_mesh_dim = i + else: + raise NotImplementedError( + "for now, we don't support slice tensor along dim which is interleaved sharded two or more times" + ) + interleaved_unit_size = input_spec.tensor_meta.shape[slice_dim] // interleaved_size + if step != 1: + raise NotImplementedError("for now, we only support constant 1 step in slice op") + + slice_size = end - start + if slice_size % interleaved_unit_size != 0 or start % interleaved_unit_size != 0: + raise NotImplementedError( + "for now, we only support slice boundary strictly aligning with the sharding spec" ) - slice_strategy.strategies.append(PlacementStrategy(output_spec=unshard_spec)) - return slice_strategy + new_placements = list(copy.deepcopy(input_spec.placements)) + new_interleaved_size = slice_size // interleaved_unit_size + if new_interleaved_size == 1: + new_placements[interleaved_shard_mesh_dim] = Shard(slice_dim) + else: + new_placements[interleaved_shard_mesh_dim] = InterleavedShard(slice_dim, new_interleaved_size) + out_spec = DTensorSpec(mesh, new_placements) + return OutputSharding( + output_spec=out_spec, + schema_suggestions=[ + OpSchema( + op_schema.op, + args_schema=( + input_spec, + dim, + start // mesh.size(interleaved_shard_mesh_dim), + end // mesh.size(interleaved_shard_mesh_dim), + step, + ), + kwargs_schema=op_schema.kwargs_schema, + ) + ], + needs_redistribute=True, + ) @register_op_strategy([aten._local_scalar_dense.default]) @@ -391,14 +487,15 @@ def scatter_value(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: ) -@register_op_strategy([aten.index_put_.default, aten.index_put.default]) -def index_put(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: - """Set the Output with the index sharding""" - value = op_schema.args_schema[2] - - value_spec: DTensorSpec = value.strategies[0].output_spec - output_spec = DTensorSpec(mesh, tuple(value_spec.placements)) - return OpStrategy([PlacementStrategy(output_spec)]) +# (Hongyu): allow partial index_put here +@register_prop_rule([aten.index_put_.default, aten.index_put.default]) +def index_put_rule(op_schema: OpSchema) -> OutputSharding: + src_spec: DTensorSpec = op_schema.args_schema[0] + value_spec: DTensorSpec = op_schema.args_schema[2] + assert ( + src_spec.placements == value_spec.placements + ), "Currently we only allow equal placements for src and value in index_put op" + return OutputSharding(src_spec) @register_op_strategy([aten.constant_pad_nd.default]) @@ -602,7 +699,7 @@ def place(vp: Placement, ip: Placement) -> Placement: return result -@register_prop_rule([aten.cat.default, aten.stack.default], schema_info=RuntimeSchemaInfo(1, needs_pytree=True)) +@register_prop_rule([aten.cat.default], schema_info=RuntimeSchemaInfo(1, needs_pytree=True)) def cat_rule(op_schema: OpSchema) -> OutputSharding: # torch.cat requires all tensors must either have the same shape (except # in the concatenating dimension) or be "empty". "Empty" here strictly means @@ -766,6 +863,102 @@ def is_empty(spec: DTensorSpec) -> bool: ) +def _derive_follow_placements_from_tuple_strategy( + tuple_strategy: TupleStrategy, +) -> Sequence[Placement]: + """ + derive the placements to follow from the tuple strategy, mainly used by + aten.stack, aten.cat, where each operand have the same shape, and correspondingly + expecting the same sharding + """ + + def merge_placement(cur_placement: Placement, new_placement: Placement) -> Placement: + # semantic if we already have a follow placement, we + # check each placement for the current arg placement + # to see if we want to merge/adjust the placement to follow + # the priority: Partial -> Shard -> Replicate + if cur_placement == new_placement: + return cur_placement + + if cur_placement.is_partial(): + if new_placement.is_shard(): + # follow new placement + return new_placement + elif new_placement.is_partial(): + # different partial types, we can't merge and have to replicate all here + return Replicate() + else: + # follow partial + return cur_placement + elif cur_placement.is_shard(): + if new_placement.is_shard(): + # cur/new placement are different sharding (i.e. different shard dim) + # currently fallback to replicate all args + return Replicate() + else: + # for partial/replicate, follow the current shard placement + return cur_placement + else: + # current replicate, just follow new placement + return new_placement + + follow_placements: Optional[List[Placement]] = None + for arg_strategy in tuple_strategy.childs: + assert isinstance(arg_strategy, OpStrategy) + for placement_strategy in arg_strategy.strategies: + arg_placements = placement_strategy.output_spec.placements + if follow_placements is None: + follow_placements = list(arg_placements) + continue + mesh_ndim = len(follow_placements) + assert follow_placements is not None + for mesh_idx in range(mesh_ndim): + # merge placements with the priority + follow_placements[mesh_idx] = merge_placement(follow_placements[mesh_idx], arg_placements[mesh_idx]) + assert follow_placements is not None, "follow placements should not be None!" + return follow_placements + + +def normalize_shard_for_stack(placements: Sequence[Placement], insert_dim: int = 0) -> Sequence[Placement]: + # stack op would "insert" new dim, so all sharded dim >= the inserted dim need to + # be normalized with the new Shard placement + normalized_placements: List[Placement] = [] + for placement in placements: + if isinstance(placement, Shard) and placement.dim >= insert_dim: + normalized_placements.append(Shard(placement.dim + 1)) + else: + normalized_placements.append(placement) + return normalized_placements + + +@register_op_strategy(aten.stack.default, RuntimeSchemaInfo(1, needs_pytree=True)) +def stack_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + args_schema = op_schema.args_schema + input_tuple_strategy = args_schema[0] + assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}" + first_input_strategy = input_tuple_strategy.childs[0] + assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}" + common_input_ndim = first_input_strategy.output_ndim + dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0 + # normalize the dim to be within the common input ndim + dim = normalize_dim(dim, common_input_ndim) + + follow_placements = _derive_follow_placements_from_tuple_strategy(input_tuple_strategy) + follow_placements = normalize_shard_for_stack(follow_placements, dim) + + # create op strategy base on the follow placements + op_strategy = OpStrategy([]) + + input_specs = tuple(DTensorSpec(mesh, tuple(follow_placements)) for _ in range(len(input_tuple_strategy.childs))) + op_strategy.strategies.append( + PlacementStrategy( + output_spec=DTensorSpec(mesh, tuple(follow_placements)), + input_specs=input_specs, + ) + ) + return op_strategy + + @register_prop_rule([aten.split.Tensor, aten.split_with_sizes.default], schema_info=RuntimeSchemaInfo(1)) def split_rule(op_schema: OpSchema) -> OutputSharding: output_spec_list: List[DTensorSpec] = [] diff --git a/vescale/dtensor/ops/vescale_view_ops.py b/vescale/dtensor/ops/vescale_view_ops.py index 5abe592..a348bc1 100644 --- a/vescale/dtensor/ops/vescale_view_ops.py +++ b/vescale/dtensor/ops/vescale_view_ops.py @@ -13,10 +13,12 @@ InterleavedShard and Shard have inconsistent behavior. """ +import functools from typing import Callable, Optional, Sequence, Set, Tuple, cast import torch from torch import Tensor +import torch.distributed from vescale.dtensor._utils import compute_local_shape from vescale.dtensor.op_schema import OpSchema, OutputSharding, RuntimeSchemaInfo @@ -105,18 +107,24 @@ def collect_used_inputs(cmd: DimSpec) -> None: for mesh_dim in shard_map_from_input_dim_to_mesh_dim[sharded_input_dims[0]]: new_placements[mesh_dim] = Shard(out_dim) # interleaved shard on not the first of collapsed input dimensions + # multiple S on one tensor dim will be transformed to a IS followed by many S + # e.g., [S(1), S(1)] on -> [IS(1), S(1)] else: - assert ( - len(shard_map_from_input_dim_to_mesh_dim[sharded_input_dims[0]]) == 1 - ), "We now only support interleaved sharding on a single mesh dimension" - mesh_dim = shard_map_from_input_dim_to_mesh_dim[sharded_input_dims[0]][0] + # assert ( + # len(shard_map_from_input_dim_to_mesh_dim[sharded_input_dims[0]]) == 1 + # ), "We now only support interleaved sharding on a single mesh dimension" + # mesh_dim = shard_map_from_input_dim_to_mesh_dim[sharded_input_dims[0]][0] interleaved_size = 1 for id in cmd.input_dims: if id.input_dim == sharded_input_dims[0]: break else: interleaved_size *= local_in_shape[id.input_dim] - new_placements[mesh_dim] = InterleavedShard(out_dim, interleaved_size) + new_placements[shard_map_from_input_dim_to_mesh_dim[sharded_input_dims[0]][0]] = InterleavedShard( + out_dim, interleaved_size + ) + for mesh_dim in shard_map_from_input_dim_to_mesh_dim[sharded_input_dims[0]][1:]: + new_placements[mesh_dim] = Shard(out_dim) # none of collapsed input dims is sharded. Do nothing. elif isinstance(cmd, Split): @@ -189,20 +197,20 @@ def collect_used_inputs(cmd: DimSpec) -> None: continue if interleaved_size * sharded_dim_size < prev_size: continue - if interleaved_size * sharded_dim_size <= prev_size * out_dim_size: - if interleaved_size % prev_size != 0: - needs_reshard = True - continue - assert ( - len(shard_map_from_input_dim_to_mesh_dim[only_sharded_input_dim]) == 1 - ), "Interleaved sharding only supports one dimension being sharded." - for mesh_dim, mesh_dim_size in enumerate(mesh_sizes): - shardable_dims[only_sharded_input_dim, mesh_dim] = out_dim_size % mesh_dim_size == 0 - new_placements[shard_map_from_input_dim_to_mesh_dim[only_sharded_input_dim][0]] = ( - InterleavedShard(out_dim, interleaved_size // prev_size) - ) - else: + # if interleaved_size * sharded_dim_size <= prev_size * out_dim_size: + if interleaved_size <= prev_size: + continue + if interleaved_size % prev_size != 0: needs_reshard = True + continue + assert ( + len(shard_map_from_input_dim_to_mesh_dim[only_sharded_input_dim]) == 1 + ), "Interleaved sharding only supports one dimension being sharded." + for mesh_dim, mesh_dim_size in enumerate(mesh_sizes): + shardable_dims[only_sharded_input_dim, mesh_dim] = out_dim_size % mesh_dim_size == 0 + new_placements[shard_map_from_input_dim_to_mesh_dim[only_sharded_input_dim][0]] = ( + InterleavedShard(out_dim, interleaved_size // prev_size) + ) else: raise RuntimeError("Unkown input dim for Split.") elif isinstance(cmd, Singleton): @@ -222,14 +230,18 @@ def collect_used_inputs(cmd: DimSpec) -> None: def remove_interleaved_shard(*args_schema, **kwargs_schema): def replace_interleaved_shard(spec: DTensorSpec) -> DTensorSpec: - # new_spec = copy.deepcopy(spec) new_spec = DTensorSpec(spec.mesh, spec.placements, spec.tensor_meta) placements = spec.placements interleaved_shard_dims = { placement.dim: placement for placement in placements if isinstance(placement, InterleavedShard) } + if not interleaved_shard_dims: return new_spec + prev_dim_interleaved_sharded_times = [0] * spec.ndim + for d in interleaved_shard_dims: + for fd in range(d + 1, spec.ndim): + prev_dim_interleaved_sharded_times[fd] += 1 if spec.tensor_meta is not None: new_shape = [] new_stride = [] @@ -247,13 +259,19 @@ def replace_interleaved_shard(spec: DTensorSpec) -> DTensorSpec: new_stride.append(spec.tensor_meta.stride[d]) new_spec.tensor_meta = TensorMeta(shape=new_shape, stride=new_stride, dtype=spec.tensor_meta.dtype) new_placements = [] - sharded_dim_offset = 1 for i, placement in enumerate(placements): + if not placement.is_shard(): + new_placements.append(placement) + continue + prev_is_dims = prev_dim_interleaved_sharded_times[placement.dim] + for j in range(i): + p = placements[j] + if p.is_interleaved_shard() and p.dim == placement.dim: + prev_is_dims += 1 if isinstance(placement, InterleavedShard): - new_placements.append(Shard(placement.dim + sharded_dim_offset)) - sharded_dim_offset += 1 + new_placements.append(Shard(placement.dim + prev_is_dims + 1)) else: - new_placements.append(placement) + new_placements.append(Shard(placement.dim + prev_is_dims)) new_spec.placements = tuple(new_placements) return new_spec @@ -266,40 +284,22 @@ def replace_interleaved_shard(spec: DTensorSpec) -> DTensorSpec: return new_args_schema, new_kwargs_schema -def register_rule_for_view_and_reshape_ops( - aten_op_overload: torch._ops.OpOverload, - local_op_name: Callable[..., torch.Tensor], - schema_info: Optional[RuntimeSchemaInfo] = None, -) -> None: - spec: Op = ops[local_op_name] +# lift up this function as a utility, one should provide a Op as argument +def vescale_view_rule_prop(op_schema: OpSchema, spec: Op) -> OutputSharding: + new_args_schema, new_kwargs_schema = remove_interleaved_shard(*op_schema.args_schema, **op_schema.kwargs_schema) + rules = spec.dim_map(*new_args_schema, **new_kwargs_schema) + input_dtensor_spec = cast(DTensorSpec, new_args_schema[0]) + mesh = input_dtensor_spec.mesh - @register_prop_rule(aten_op_overload, schema_info=schema_info) - def vescale_view_rule_prop(op_schema: OpSchema) -> OutputSharding: - new_args_schema, new_kwargs_schema = remove_interleaved_shard(*op_schema.args_schema, **op_schema.kwargs_schema) - rules = spec.dim_map(*new_args_schema, **new_kwargs_schema) - input_dtensor_spec = cast(DTensorSpec, new_args_schema[0]) - mesh = input_dtensor_spec.mesh - - assert isinstance(input_dtensor_spec, DTensorSpec), "Expected first input to be a DTensorSpec" - global_in_shape = input_dtensor_spec.shape - assert global_in_shape is not None, "Shape required." - - if TORCH_VERSION_BIGGER_THAN_2_2: - from torch._subclasses.fake_tensor import unset_fake_temporarily - from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing - - with disable_proxy_modes_tracing(), unset_fake_temporarily(): - ( - global_out_shape, - shard_out, - shardable_dims, - ) = propagate_shape_and_sharding( - input_dtensor_spec.placements, - tuple(global_in_shape), - rules, - tuple(mesh.mesh.shape), - ) - else: + assert isinstance(input_dtensor_spec, DTensorSpec), "Expected first input to be a DTensorSpec" + global_in_shape = input_dtensor_spec.shape + assert global_in_shape is not None, "Shape required." + + if TORCH_VERSION_BIGGER_THAN_2_2: + from torch._subclasses.fake_tensor import unset_fake_temporarily + from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing + + with disable_proxy_modes_tracing(), unset_fake_temporarily(): ( global_out_shape, shard_out, @@ -310,64 +310,156 @@ def vescale_view_rule_prop(op_schema: OpSchema) -> OutputSharding: rules, tuple(mesh.mesh.shape), ) + else: + ( + global_out_shape, + shard_out, + shardable_dims, + ) = propagate_shape_and_sharding( + input_dtensor_spec.placements, + tuple(global_in_shape), + rules, + tuple(mesh.mesh.shape), + ) + + if shard_out is not None: + # no reshard needed + output_dtensor_spec = DTensorSpec(mesh=mesh, placements=tuple(shard_out)) + + # We only need the local shape to lower the call into the local op + args = op_schema.args_schema + shape_argnum = spec.shape_argnum + if shape_argnum is not None: + # compute the local shape from the global shape, then return + # a resharding even if we don't really reshard, the only reason + # for this type of resharding is to lower the global shape to + # local shape + local_out_shape = compute_local_shape(list(global_out_shape), mesh, shard_out) + + suggested_schema = OpSchema( + op=op_schema.op, + args_schema=args[:shape_argnum] + (tuple(local_out_shape),) + args[shape_argnum + 1 :], + kwargs_schema=op_schema.kwargs_schema, + ) + return OutputSharding( + output_spec=output_dtensor_spec, + schema_suggestions=[suggested_schema], + needs_redistribute=True, + ) + + return OutputSharding(output_spec=output_dtensor_spec) - if shard_out is not None: - # no reshard needed - output_dtensor_spec = DTensorSpec(mesh=mesh, placements=tuple(shard_out)) - - # We only need the local shape to lower the call into the local op - args = op_schema.args_schema - shape_argnum = spec.shape_argnum - if shape_argnum is not None: - # compute the local shape from the global shape, then return - # a resharding even if we don't really reshard, the only reason - # for this type of resharding is to lower the global shape to - # local shape - local_out_shape = compute_local_shape(list(global_out_shape), mesh, shard_out) - - suggested_schema = OpSchema( + else: + # TODO: optimize this. we shouldn't simply blindly replicate + # unshardable dims ... + # FIXME: this can be wrong for situations where we have + # [Shard(0), Shard(0)] + # NOTE: generating suggested_placments for InterleavedShard is complex. + # Just Replicate tensor if it's Sharded. + suggested_placements = [ + p if not isinstance(p, Shard) else Replicate() for _, p in enumerate(input_dtensor_spec.placements) + ] + return OutputSharding( + output_spec=None, + schema_suggestions=[ + OpSchema( op=op_schema.op, - args_schema=args[:shape_argnum] + (tuple(local_out_shape),) + args[shape_argnum + 1 :], + args_schema=( + DTensorSpec( + placements=tuple(suggested_placements), + mesh=input_dtensor_spec.mesh, + tensor_meta=input_dtensor_spec.tensor_meta, + ), + ) + + op_schema.args_schema[1:], kwargs_schema=op_schema.kwargs_schema, ) - return OutputSharding( - output_spec=output_dtensor_spec, - schema_suggestions=[suggested_schema], - needs_redistribute=True, - ) + ], + needs_redistribute=True, + ) - return OutputSharding(output_spec=output_dtensor_spec) - else: - # TODO: optimize this. we shouldn't simply blindly replicate - # unshardable dims ... - # FIXME: this can be wrong for situations where we have - # [Shard(0), Shard(0)] - # NOTE: generating suggested_placments for InterleavedShard is complex. - # Just Replicate tensor if it's Sharded. - suggested_placements = [ - p if not isinstance(p, Shard) else Replicate() for _, p in enumerate(input_dtensor_spec.placements) - ] - return OutputSharding( - output_spec=None, - schema_suggestions=[ - OpSchema( - op=op_schema.op, - args_schema=( - DTensorSpec( - placements=tuple(suggested_placements), - mesh=input_dtensor_spec.mesh, - tensor_meta=input_dtensor_spec.tensor_meta, - ), - ) - + op_schema.args_schema[1:], - kwargs_schema=op_schema.kwargs_schema, - ) - ], - needs_redistribute=True, - ) +def register_rule_for_view_and_reshape_ops( + aten_op_overload: torch._ops.OpOverload, + local_op_name: Callable[..., torch.Tensor], + schema_info: Optional[RuntimeSchemaInfo] = None, +) -> None: + spec: Op = ops[local_op_name] + register_prop_rule(aten_op_overload, schema_info=schema_info)(functools.partial(vescale_view_rule_prop, spec=spec)) register_rule_for_view_and_reshape_ops(aten.view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1)) register_rule_for_view_and_reshape_ops(aten.reshape.default, torch.reshape, schema_info=RuntimeSchemaInfo(1)) register_rule_for_view_and_reshape_ops(aten._unsafe_view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1)) + + +def _check_tensor_contiguous(shape, stride): + if stride[-1] != 1: + return False + for i in range(len(stride) - 1): + if stride[i] != stride[i + 1] * shape[i + 1]: + return False + + return True + + +def _construct_contiguous_stride(shape): + stride = [1] * len(shape) + for i in reversed(range(len(shape) - 1)): + stride[i] = stride[i + 1] * shape[i + 1] + return stride + + +@register_prop_rule(aten.as_strided.default, schema_info=RuntimeSchemaInfo(1)) +def prop_as_strided_rule(op_schema: OpSchema) -> OutputSharding: + args_schema = op_schema.args_schema + input_spec = args_schema[0] + output_shape = args_schema[1] + output_stride = args_schema[2] + memory_offset = args_schema[3] + + assert isinstance(input_spec, DTensorSpec) + assert memory_offset == 0, "for now, we only support 0 offset" + + assert _check_tensor_contiguous(output_shape, output_stride), "for now, we only support contiguous output" + + assert input_spec.tensor_meta is not None + input_stride = input_spec.tensor_meta.stride + input_shape = input_spec.tensor_meta.shape + assert _check_tensor_contiguous(input_shape, input_stride), "for now, we only support contiguous input" + + # we treat as_strided as view op. + spec = ops[Tensor.view] + new_op_schema = OpSchema( + op_schema.op, args_schema=(args_schema[0], args_schema[1]), kwargs_schema=op_schema.kwargs_schema + ) + output_sharding = vescale_view_rule_prop(op_schema=new_op_schema, spec=spec) + + # fail + if output_sharding.output_spec is None: + suggest_op_schema = output_sharding.schema_suggestions[0] + return OutputSharding( + output_spec=None, + schema_suggestions=[ + OpSchema( + op=op_schema.op, + args_schema=(suggest_op_schema.args_schema[0],) + op_schema.args_schema[1:], + kwargs_schema=op_schema.kwargs_schema, + ) + ], + needs_redistribute=True, + ) + if output_sharding.needs_redistribute: + suggest_op_schema = output_sharding.schema_suggestions[0] + local_shape = suggest_op_schema.args_schema[1] + local_stride = _construct_contiguous_stride(local_shape) + new_suggest_op_schema = OpSchema( + op=op_schema.op, + args_schema=suggest_op_schema.args_schema + (local_stride, 0), + kwargs_schema=op_schema.kwargs_schema, + ) + + return OutputSharding( + output_spec=output_sharding.output_spec, schema_suggestions=[new_suggest_op_schema], needs_redistribute=True + ) + return output_sharding diff --git a/vescale/dtensor/placement_types.py b/vescale/dtensor/placement_types.py index 7a1d1b8..5332102 100644 --- a/vescale/dtensor/placement_types.py +++ b/vescale/dtensor/placement_types.py @@ -485,14 +485,16 @@ def dim_map(self) -> List[int]: # and int >=0 represent shard on that device mesh dim r = [-1] * self.ndim for i, placement in enumerate(self.placements): - if placement.is_shard() or placement.is_interleaved_shard(): + if placement.is_shard(): shard_dim = placement.dim - if r[shard_dim] > -1: - raise ValueError( - f"Tensor dim {shard_dim} is already sharded on mesh dim {r[shard_dim]}," - " DTensor operator implementation does not support things like hybrid" - " sharding strategies yet (i.e. [Shard(0), Shard(0)])" - ) + # NOTE: this might lead to other problems, pay attention. + # relax this check, allow shard one tensor dim twice. + # if r[shard_dim] > -1: + # raise ValueError( + # f"Tensor dim {shard_dim} is already sharded on mesh dim {r[shard_dim]}," + # " DTensor operator implementation does not support things like hybrid" + # " sharding strategies yet (i.e. [Shard(0), Shard(0)])" + # ) r[shard_dim] = i return r @@ -553,3 +555,9 @@ def is_replicated(self): return True if the current DTensorSpec replicates on all mesh dims (devices) """ return all(placement.is_replicate() for placement in self.placements) + + def is_partial(self): + """ + return True if the current DTensorSpec is partial on all mesh dims (devices) + """ + return len(self.placements) == 1 and self.placements[0].is_partial() diff --git a/vescale/dtensor/random.py b/vescale/dtensor/random.py index d2e5767..f6bbf69 100644 --- a/vescale/dtensor/random.py +++ b/vescale/dtensor/random.py @@ -24,7 +24,7 @@ _rng_tracker: Optional["RNGStateTracker"] = None -USE_THREAD_RNG_TRACKER = os.environ.get("VESCALE_SINGLE_DEVICE_RAND", "0") == "1" +USE_THREAD_RNG_TRACKER = os.environ.get("VESCALE_SINGLE_DEVICE_RAND", "1") == "1" def init_vescale_rng_tracker(device_type: str = "cuda"): diff --git a/vescale/dtensor/redistribute.py b/vescale/dtensor/redistribute.py index 7f22ba1..ee80dde 100644 --- a/vescale/dtensor/redistribute.py +++ b/vescale/dtensor/redistribute.py @@ -12,7 +12,6 @@ import torch import torch.distributed.distributed_c10d as c10d -from torch.utils._python_dispatch import _get_current_dispatch_mode import vescale.dtensor.dtensor as dtensor from vescale.dtensor._collective_utils import ( @@ -21,9 +20,9 @@ mesh_broadcast, mesh_reduce_scatter, mesh_scatter, + mesh_all_to_all_single, wait, ) -from vescale.dtensor._diff import EnablePartialMode, switch_partial_mode from vescale.dtensor.device_mesh import DeviceMesh from vescale.dtensor.op_schema import DTensorSpec from vescale.dtensor.placement_types import InterleavedShard, Partial, Placement, Replicate, Shard @@ -45,6 +44,8 @@ def _replicate_then_shard(val: _PlacementItem) -> int: return 0 +# NOTE: we don't need _decompose_reshard anymore, but we still keep this function +# in case of future usage. def _decompose_reshard(val: List[_PlacementItem]) -> List[_PlacementItem]: """ Decompose Si -> Sj into Si -> R -> Sj @@ -216,7 +217,6 @@ def _reduce_scatter_to_shard_with_pad( return output -@switch_partial_mode def redistribute_local_tensor( local_tensor: torch.Tensor, current_spec: DTensorSpec, @@ -244,7 +244,6 @@ def redistribute_local_tensor( current_placements = current_spec.placements target_placements = target_spec.placements sorted_placements = list(enumerate(zip(current_placements, target_placements))) - sorted_placements = _decompose_reshard(sorted_placements) sorted_placements.sort(key=_replicate_then_shard) for i, (current, target) in sorted_placements: @@ -350,23 +349,41 @@ def redistribute_local_tensor( elif current.is_interleaved_shard(): raise NotImplementedError("Redistribution from InterleavedShard to Shard is not suported") else: - # NOTE: this case shouldn't hit _decompose_sharding, decompose sharding should - # decompose Shard(0) -> Shard(1) into Shard(0) -> Replicate -> Shard(1) assert current.is_shard(), f"Current placement should be shard but found {current}" shard_spec = cast(Shard, current) - if shard_spec.dim != target_placement.dim: - # TODO: enable this with all_to_all - raise NotImplementedError("Changing sharding dim is not supported yet!") + if shard_spec.dim == target_placement.dim: + new_local_tensor = local_tensor + continue + if ( + local_tensor.size(target_placement.dim) % device_mesh.size(dim=i) != 0 + or current_spec.shape[shard_spec.dim] % device_mesh.size(dim=i) != 0 + ): + # detect uneven shard on the target dim, fall back to decomposition impl. + # e.g., decompose Shard(0) -> Shard(1) into Shard(0) -> Replicate -> Shard(1) + new_local_tensor = _reshard_to_replicate_with_pad_one_dim( + local_tensor, current_spec.shape, device_mesh, i, shard_spec.dim + ) + shards, _ = target_placement._split_tensor( + new_local_tensor, + num_chunks=device_mesh.size(dim=i), + with_padding=False, + contiguous=False, + ) + new_local_tensor = shards[my_coordinate[i]].clone() + continue + new_local_tensor = mesh_all_to_all_single( + local_tensor, + mesh=device_mesh, + original_shard_dim=shard_spec.dim, + target_shard_dim=target_placement.dim, + mesh_dim=i, + async_op=False, # all-to-all is sync + ) elif target.is_partial(): if current.is_partial(): - mode = _get_current_dispatch_mode() - if isinstance(mode, EnablePartialMode): - # P -> P - new_local_tensor = local_tensor - else: - # P -> R - partial_spec = cast(Partial, current) - new_local_tensor = mesh_all_reduce(local_tensor, device_mesh, partial_spec.reduce_op, i) + # P -> R + partial_spec = cast(Partial, current) + new_local_tensor = mesh_all_reduce(local_tensor, device_mesh, partial_spec.reduce_op, i) elif current.is_replicate(): # For replicate -> partial, we zero out all other ranks of the current mesh dim # and leave only 1 rank have the data, to perform a "zero cost" reshard. diff --git a/vescale/initialize/deferred_init.py b/vescale/initialize/deferred_init.py index cb34f4e..fdab336 100644 --- a/vescale/initialize/deferred_init.py +++ b/vescale/initialize/deferred_init.py @@ -124,6 +124,7 @@ def materialize_dtensor( placements: Tuple[Placement] = normalize_placements( placements, device_mesh.ndim, tensor_ndim=tensor.ndim, none_as_replicate=True ) + has_shard_placement = any(p.is_shard() for p in placements) # get local tensor shape local_shape = compute_local_shape(global_shape, device_mesh, placements) torch_device = torch.device(device) @@ -139,9 +140,19 @@ def materialize_dtensor( random._rng_tracker = random.init_vescale_rng_tracker() assert random._rng_tracker is not None with random._rng_tracker._distribute_region(spec): - tensor = _C.materialize_tensor_with_local_shape(tensor, local_shape, torch_device) + # shortcut for parameters with no shard placements. TODO: what about Partial sharding + if not has_shard_placement: + tensor = _C.materialize_tensor(tensor) + tensor = tensor.to(device) + else: + tensor = _C.materialize_tensor_with_local_shape(tensor, local_shape, torch_device) else: - tensor = _C.materialize_tensor_with_local_shape(tensor, local_shape, torch_device) + # shortcut for parameters with no shard placements. TODO: what about Partial sharding + if not has_shard_placement: + tensor = _C.materialize_tensor(tensor) + tensor = tensor.to(device) + else: + tensor = _C.materialize_tensor_with_local_shape(tensor, local_shape, torch_device) # wrap as dtensor return DTensor( local_tensor=tensor, @@ -184,6 +195,7 @@ def materialize_dparameter( placements: Tuple[Placement] = normalize_placements( placements, device_mesh.ndim, tensor_ndim=param.data.ndim, none_as_replicate=True ) + has_shard_placement = any(p.is_shard() for p in placements) # get local tensor shape local_shape = compute_local_shape(global_shape, device_mesh, placements) torch_device = torch.device(device) @@ -199,9 +211,17 @@ def materialize_dparameter( random._rng_tracker = random.init_vescale_rng_tracker() assert random._rng_tracker is not None with random._rng_tracker._distribute_region(spec): - param = _C.materialize_tensor_with_local_shape(param, local_shape, torch_device) + # shortcut for parameters with no shard placements. TODO: what about Partial sharding + if not has_shard_placement: + param = _C.materialize_tensor(param) + else: + param = _C.materialize_tensor_with_local_shape(param, local_shape, torch_device) else: - param = _C.materialize_tensor_with_local_shape(param, local_shape, torch_device) + # shortcut for parameters with no shard placements. TODO: what about Partial sharding + if not has_shard_placement: + param = _C.materialize_tensor(param) + else: + param = _C.materialize_tensor_with_local_shape(param, local_shape, torch_device) # wrap parameter's data as dtensor dt = DTensor( local_tensor=param.data, diff --git a/vescale/model/patch/linear.py b/vescale/model/patch/linear.py index 5df87e3..0a96c97 100644 --- a/vescale/model/patch/linear.py +++ b/vescale/model/patch/linear.py @@ -26,6 +26,8 @@ from vescale.dtensor.placement_types import Shard from vescale.dmodule.placements_interface import PlacementsInterface +from .utils import is_patched, set_patched + def make_new_row_parallel_linear_forward(device_mesh: DeviceMesh, out_pi: PlacementsInterface): r""" @@ -108,15 +110,14 @@ def patch(root: torch.nn.Module): assert isinstance(submod.bias, DTensor) assert isinstance(output_pis, Sequence) and len(output_pis) == 1, "Linear has only a single output!" out_pi = output_pis[0] - assert ( - out_pi.placements and isinstance(out_pi.placements, Sequence) and len(out_pi.placements) == 1 - ), "Only 1D sharding is considered now!" + assert out_pi.placements and isinstance(out_pi.placements, Sequence) if any(p.is_partial() for p in submod.bias.placements) or any(p.is_partial() for p in out_pi.placements): warnings.warn( f"`{submod_path}` is a Row Parallel Linear with `Partial` bias/output, which can cause undefined result in Adam Optimizer.", UserWarning, ) + assert not is_patched(submod), "RowParallelLinear should only be patched once!" # replace nn.Linear's forward with customized forward. # NOTE: dyanmo doesn't support functools now, use function closure instead. # TODO: collaborate with upstream to support functools @@ -124,3 +125,4 @@ def patch(root: torch.nn.Module): make_new_row_parallel_linear_forward(device_mesh=submod.weight.device_mesh, out_pi=out_pi), submod, ) + set_patched(submod) diff --git a/vescale/model/patch/utils.py b/vescale/model/patch/utils.py new file mode 100644 index 0000000..b1fbd2b --- /dev/null +++ b/vescale/model/patch/utils.py @@ -0,0 +1,30 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +from torch import nn + +__all__ = ["is_patched", "set_patched"] + +_TAG_PATCHED = "PATCHED" + + +def is_patched(module: nn.Module) -> bool: + return hasattr(module, _TAG_PATCHED) + + +def set_patched(module: nn.Module) -> None: + setattr(module, _TAG_PATCHED, True) diff --git a/vescale/model/patch/vp_cross_entropy.py b/vescale/model/patch/vp_cross_entropy.py index 6671ea2..9dacd35 100644 --- a/vescale/model/patch/vp_cross_entropy.py +++ b/vescale/model/patch/vp_cross_entropy.py @@ -31,6 +31,8 @@ from vescale.dtensor.dtensor import DTensor from vescale.dtensor.placement_types import Partial, Replicate +from .utils import is_patched, set_patched + def _get_vocab_range(per_partition_vocab_size: int, rank, world_size: int) -> Sequence[int]: index_f = rank * per_partition_vocab_size @@ -237,8 +239,10 @@ def patch(root: torch.nn.Module) -> None: ) continue + assert not is_patched(submod), "VocabParallelCrossEntropy should only be patched once!" # replace nn.Linear's forward with customized forward. submod.forward = MethodType( partial(VocabParallelCrossEntropy.forward, device_mesh=None), submod, ) + set_patched(submod) diff --git a/vescale/model/patch/vp_embedding.py b/vescale/model/patch/vp_embedding.py index ed09fb4..0a4f3ea 100644 --- a/vescale/model/patch/vp_embedding.py +++ b/vescale/model/patch/vp_embedding.py @@ -32,6 +32,8 @@ from vescale.dtensor.dtensor import DTensor from vescale.dtensor.placement_types import Partial, Replicate +from .utils import is_patched, set_patched + class VocabParallelEmbedding: @staticmethod @@ -116,8 +118,10 @@ def patch(root: torch.nn.Module) -> None: ) continue + assert not is_patched(submod), "VocabParallelEmbedding should only be patched once!" # replace nn.Linear's forward with customized forward. submod.forward = MethodType( partial(VocabParallelEmbedding.forward, device_mesh=submod.weight.device_mesh), submod, ) + set_patched(submod)