From 9ec300f2aafc68e8dc503866954a7293214ada85 Mon Sep 17 00:00:00 2001 From: lzhengning Date: Tue, 28 Sep 2021 17:54:56 +0800 Subject: [PATCH 1/3] Polish RNN, GRU, and LSTM 1. Use Cudnn to speed rnn 2. Fix: document string not correctly rendered 3. Fix: RNN cannot accept relu nonlinearity 4. feat: add default hidden state when executing rnn/gru/lstm --- .../cuda/cudnn/inc/cudnn_rnn_descriptor.h | 154 ++++++ .../cuda/cudnn/ops/cudnn_rnn_backward_x_op.cc | 195 ++++++++ .../cuda/cudnn/ops/cudnn_rnn_backward_x_op.h | 38 ++ .../extern/cuda/cudnn/ops/cudnn_rnn_op.cc | 227 +++++++++ .../extern/cuda/cudnn/ops/cudnn_rnn_op.h | 36 ++ .../cuda/cudnn/src/cudnn_rnn_descriptor.cc | 74 +++ python/jittor/nn.py | 332 ++++++++----- python/jittor/test/test_rnn.py | 469 +++++++++++++++--- 8 files changed, 1340 insertions(+), 185 deletions(-) create mode 100644 python/jittor/extern/cuda/cudnn/inc/cudnn_rnn_descriptor.h create mode 100644 python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.cc create mode 100644 python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.h create mode 100644 python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.cc create mode 100644 python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.h create mode 100644 python/jittor/extern/cuda/cudnn/src/cudnn_rnn_descriptor.cc diff --git a/python/jittor/extern/cuda/cudnn/inc/cudnn_rnn_descriptor.h b/python/jittor/extern/cuda/cudnn/inc/cudnn_rnn_descriptor.h new file mode 100644 index 00000000..977aafb9 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/inc/cudnn_rnn_descriptor.h @@ -0,0 +1,154 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: +// Zheng-Ning Liu +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" +#include "cudnn_warper.h" +#include "executor.h" +#include "init.h" + + +namespace jittor { + +static inline cudnnRNNMode_t rnn_string_to_rnn_mode(string mode) { + if (mode == "relu") + return CUDNN_RNN_RELU; + if (mode == "tanh") + return CUDNN_RNN_TANH; + if (mode == "lstm") + return CUDNN_LSTM; + ASSERT(mode == "gru") << "rnn mode must be relu, tanh, lstm, or gru, but got " << mode; + return CUDNN_GRU; +} + +static inline int rnn_string_to_num_linear_layers(string mode) { + if (mode == "relu") + return 2; + if (mode == "tanh") + return 2; + if (mode == "lstm") + return 8; + ASSERT(mode == "gru") << "mode must be relu, tanh, lstm, or gru, but got " << mode; + return 6; +} + +/** A wrapper for CUDNN dropout descriptor + */ +struct DropoutDescriptor { + cudnnDropoutDescriptor_t desc; + size_t stateSize, stateAllocation; + float dropout; + void *stateSpace; + + DropoutDescriptor(cudnnHandle_t handle, float dropout) + : dropout(dropout), stateSpace(nullptr) { + checkCudaErrors(cudnnCreateDropoutDescriptor(&desc)); + if (dropout > 0) { + checkCudaErrors(cudnnDropoutGetStatesSize(handle, &stateSize)); + stateSpace = exe.temp_allocator->alloc(stateSize, stateAllocation); + checkCudaErrors(cudnnSetDropoutDescriptor( + desc, + cudnn_handle, + dropout, + stateSpace, + stateSize, + get_seed() + )); + } else { + checkCudaErrors(cudnnSetDropoutDescriptor( + desc, handle, 0, nullptr, 0, 0 + )); + } + } + ~DropoutDescriptor() { + checkCudaErrors(cudnnDestroyDropoutDescriptor(desc)); + if (stateSpace) + exe.temp_allocator->free(stateSpace, stateSize, stateAllocation); + } +}; + +/** A wrapper for CUDNN RNN descriptor + */ +struct RnnDescriptor { + cudnnHandle_t handle; + cudnnRNNDescriptor_t desc; + DropoutDescriptor dropoutDesc; + + RnnDescriptor(cudnnHandle_t handle, string mode, int hidden_size, int num_layers, + float dropout, bool bidirectional) : handle(handle), dropoutDesc(handle, dropout) { + checkCudaErrors(cudnnCreateRNNDescriptor(&desc)); + checkCudaErrors(cudnnSetRNNDescriptor_v6( + handle, + desc, + hidden_size, + num_layers, + dropoutDesc.desc, + CUDNN_LINEAR_INPUT, + bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, + rnn_string_to_rnn_mode(mode), + CUDNN_RNN_ALGO_STANDARD, + CUDNN_DATA_FLOAT + )); + } + + ~RnnDescriptor() { + checkCudaErrors(cudnnDestroyRNNDescriptor(desc)); + } + + size_t weight_space_size(const cudnnTensorDescriptor_t &xDesc) { + size_t size; + checkCudaErrors(cudnnGetRNNParamsSize( + handle, desc, xDesc, &size, CUDNN_DATA_FLOAT + )); + return size; + } + + size_t work_space_size(const cudnnTensorDescriptor_t *xDesc, int seq_length) { + size_t size; + checkCudaErrors(cudnnGetRNNWorkspaceSize( + handle, desc, seq_length, xDesc, &size + )); + return size; + } + + size_t reserve_space_size(const cudnnTensorDescriptor_t *xDesc, int seq_length) { + size_t size; + checkCudaErrors(cudnnGetRNNTrainingReserveSize( + handle, desc, seq_length, xDesc, &size + )); + return size; + } +}; + +/** + */ +struct RnnWeightDescriptor { + cudnnFilterDescriptor_t desc; + size_t size; + RnnWeightDescriptor(size_t size) : size(size) { + int dimW[3] = {(int) (size / sizeof(float)), 1, 1}; + checkCudaErrors(cudnnCreateFilterDescriptor(&desc)); + checkCudaErrors(cudnnSetFilterNdDescriptor(desc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 3, dimW)); + } + ~RnnWeightDescriptor() { + cudnnDestroyFilterDescriptor(desc); + } +}; + +/** + Returns offsets of RNN linear parameters in a flatten array. + + Returns + ======= + list: [total size, param #1 offset, param #2 offset, ...] + + TODO: support cudnn rnn-v8; support proj_size + */ +// @pyjt(cudnn_rnn_weight_offset) +vector cudnn_rnn_weight_offset(string mode, int input_size, int hidden_size, int num_layers, int proj_size, bool bias, bool bidirectional); + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.cc new file mode 100644 index 00000000..347cf5bf --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.cc @@ -0,0 +1,195 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: +// Zheng-Ning Liu +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "cudnn_rnn_descriptor.h" +#include "cudnn_rnn_backward_x_op.h" +#include "cudnn_warper.h" +#include "executor.h" +#include "ops/op_register.h" + +namespace jittor { + +#pragma GCC diagnostic ignored "-Wunused-variable" + +#ifndef JIT + +CudnnRnnBackwardXOp::CudnnRnnBackwardXOp(Var *x, Var* hx, Var* cx, Var* y, Var* dy, Var* dhy, Var* dcy, Var* w, Var* reservation, + string mode, int input_size, int hidden_size, int num_layers, int proj_size, + double dropout, bool bias, bool bidirectional) + : x(x), hx(hx), cx(cx), y(y), dy(dy), dhy(dhy), dcy(dcy), w(w), reservation(reservation), + mode(mode), input_size(input_size), hidden_size(hidden_size), num_layers(num_layers), + proj_size(proj_size), dropout(dropout), bias(bias), bidirectional(bidirectional) { + + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + + ASSERTop(mode,==,"lstm"); + ASSERTop(proj_size,==,0); + init_rnn(); +} + +CudnnRnnBackwardXOp::CudnnRnnBackwardXOp(Var* x, Var* hx, Var* y, Var* dy, Var* dhy, Var* w, Var* reservation, + string mode, int input_size, int hidden_size, int num_layers, int proj_size, + double dropout, bool bias, bool bidirectional) + : x(x), hx(hx), cx(nullptr), y(y), dy(dy), dhy(dhy), dcy(nullptr), w(w), reservation(reservation), + mode(mode), input_size(input_size), hidden_size(hidden_size), num_layers(num_layers), + proj_size(proj_size), dropout(dropout), bias(bias), bidirectional(bidirectional) { + + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + + ASSERTop(mode,!=,"lstm"); + ASSERTop(proj_size,==,0); + init_rnn(); +} + +void CudnnRnnBackwardXOp::init_rnn() { + dx = create_output(nullptr, ns_float32); + dhx = create_output(nullptr, ns_float32); + + if (mode == "lstm") + dcx = create_output(nullptr, ns_float32); + else + dcx = nullptr; + + dw = create_output(nullptr, dtype_infer(x->ns, y->ns)); + + seq_length = y->shape[0]; + batch_size = y->shape[1]; +} + +void CudnnRnnBackwardXOp::infer_shape() { + dx->set_shape(NanoVector(seq_length, batch_size, input_size)); + + int num_directions = 1 + bidirectional; + if (proj_size > 0) + dhx->set_shape(NanoVector(num_layers * num_directions, batch_size, proj_size)); + else + dhx->set_shape(NanoVector(num_layers * num_directions, batch_size, hidden_size)); + + if (dcx) + dcx->set_shape(NanoVector(num_layers * num_directions, batch_size, hidden_size)); + + dw->set_shape(w->shape); +} + +void CudnnRnnBackwardXOp::jit_prepare(JK& jk) { + jk << _CS("[Tx:") << hx->dtype(); + jk << _CS("][Ty:") << y->dtype(); + jk << _CS("][Tw:") << w->dtype(); + jk << ']'; +} + +#else // JIT +#ifdef JIT_cuda + +template __inline__ cudnnDataType_t getDataType(); +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_HALF; } +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_FLOAT; } + +void CudnnRnnBackwardXOp::jit_run() { + int num_directions = 1 + bidirectional; + + int in_dims[3] = {batch_size, input_size, 1}; + int out_dims[3] = {batch_size, hidden_size * num_directions, 1}; + int in_strides[3] = {in_dims[1] * in_dims[2], in_dims[2], 1}; + int out_strides[3] = {out_dims[1] * out_dims[2], out_dims[2], 1}; + int hidden_dims[3] = {num_layers * num_directions, batch_size, hidden_size}; + int hidden_strides[3] = {hidden_dims[1] * hidden_dims[2], hidden_dims[2], 1}; + + vector xDesc(seq_length), dxDesc(seq_length); + vector yDesc(seq_length), dyDesc(seq_length); + + for (int i = 0; i < seq_length; ++i) { + checkCudaErrors(cudnnCreateTensorDescriptor(&xDesc[i])); + checkCudaErrors(cudnnCreateTensorDescriptor(&dxDesc[i])); + checkCudaErrors(cudnnCreateTensorDescriptor(&yDesc[i])); + checkCudaErrors(cudnnCreateTensorDescriptor(&dyDesc[i])); + checkCudaErrors(cudnnSetTensorNdDescriptor(xDesc[i], getDataType(), 3, in_dims, in_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(dxDesc[i], getDataType(), 3, in_dims, in_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(yDesc[i], getDataType(), 3, out_dims, out_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(dyDesc[i], getDataType(), 3, out_dims, out_strides)); + } + + cudnnTensorDescriptor_t dhyDesc, dcyDesc; + cudnnTensorDescriptor_t hxDesc, cxDesc, dhxDesc, dcxDesc; + checkCudaErrors(cudnnCreateTensorDescriptor(&hxDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&cxDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&dhxDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&dcxDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&dhyDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&dcyDesc)); + checkCudaErrors(cudnnSetTensorNdDescriptor(hxDesc, getDataType(), 3, hidden_dims, hidden_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(cxDesc, getDataType(), 3, hidden_dims, hidden_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(dhxDesc, getDataType(), 3, hidden_dims, hidden_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(dcxDesc, getDataType(), 3, hidden_dims, hidden_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(dhyDesc, getDataType(), 3, hidden_dims, hidden_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(dcyDesc, getDataType(), 3, hidden_dims, hidden_strides)); + + RnnWeightDescriptor w_desc(w->size); + RnnDescriptor rnn_desc(cudnn_handle, mode, hidden_size, num_layers, dropout, bidirectional); + + void *work_space; + size_t work_space_size = rnn_desc.work_space_size(dxDesc.data(), seq_length); + size_t work_space_allocation; + if (work_space_size > 0) + work_space = exe.temp_allocator->alloc(work_space_size, work_space_allocation); + + size_t reserveSpaceSize = reservation->size; + + checkCudaErrors(cudnnRNNBackwardData( + cudnn_handle, rnn_desc.desc, + seq_length, + yDesc.data(), y->ptr(), + dyDesc.data(), dy->ptr(), + dhyDesc, dhy->ptr(), + dcyDesc, mode == "lstm" ? dcy->ptr(): nullptr, + w_desc.desc, w->ptr(), + hxDesc, hx->ptr(), + cxDesc, mode == "lstm" ? cx->ptr() : nullptr, + dxDesc.data(), dx->ptr(), + dhxDesc, dhx->ptr(), + dcxDesc, mode == "lstm" ? dcx->ptr() : nullptr, + work_space, work_space_size, + reservation->ptr(), reservation->size + )); + + checkCudaErrors(cudaMemset(dw->ptr(), 0, dw->size)); + + checkCudaErrors(cudnnRNNBackwardWeights( + cudnn_handle, rnn_desc.desc, + seq_length, + xDesc.data(), x->ptr(), + hxDesc, hx->ptr(), + yDesc.data(), y->ptr(), + work_space, work_space_size, + w_desc.desc, dw->ptr(), + reservation->ptr(), reservation->size + )); + + for (int i = 0; i < seq_length; ++i) { + checkCudaErrors(cudnnDestroyTensorDescriptor(xDesc[i])); + checkCudaErrors(cudnnDestroyTensorDescriptor(dxDesc[i])); + checkCudaErrors(cudnnDestroyTensorDescriptor(yDesc[i])); + checkCudaErrors(cudnnDestroyTensorDescriptor(dyDesc[i])); + } + + checkCudaErrors(cudnnDestroyTensorDescriptor(dhyDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(dcyDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(hxDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(cxDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(dhxDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(dcxDesc)); + + if (work_space) + exe.temp_allocator->free(work_space, work_space_size, work_space_allocation); +} + +#endif +#endif // JIT +} diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.h b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.h new file mode 100644 index 00000000..b4064131 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.h @@ -0,0 +1,38 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: +// Zheng-Ning Liu +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CudnnRnnBackwardXOp : Op { + Var* x, * hx, * cx; + Var* y, * dy, * dhy, * dcy; + Var* w; + Var* dx, * dhx, * dcx, * dw; + Var* reservation; + string mode; + int input_size, hidden_size, num_layers, proj_size, batch_size; + int seq_length; + float dropout; + bool bias, bidirectional; + + // @attrs(multiple_outputs) + CudnnRnnBackwardXOp(Var* x, Var* hx, Var* cx, Var* y, Var* dy, Var* dhy, Var* dcy, Var* w, Var* reservation, string mode, int input_size, int hidden_size, int num_layers, int proj_size, double dropout, bool bias, bool bidirectional); + + // @attrs(multiple_outputs) + CudnnRnnBackwardXOp(Var* x, Var* hx, Var* y, Var* dy, Var* dhy, Var* w, Var* reservation, string mode, int input_size, int hidden_size, int num_layers, int proj_size, double dropout, bool bias, bool bidirectional); + + void init_rnn(); + + const char* name() const override { return "cudnn_rnn_backward_x"; } + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.cc new file mode 100644 index 00000000..6772a6f8 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.cc @@ -0,0 +1,227 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: +// Zheng-Ning Liu +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "cudnn_rnn_descriptor.h" +#include "cudnn_rnn_op.h" +#include "cudnn_warper.h" +#include "executor.h" +#include "ops/op_register.h" + +using namespace std; + +namespace jittor { + +#pragma GCC diagnostic ignored "-Wunused-variable" + +#ifndef JIT + +CudnnRnnOp::CudnnRnnOp(Var* x, Var* hx, Var* cx, Var* w, + string mode, int input_size, int hidden_size, int num_layers, int proj_size, + double dropout, bool bias, bool bidirectional, bool is_train) + : x(x), hx(hx), cx(cx), w(w), mode(mode), input_size(input_size), hidden_size(hidden_size), + num_layers(num_layers), proj_size(proj_size), dropout(dropout), bias(bias), + bidirectional(bidirectional), is_train(is_train) { + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_grads, 1); + + ASSERTop(mode,==,"lstm"); + ASSERTop(proj_size,==,0); + init_rnn(); +} + +CudnnRnnOp::CudnnRnnOp(Var* x, Var* hx, Var* w, + string mode, int input_size, int hidden_size, int num_layers, int proj_size, + double dropout, bool bias, bool bidirectional, bool is_train) + : x(x), hx(hx), cx(nullptr), w(w), mode(mode), input_size(input_size), hidden_size(hidden_size), + num_layers(num_layers), proj_size(proj_size), dropout(dropout), bias(bias), + bidirectional(bidirectional), is_train(is_train) { + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_grads, 1); + + ASSERTop(mode,!=,"lstm"); + ASSERTop(proj_size,==,0); + init_rnn(); +} + +void CudnnRnnOp::init_rnn() { + y = create_output(nullptr, dtype_infer(x->ns, w->ns)); + hy = create_output(nullptr, dtype_infer(x->ns, w->ns)); + if (mode == "lstm") + cy = create_output(nullptr, dtype_infer(x->ns, w->ns)); + else + cy = nullptr; + + if (is_train) + reservation = create_output(nullptr, ns_float32); + else + reservation = nullptr; + + seq_length = x->shape[0]; + batch_size = x->shape[1]; +} + +void CudnnRnnOp::infer_shape() { + ASSERTop(x->shape.size(),==,3); + ASSERTop(x->shape[2],==,input_size); + + int num_directions = 1 + bidirectional; + + y->set_shape(NanoVector(seq_length, batch_size, hidden_size * num_directions)); + + if (proj_size > 0) + hy->set_shape(NanoVector(num_layers * num_directions, batch_size, proj_size)); + else + hy->set_shape(NanoVector(num_layers * num_directions, batch_size, hidden_size)); + + if (cy) + cy->set_shape(NanoVector(num_layers * num_directions, batch_size, hidden_size)); + + if (reservation) { + int in_dims[3] = {batch_size, input_size, 1}; + int in_strides[3] = {in_dims[1] * in_dims[2], in_dims[2], 1}; + + vector xDesc(seq_length); + RnnDescriptor rnn_desc(cudnn_handle, mode, hidden_size, num_layers, dropout, bidirectional); + for (int i = 0; i < seq_length; ++i) { + checkCudaErrors(cudnnCreateTensorDescriptor(&xDesc[i])); + checkCudaErrors(cudnnSetTensorNdDescriptor(xDesc[i], CUDNN_DATA_FLOAT, 3, in_dims, in_strides)); + } + reservation->set_shape(rnn_desc.reserve_space_size(xDesc.data(), seq_length)); + } +} + +void CudnnRnnOp::jit_prepare(JK& jk) { + jk << _CS("[Tx:") << x->dtype(); + jk << _CS("][Ty:") << y->dtype(); + jk << _CS("][Tw:") << w->dtype(); + jk << ']'; +} + +static auto make_backwardx_with_cx = get_op_info("cudnn_rnn_backward_x") + .get_constructor, Var*, Var*, Var*, Var*, Var*, Var*, Var*, Var*, Var*, string, int, int, int, int, double, bool, bool>(); +static auto make_backwardx_without_cx = get_op_info("cudnn_rnn_backward_x") + .get_constructor, Var*, Var*, Var*, Var*, Var*, Var*, Var*, string, int, int, int, int, double, bool, bool>(); + +void CudnnRnnOp::grads(Var** dout, VarPtr* dins) { + Var *dy = dout[0]; + Var *dhy = dout[1]; + Var *dcy = cx ? dout[2] : nullptr; + + vector dInput; + if (cx) + dInput = make_backwardx_with_cx(x, hx, cx, y, dy, dhy, dcy, w, reservation, mode, input_size, hidden_size, num_layers, proj_size, dropout, bias, bidirectional); + else + dInput = make_backwardx_without_cx(x, hx, y, dy, dhy, w, reservation, mode, input_size, hidden_size, num_layers, proj_size, dropout, bias, bidirectional); + + for (int i = 0; i < 3 + (cx != nullptr); ++i) + dins[i] = move(dInput[i]); +} + +#else // JIT +#ifdef JIT_cuda + +#pragma clang diagnostic ignored "-Wtautological-compare" + +template __inline__ cudnnDataType_t getDataType(); +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_HALF; } +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_FLOAT; } + +void CudnnRnnOp::jit_run() { + int num_directions = bidirectional + 1; + int num_linear_layers = rnn_string_to_num_linear_layers(mode); + + int in_dims[3] = {batch_size, input_size, 1}; + int out_dims[3] = {batch_size, hidden_size * num_directions, 1}; + int in_strides[3] = {in_dims[1] * in_dims[2], in_dims[2], 1}; + int out_strides[3] = {out_dims[1] * out_dims[2], out_dims[2], 1}; + int hidden_dims[3] = {num_layers * num_directions, batch_size, hidden_size}; + int hidden_strides[3] = {hidden_dims[1] * hidden_dims[2], hidden_dims[2], 1}; + + vector xDesc(seq_length); + vector yDesc(seq_length); + cudnnTensorDescriptor_t hxDesc, cxDesc; + cudnnTensorDescriptor_t hyDesc, cyDesc; + + for (int i = 0; i < seq_length; ++i) { + checkCudaErrors(cudnnCreateTensorDescriptor(&xDesc[i])); + checkCudaErrors(cudnnCreateTensorDescriptor(&yDesc[i])); + checkCudaErrors(cudnnSetTensorNdDescriptor(xDesc[i], getDataType(), 3, in_dims, in_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(yDesc[i], getDataType(), 3, out_dims, out_strides)); + } + + checkCudaErrors(cudnnCreateTensorDescriptor(&hxDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&cxDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&hyDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&cyDesc)); + + checkCudaErrors(cudnnSetTensorNdDescriptor(hxDesc, getDataType(), 3, hidden_dims, hidden_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(cxDesc, getDataType(), 3, hidden_dims, hidden_strides)); + + checkCudaErrors(cudnnSetTensorNdDescriptor(hyDesc, getDataType(), 3, hidden_dims, hidden_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(cyDesc, getDataType(), 3, hidden_dims, hidden_strides)); + + RnnDescriptor rnn_desc(cudnn_handle, mode, hidden_size, num_layers, dropout, bidirectional); + + void *work_space; + size_t work_space_size = rnn_desc.work_space_size(xDesc.data(), seq_length); + size_t work_space_allocation; + if (work_space_size > 0) + work_space = exe.temp_allocator->alloc(work_space_size, work_space_allocation); + + RnnWeightDescriptor w_desc(w->size); + + if (is_train) { + checkCudaErrors(cudnnRNNForwardTraining( + cudnn_handle, rnn_desc.desc, + seq_length, + xDesc.data(), x->ptr(), + hxDesc, hx->ptr(), + cxDesc, mode == "lstm" ? cx->ptr() : nullptr, + w_desc.desc, w->ptr(), + yDesc.data(), y->ptr(), + hyDesc, hy->ptr(), + cyDesc, mode == "lstm" ? cy->ptr() : nullptr, + work_space, work_space_size, + reservation->ptr(), reservation->size + )); + } else { + checkCudaErrors(cudnnRNNForwardInference( + cudnn_handle, rnn_desc.desc, + seq_length, + xDesc.data(), x->ptr(), + hxDesc, hx->ptr(), + cxDesc, mode == "lstm" ? cx->ptr() : nullptr, + w_desc.desc, w->ptr(), + yDesc.data(), y->ptr(), + hyDesc, hy->ptr(), + cyDesc, mode == "lstm" ? cy->ptr() : nullptr, + work_space, work_space_size + )); + } + + for (int i = 0; i < seq_length; i++) { + checkCudaErrors(cudnnDestroyTensorDescriptor(xDesc[i])); + checkCudaErrors(cudnnDestroyTensorDescriptor(yDesc[i])); + } + + checkCudaErrors(cudnnDestroyTensorDescriptor(hxDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(cxDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(hyDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(cyDesc)); + + if (work_space) + exe.temp_allocator->free(work_space, work_space_size, work_space_allocation); +} + +#endif +#endif // JIT + +} // jittor + diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.h b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.h new file mode 100644 index 00000000..709a1436 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.h @@ -0,0 +1,36 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: +// Zheng-Ning Liu +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CudnnRnnOp : Op { + Var* x, * hx, * cx, * y, * hy, * cy; + Var* w; + Var* reservation; + string mode; + int input_size, hidden_size, num_layers, proj_size; + int seq_length, batch_size; + float dropout; + bool bias, bidirectional, is_train; + + // @attrs(multiple_outputs) + CudnnRnnOp(Var* x, Var* hx, Var* cx, Var* w, string mode, int input_size, int hidden_size, int num_layers, int proj_size, double dropout, bool batch_first, bool bias, bool bidirectional); + // @attrs(multiple_outputs) + CudnnRnnOp(Var* x, Var* hx, Var* w, string mode, int input_size, int hidden_size, int num_layers, int proj_size, double dropout, bool batch_first, bool bias, bool bidirectional); + + void init_rnn(); + + const char* name() const override { return "cudnn_rnn"; } + void grads(Var** douts, VarPtr* dins) override; + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/src/cudnn_rnn_descriptor.cc b/python/jittor/extern/cuda/cudnn/src/cudnn_rnn_descriptor.cc new file mode 100644 index 00000000..b4f4518b --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/src/cudnn_rnn_descriptor.cc @@ -0,0 +1,74 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: +// Zheng-Ning Liu +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "cudnn_rnn_descriptor.h" + +namespace jittor { + +vector cudnn_rnn_weight_offset(string mode, int input_size, int hidden_size, int num_layers, int proj_size, bool bias, bool bidirectional) { + // A pseudo mini-batch for fetching weight space size. + int dimX[] = {1, input_size, 1}; + int strideX[] = {input_size, 1, 1}; + cudnnTensorDescriptor_t xDesc; + checkCudaErrors(cudnnCreateTensorDescriptor(&xDesc)); + checkCudaErrors(cudnnSetTensorNdDescriptor(xDesc, CUDNN_DATA_FLOAT, 3, dimX, strideX)); + + RnnDescriptor rnn_desc = RnnDescriptor(cudnn_handle, mode, hidden_size, num_layers, 0, bidirectional); + int weightSpaceSize = rnn_desc.weight_space_size(xDesc); + RnnWeightDescriptor w_desc(weightSpaceSize); + + vector weight_offsets; + weight_offsets.push_back(weightSpaceSize / sizeof(float)); + + int num_directions = bidirectional + 1; + int num_linear_layers = rnn_string_to_num_linear_layers(mode); + + for (int layer = 0; layer < num_layers * num_directions; layer++) { + for (int linLayerID = 0; linLayerID < num_linear_layers; linLayerID++) { + cudnnFilterDescriptor_t linLayerMatDesc; + cudnnFilterDescriptor_t linLayerBiasDesc; + float *linLayerMat = nullptr; + float *linLayerBias = nullptr; + + checkCudaErrors(cudnnCreateFilterDescriptor(&linLayerMatDesc)); + checkCudaErrors(cudnnCreateFilterDescriptor(&linLayerBiasDesc)); + + checkCudaErrors(cudnnGetRNNLinLayerMatrixParams( + cudnn_handle, rnn_desc.desc, + layer, + xDesc, + w_desc.desc, + nullptr, + linLayerID, + linLayerMatDesc, + (void **) &linLayerMat + )); + weight_offsets.push_back(linLayerMat - (float *) nullptr); + + if (bias) { + checkCudaErrors(cudnnGetRNNLinLayerBiasParams( + cudnn_handle, rnn_desc.desc, + layer, + xDesc, + w_desc.desc, + nullptr, + linLayerID, + linLayerBiasDesc, + (void **) &linLayerBias + )); + weight_offsets.push_back(linLayerBias - (float *) nullptr); + } + } + } + + checkCudaErrors(cudnnDestroyTensorDescriptor(xDesc)); + + return weight_offsets; +} + + +} // jittor diff --git a/python/jittor/nn.py b/python/jittor/nn.py index bdfe04d6..86d9fffc 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -12,8 +12,9 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** from abc import abstractmethod +from sys import breakpointhook import jittor as jt -from jittor import init, Module +from jittor import flatten, init, Module import numpy as np import collections import math @@ -1881,6 +1882,7 @@ class RNNCell(jt.Module): return y + class GRUCell(jt.Module): def __init__(self, input_size, hidden_size, bias=True): ''' A gated recurrent unit (GRU) cell. @@ -1941,7 +1943,7 @@ class RNNBase(Module): def __init__(self, mode: str, input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, batch_first: bool = False, dropout: float = 0, bidirectional: bool = False, - proj_size: int = 0) -> None: + proj_size: int = 0, nonlinearity: str = None) -> None: super().__init__() self.mode = mode @@ -1953,6 +1955,7 @@ class RNNBase(Module): self.dropout = dropout self.bidirectional = bidirectional self.proj_size = proj_size + self.nonlinearity = nonlinearity if mode == 'LSTM': gate_size = 4 * hidden_size @@ -1994,6 +1997,56 @@ class RNNBase(Module): build_unit(f'bias_ih_l{layer}', gate_size) build_unit(f'bias_hh_l{layer}', gate_size) + def _cudnn_flatten_weights(self, cudnn_mode): + def copy_to_flatten_weight(param_name, offset_idx, num_gates): + def copy_to(param_name, offset_idx, idx): + cur_offset = self._cudnn_weight_offset[offset_idx] + param = getattr(self, param_name) + param = param[self.hidden_size * idx: self.hidden_size * (idx + 1)] + ft_weight[cur_offset:cur_offset + param.numel()] = param.flatten() + + if self.bias: + for idx in range(num_gates): + copy_to('weight' + param_name, offset_idx + idx * 2, idx) + copy_to('bias' + param_name, offset_idx + idx * 2 + 1, idx) + return num_gates * 2 + else: + for idx in range(num_gates): + copy_to('weight' + param_name, offset_idx + idx, idx) + return num_gates + + if jt.flags.use_cuda and jt.cudnn: + if getattr(self, '_cudnn_weight_size', None) is None: + offset_array = jt.cudnn.cudnn_rnn_weight_offset( + cudnn_mode, + self.input_size, + self.hidden_size, + self.num_layers, + self.proj_size, + self.bias, + self.bidirectional + ) + self._cudnn_weight_size = offset_array[0] + self._cudnn_weight_offset = offset_array[1:] + + num_gates = { + "RNN": 1, "LSTM": 4, "GRU": 3 + }[self.mode] + ft_weight = jt.zeros(self._cudnn_weight_size, dtype=jt.float32) + + cnt = 0 + for layer in range(self.num_layers): + suffix = '' + cnt += copy_to_flatten_weight(f'_ih_l{layer}' + suffix, cnt, num_gates) + cnt += copy_to_flatten_weight(f'_hh_l{layer}' + suffix, cnt, num_gates) + if self.bidirectional: + suffix = '_reverse' + cnt += copy_to_flatten_weight(f'_ih_l{layer}' + suffix, cnt, num_gates) + cnt += copy_to_flatten_weight(f'_hh_l{layer}' + suffix, cnt, num_gates) + return ft_weight + else: + raise RuntimeError("Not Cudnn found") + @abstractmethod def call_rnn_cell(self, input, hidden, suffix): pass @@ -2013,87 +2066,116 @@ class RNNBase(Module): return output, hidden - def execute(self, input, hx): + def _execute_cudnn_rnn(self, input, hx): + cudnn_mode = { + ('RNN', 'tanh'): 'tanh', + ('RNN', 'relu'): 'relu', + ('LSTM', None): 'lstm', + ('GRU', None): 'gru' + }[(self.mode, self.nonlinearity)] + ft_weight = self._cudnn_flatten_weights(cudnn_mode) + + if self.mode == 'LSTM': + ret = jt.cudnn.ops.cudnn_rnn(input, hx[0], hx[1], ft_weight, + cudnn_mode, self.input_size, self.hidden_size, self.num_layers, 0, + self.dropout, self.bias, self.bidirectional, self.is_training() + ) + return ret[0], (ret[1], ret[2]) + else: + ret = jt.cudnn.ops.cudnn_rnn(input, hx, ft_weight, + cudnn_mode, self.input_size, self.hidden_size, self.num_layers, 0, + self.dropout, self.bias, self.bidirectional, self.is_training() + ) + return ret[0], ret[1] + + def execute(self, input, hx=None): if self.batch_first: input = input.permute(1, 0, 2) num_directions = 2 if self.bidirectional else 1 if hx is None: - hx = self.default_init_state() + if self.mode in ['RNN', 'GRU']: + hx = jt.zeros((num_directions * self.num_layers, input.shape[1], self.hidden_size), dtype=input.dtype) + elif self.mode == 'LSTM': + hx = (jt.zeros((num_directions * self.num_layers, input.shape[1], self.hidden_size), dtype=input.dtype), + jt.zeros((num_directions * self.num_layers, input.shape[1], self.hidden_size), dtype=input.dtype)) - hidden_n = [] + if jt.flags.use_cuda and jt.cudnn and self.proj_size == 0: + return self._execute_cudnn_rnn(input, hx) + else: + hidden_n = [] - for l in range(self.num_layers): - output = [] + for l in range(self.num_layers): + output = [] - if isinstance(hx, tuple): - hidden = [h[l * num_directions] for h in hx] - else: - hidden = hx[l * num_directions] - - output, _hidden = self.call_rnn_sequence(input, hidden, f'l{l}') - hidden_n.append(_hidden) - - if self.bidirectional: if isinstance(hx, tuple): - hidden = [h[l * num_directions + 1] for h in hx] + hidden = [h[l * num_directions] for h in hx] else: - hidden = hx[l * num_directions + 1] + hidden = hx[l * num_directions] - output_b, _hidden = self.call_rnn_sequence(input, hidden, f'l{l}_reverse') - output = jt.concat([output, output_b], dim=-1) + output, _hidden = self.call_rnn_sequence(input, hidden, f'l{l}') hidden_n.append(_hidden) - if self.dropout > 0: - input = dropout(output, p=self.dropout) + if self.bidirectional: + if isinstance(hx, tuple): + hidden = [h[l * num_directions + 1] for h in hx] + else: + hidden = hx[l * num_directions + 1] + + output_b, _hidden = self.call_rnn_sequence(input, hidden, f'l{l}_reverse') + output = jt.concat([output, output_b], dim=-1) + hidden_n.append(_hidden) + + if self.dropout > 0: + input = dropout(output, p=self.dropout) + else: + input = output + + if isinstance(hx, tuple): + hidden_n = tuple(jt.stack(hn, dim=0) for hn in zip(*hidden_n)) else: - input = output + hidden_n = jt.stack(hidden_n, dim=0) - if isinstance(hx, tuple): - hidden_n = tuple(jt.stack(hn, dim=0) for hn in zip(*hidden_n)) - else: - hidden_n = jt.stack(hidden_n, dim=0) - - return output, hidden_n + return output, hidden_n class RNN(RNNBase): + ''' Applies a multi-layer Elman RNN with tanh ReLU non-linearity to an input sequence. + + :param input_size: The number of expected features in the input. + :type input_size: int + + :param hidden_size: The number of features in the hidden state. + :type hidden_size: int + + :param num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two RNNs together to form a stacked RNN, with the second RNN taking in outputs of the first RNN and computing the final results. Default: 1 + :type num_layers: int, optinal + + :param nonlinearity: The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh' + :type nonlinearity: str, optional + + :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True. + :type bias: bool, optional + + :param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False + :type bias: bool, optional + + :param dropout: If non-zero, introduces a Dropout layer on the outputs of each RNN layer except the last layer, with dropout probability equal to dropout. Default: 0 + :type dropout: float, optional + + :param bidirectional: If True, becomes a bidirectional RNN. Default: False + :type bidirectional: bool, optional + + Example: + >>> rnn = nn.RNN(10, 20, 2) + >>> input = jt.randn(5, 3, 10) + >>> h0 = jt.randn(2, 3, 20) + >>> output, hn = rnn(input, h0) + ''' def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1, nonlinearity: str = 'tanh', bias: bool = True, batch_first: bool = False, dropout: float = 0, bidirectional: bool = False) -> None: - ''' Applies a multi-layer Elman RNN with tanh ReLU non-linearity to an input sequence. - - :param input_size: The number of expected features in the input. - :type input_size: int - - :param hidden_size: The number of features in the hidden state. - :type hidden_size: int - - :param num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two RNNs together to form a stacked RNN, with the second RNN taking in outputs of the first RNN and computing the final results. Default: 1 - :type num_layers: int, optinal - - :param nonlinearity: The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh' - :type nonlinearity: str, optional - - :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True. - :type bias: bool, optional - - :param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False - :type bias: bool, optional - - :param dropout: If non-zero, introduces a Dropout layer on the outputs of each RNN layer except the last layer, with dropout probability equal to dropout. Default: 0 - :type dropout: float, optional - - :param bidirectional: If True, becomes a bidirectional RNN. Default: False - :type bidirectional: bool, optional - - Example: - >>> rnn = nn.RNN(10, 20, 2) - >>> input = jt.randn(5, 3, 10) - >>> h0 = jt.randn(2, 3, 20) - >>> output, hn = rnn(input, h0) - ''' super().__init__('RNN', input_size, hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional) @@ -2112,47 +2194,48 @@ class RNN(RNNBase): if self.nonlinearity == 'tanh': h = jt.tanh(y) else: - h = jt.relu(y) + h = jt.nn.relu(y) return h, h class LSTM(RNNBase): + ''' Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence. + + :param input_size: The number of expected features in the input. + :type input_size: int + + :param hidden_size: The number of features in the hidden state. + :type hidden_size: int + + :param num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two LSTMs together to form a stacked LSTM, with the second LSTM taking in outputs of the first LSTM and computing the final results. Default: 1 + :type num_layers: int, optinal + + :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True. + :type bias: bool, optional + + :param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False + :type bias: bool, optional + + :param dropout: If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to dropout. Default: 0 + :type dropout: float, optional + + :param bidirectional: If True, becomes a bidirectional LSTM. Default: False + :type bidirectional: bool, optional + + :param proj_size: If > 0, will use LSTM with projections of corresponding size. Default: 0 + :type proj_size: int, optional + + Example: + >>> rnn = nn.LSTM(10, 20, 2) + >>> input = jt.randn(5, 3, 10) + >>> h0 = jt.randn(2, 3, 20) + >>> c0 = jt.randn(2, 3, 20) + >>> output, (hn, cn) = rnn(input, (h0, c0)) + ''' + def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False, proj_size=0): - ''' Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence. - - :param input_size: The number of expected features in the input. - :type input_size: int - - :param hidden_size: The number of features in the hidden state. - :type hidden_size: int - - :param num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two LSTMs together to form a stacked LSTM, with the second LSTM taking in outputs of the first LSTM and computing the final results. Default: 1 - :type num_layers: int, optinal - - :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True. - :type bias: bool, optional - - :param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False - :type bias: bool, optional - - :param dropout: If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to dropout. Default: 0 - :type dropout: float, optional - - :param bidirectional: If True, becomes a bidirectional LSTM. Default: False - :type bidirectional: bool, optional - - :param proj_size: If > 0, will use LSTM with projections of corresponding size. Default: 0 - :type proj_size: int, optional - - Example: - >>> rnn = nn.LSTM(10, 20, 2) - >>> input = jt.randn(5, 3, 10) - >>> h0 = jt.randn(2, 3, 20) - >>> c0 = jt.randn(2, 3, 20) - >>> output, (hn, cn) = rnn(input, (h0, c0)) - ''' super().__init__('LSTM', input_size, hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, proj_size=proj_size) @@ -2179,38 +2262,39 @@ class LSTM(RNNBase): class GRU(RNNBase): + ''' Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. + + :param input_size: The number of expected features in the input. + :type input_size: int + + :param hidden_size: The number of features in the hidden state. + :type hidden_size: int + + :param num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two GRUs together to form a stacked GRU, with the second GRU taking in outputs of the first GRU and computing the final results. Default: 1 + :type num_layers: int, optinal + + :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True. + :type bias: bool, optional + + :param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False + :type bias: bool, optional + + :param dropout: If non-zero, introduces a Dropout layer on the outputs of each GRU layer except the last layer, with dropout probability equal to dropout. Default: 0 + :type dropout: float, optional + + :param bidirectional: If True, becomes a bidirectional GRU. Default: False + :type bidirectional: bool, optional + + Example: + >>> rnn = nn.GRU(10, 20, 2) + >>> input = jt.randn(5, 3, 10) + >>> h0 = jt.randn(2, 3, 20) + >>> output, hn = rnn(input, h0) + ''' + def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, batch_first: bool = False, dropout: float = 0, bidirectional: bool = False) -> None: - ''' Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. - - :param input_size: The number of expected features in the input. - :type input_size: int - - :param hidden_size: The number of features in the hidden state. - :type hidden_size: int - - :param num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two GRUs together to form a stacked GRU, with the second GRU taking in outputs of the first GRU and computing the final results. Default: 1 - :type num_layers: int, optinal - - :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True. - :type bias: bool, optional - - :param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False - :type bias: bool, optional - - :param dropout: If non-zero, introduces a Dropout layer on the outputs of each GRU layer except the last layer, with dropout probability equal to dropout. Default: 0 - :type dropout: float, optional - - :param bidirectional: If True, becomes a bidirectional GRU. Default: False - :type bidirectional: bool, optional - - Example: - >>> rnn = nn.GRU(10, 20, 2) - >>> input = jt.randn(5, 3, 10) - >>> h0 = jt.randn(2, 3, 20) - >>> output, hn = rnn(input, h0) - ''' super().__init__('GRU', input_size, hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional) diff --git a/python/jittor/test/test_rnn.py b/python/jittor/test/test_rnn.py index 51e03668..9cc37fb6 100644 --- a/python/jittor/test/test_rnn.py +++ b/python/jittor/test/test_rnn.py @@ -7,15 +7,8 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** import unittest -import jittor as jt -import jittor.nn as nn -import numpy as np - - -skip_this_test = False - +from unittest.case import skipIf try: - jt.dirty_fix_pytorch_runtime_error() import torch import torch.nn as tnn except: @@ -23,72 +16,49 @@ except: tnn = None skip_this_test = True +import jittor as jt +import jittor.nn as nn +import numpy as np -def check_equal_1(t_rnn, j_rnn, input, h0): +skip_this_test = False + +def check_equal_1(t_rnn, j_rnn, input, h0, dev=None): j_rnn.load_state_dict(t_rnn.state_dict()) - t_output, th = t_rnn(torch.from_numpy(input), torch.from_numpy(h0)) + if dev: + t_output, th = t_rnn(torch.from_numpy(input).to(dev), torch.from_numpy(h0).to(dev)) + + else: + t_output, th = t_rnn(torch.from_numpy(input), torch.from_numpy(h0)) + t_output = t_output.detach().cpu().numpy() + th = th.detach().cpu().numpy() j_output, jh = j_rnn(jt.float32(input), jt.float32(h0)) + j_output, jh = j_output.data, jh.data - assert np.allclose(t_output.detach().numpy(), j_output.data, rtol=1e-03, atol=1e-06) - assert np.allclose(th.detach().numpy(), jh.data, rtol=1e-03, atol=1e-06) + np.testing.assert_allclose(t_output, j_output.data, rtol=1e-03, atol=1e-06) + np.testing.assert_allclose(th, jh.data, rtol=1e-03, atol=1e-06) - -def check_equal_2(t_rnn, j_rnn, input, h0, c0): +def check_equal_2(t_rnn, j_rnn, input, h0, c0, dev=None): j_rnn.load_state_dict(t_rnn.state_dict()) - t_output, (th, tc) = t_rnn(torch.from_numpy(input), - (torch.from_numpy(h0), torch.from_numpy(c0))) + if dev: + t_output, (th, tc) = t_rnn(torch.from_numpy(input).to(dev), + (torch.from_numpy(h0).to(dev), torch.from_numpy(c0).to(dev))) + else: + t_output, (th, tc) = t_rnn(torch.from_numpy(input).to(dev), + (torch.from_numpy(h0), torch.from_numpy(c0))) j_output, (jh, jc) = j_rnn(jt.float32(input), (jt.float32(h0), jt.float32(c0))) - assert np.allclose(t_output.detach().numpy(), j_output.data, rtol=1e-03, atol=1e-06) - assert np.allclose(th.detach().numpy(), jh.data, rtol=1e-03, atol=1e-06) - assert np.allclose(tc.detach().numpy(), jc.data, rtol=1e-03, atol=1e-06) + np.testing.assert_allclose(t_output.detach().cpu().numpy(), j_output.data, rtol=1e-03, atol=1e-06) + np.testing.assert_allclose(th.detach().cpu().numpy(), jh.data, rtol=1e-03, atol=1e-06) + np.testing.assert_allclose(tc.detach().cpu().numpy(), jc.data, rtol=1e-03, atol=1e-06) @unittest.skipIf(skip_this_test, "No Torch found") class TestRNN(unittest.TestCase): - def test_rnn(self): - h0 = np.random.rand(1, 24, 200).astype(np.float32) - input = np.random.rand(32, 24, 100).astype(np.float32) - - t_rnn = tnn.RNN(100, 200) - j_rnn = nn.RNN(100, 200) - check_equal_1(t_rnn, j_rnn, input, h0) - - h0 = np.random.rand(4, 4, 200).astype(np.float32) - input = np.random.rand(5, 4, 100).astype(np.float32) - - t_rnn = tnn.RNN(100, 200, num_layers=4) - j_rnn = nn.RNN(100, 200, num_layers=4) - check_equal_1(t_rnn, j_rnn, input, h0) - - h0 = np.random.rand(2, 1, 200).astype(np.float32) - input = np.random.rand(5, 1, 100).astype(np.float32) - - t_rnn = tnn.RNN(100, 200, bidirectional=True) - j_rnn = nn.RNN(100, 200, bidirectional=True) - check_equal_1(t_rnn, j_rnn, input, h0) - - h0 = np.random.rand(4, 4, 200).astype(np.float32) - input = np.random.rand(5, 4, 100).astype(np.float32) - - t_rnn = tnn.RNN(100, 200, num_layers=2, bidirectional=True, bias=False) - j_rnn = nn.RNN(100, 200, num_layers=2, bidirectional=True, bias=False) - check_equal_1(t_rnn, j_rnn, input, h0) - - h0 = np.random.rand(2, 4, 200).astype(np.float32) - input = np.random.rand(5, 4, 100).astype(np.float32) - - t_rnn = tnn.RNN(100, 200, num_layers=2, dropout=0.5, bias=False) - j_rnn = nn.RNN(100, 200, num_layers=2, dropout=0.5, bias=False) - t_rnn.eval() - j_rnn.eval() - check_equal_1(t_rnn, j_rnn, input, h0) - def test_lstm_cell(self): np_h0 = torch.randn(3, 20).numpy() np_c0 = torch.randn(3, 20).numpy() @@ -173,7 +143,49 @@ class TestRNN(unittest.TestCase): j_output = j_output.data assert np.allclose(t_output, j_output, rtol=1e-03, atol=1e-06) - def test_lstm(self): + def test_basic_rnn(self): + h0 = np.random.rand(1, 24, 200).astype(np.float32) + input = np.random.rand(32, 24, 100).astype(np.float32) + + t_rnn = tnn.RNN(100, 200) + j_rnn = nn.RNN(100, 200) + check_equal_1(t_rnn, j_rnn, input, h0) + + def test_multilayer_rnn(self): + h0 = np.random.rand(4, 4, 200).astype(np.float32) + input = np.random.rand(5, 4, 100).astype(np.float32) + + t_rnn = tnn.RNN(100, 200, num_layers=4) + j_rnn = nn.RNN(100, 200, num_layers=4) + check_equal_1(t_rnn, j_rnn, input, h0) + + def test_bidirectional_rnn(self): + h0 = np.random.rand(2, 1, 200).astype(np.float32) + input = np.random.rand(5, 1, 100).astype(np.float32) + + t_rnn = tnn.RNN(100, 200, bidirectional=True) + j_rnn = nn.RNN(100, 200, bidirectional=True) + check_equal_1(t_rnn, j_rnn, input, h0) + + def test_no_bias_rnn(self): + h0 = np.random.rand(4, 4, 200).astype(np.float32) + input = np.random.rand(5, 4, 100).astype(np.float32) + + t_rnn = tnn.RNN(100, 200, num_layers=2, bidirectional=True, bias=False) + j_rnn = nn.RNN(100, 200, num_layers=2, bidirectional=True, bias=False) + check_equal_1(t_rnn, j_rnn, input, h0) + + def test_dropout_rnn(self): + h0 = np.random.rand(2, 4, 200).astype(np.float32) + input = np.random.rand(5, 4, 100).astype(np.float32) + + t_rnn = tnn.RNN(100, 200, num_layers=2, dropout=0.5, bias=False) + j_rnn = nn.RNN(100, 200, num_layers=2, dropout=0.5, bias=False) + t_rnn.eval() + j_rnn.eval() + check_equal_1(t_rnn, j_rnn, input, h0) + + def test_basic_lstm(self): h0 = np.random.rand(1, 24, 200).astype(np.float32) c0 = np.random.rand(1, 24, 200).astype(np.float32) input = np.random.rand(32, 24, 100).astype(np.float32) @@ -182,6 +194,7 @@ class TestRNN(unittest.TestCase): j_rnn = nn.LSTM(100, 200) check_equal_2(t_rnn, j_rnn, input, h0, c0) + def test_projection_lstm(self): proj_size = 13 h0 = np.random.rand(1, 24, proj_size).astype(np.float32) c0 = np.random.rand(1, 24, 200).astype(np.float32) @@ -190,6 +203,7 @@ class TestRNN(unittest.TestCase): j_rnn = nn.LSTM(100, 200, proj_size=proj_size) check_equal_2(t_rnn, j_rnn, input, h0, c0) + def test_multilayer_lstm(self): h0 = np.random.rand(4, 4, 200).astype(np.float32) c0 = np.random.rand(4, 4, 200).astype(np.float32) input = np.random.rand(5, 4, 100).astype(np.float32) @@ -198,6 +212,8 @@ class TestRNN(unittest.TestCase): j_rnn = nn.LSTM(100, 200, num_layers=4) check_equal_2(t_rnn, j_rnn, input, h0, c0) + def test_multilayer_projection_lstm(self): + proj_size = 8 h0 = np.random.rand(2, 4, proj_size).astype(np.float32) c0 = np.random.rand(2, 4, 20).astype(np.float32) input = np.random.rand(5, 4, 10).astype(np.float32) @@ -206,6 +222,7 @@ class TestRNN(unittest.TestCase): j_rnn = nn.LSTM(10, 20, num_layers=2, proj_size=proj_size) check_equal_2(t_rnn, j_rnn, input, h0, c0) + def test_bidirectional_lstm(self): h0 = np.random.rand(2, 1, 200).astype(np.float32) c0 = np.random.rand(2, 1, 200).astype(np.float32) input = np.random.rand(5, 1, 100).astype(np.float32) @@ -214,7 +231,8 @@ class TestRNN(unittest.TestCase): j_rnn = nn.LSTM(100, 200, bidirectional=True) check_equal_2(t_rnn, j_rnn, input, h0, c0) - proj_size = 13 + def test_bidirectional_projection_lstm(self): + proj_size = 10 h0 = np.random.rand(2, 4, proj_size).astype(np.float32) c0 = np.random.rand(2, 4, 200).astype(np.float32) input = np.random.rand(5, 4, 100).astype(np.float32) @@ -223,6 +241,7 @@ class TestRNN(unittest.TestCase): j_rnn = nn.LSTM(100, 200, bidirectional=True, proj_size=proj_size) check_equal_2(t_rnn, j_rnn, input, h0, c0) + def test_multilayer_bidirectional_projection_lstm(self): h0 = np.random.rand(4, 4, 200).astype(np.float32) c0 = np.random.rand(4, 4, 200).astype(np.float32) input = np.random.rand(5, 4, 100).astype(np.float32) @@ -231,7 +250,7 @@ class TestRNN(unittest.TestCase): j_rnn = nn.LSTM(100, 200, num_layers=2, bidirectional=True, bias=False) check_equal_2(t_rnn, j_rnn, input, h0, c0) - + def test_dropout_lstm(self): h0 = np.random.rand(2, 4, 200).astype(np.float32) c0 = np.random.rand(2, 4, 200).astype(np.float32) input = np.random.rand(5, 4, 100).astype(np.float32) @@ -242,7 +261,7 @@ class TestRNN(unittest.TestCase): j_rnn.eval() check_equal_2(t_rnn, j_rnn, input, h0, c0) - def test_gru(self): + def test_basic_gru(self): h0 = np.random.rand(1, 24, 200).astype(np.float32) input = np.random.rand(32, 24, 100).astype(np.float32) @@ -250,6 +269,7 @@ class TestRNN(unittest.TestCase): j_rnn = nn.GRU(100, 200) check_equal_1(t_rnn, j_rnn, input, h0) + def test_multilayer_gru(self): h0 = np.random.rand(4, 4, 200).astype(np.float32) input = np.random.rand(5, 4, 100).astype(np.float32) @@ -257,6 +277,7 @@ class TestRNN(unittest.TestCase): j_rnn = nn.GRU(100, 200, num_layers=4) check_equal_1(t_rnn, j_rnn, input, h0) + def test_bidirectional_gru(self): h0 = np.random.rand(2, 1, 200).astype(np.float32) input = np.random.rand(5, 1, 100).astype(np.float32) @@ -264,6 +285,7 @@ class TestRNN(unittest.TestCase): j_rnn = nn.GRU(100, 200, bidirectional=True) check_equal_1(t_rnn, j_rnn, input, h0) + def test_multilayer_bidirectional_gru(self): h0 = np.random.rand(4, 4, 200).astype(np.float32) input = np.random.rand(5, 4, 100).astype(np.float32) @@ -271,6 +293,7 @@ class TestRNN(unittest.TestCase): j_rnn = nn.GRU(100, 200, num_layers=2, bidirectional=True, bias=False) check_equal_1(t_rnn, j_rnn, input, h0) + def test_multilayer_dropout_gru(self): h0 = np.random.rand(2, 4, 200).astype(np.float32) input = np.random.rand(5, 4, 100).astype(np.float32) @@ -280,6 +303,330 @@ class TestRNN(unittest.TestCase): j_rnn.eval() check_equal_1(t_rnn, j_rnn, input, h0) + def test_rnn_default_hx(self): + input = np.random.rand(32, 24, 12).astype(np.float32) + h0 = np.zeros((1, 24, 24)).astype(np.float32) + + t_rnn = tnn.RNN(12, 24) + j_rnn = nn.RNN(12, 24) + j_rnn.load_state_dict(t_rnn.state_dict()) + t_output, th = t_rnn(torch.from_numpy(input)) + j_output, jh = j_rnn(jt.array(input)) + + np.testing.assert_allclose(t_output.detach().cpu().numpy(), j_output.data, rtol=1e-03, atol=1e-06) + np.testing.assert_allclose(th.detach().cpu().numpy(), jh.data, rtol=1e-03, atol=1e-06) + + def test_lstm_default_hx(self): + input = np.random.rand(32, 24, 10).astype(np.float32) + t_rnn = tnn.LSTM(10, 20, num_layers=2, bidirectional=True) + j_rnn = nn.LSTM(10, 20, num_layers=2, bidirectional=True) + j_rnn.load_state_dict(t_rnn.state_dict()) + t_output, (th, tc) = t_rnn(torch.from_numpy(input)) + j_output, (jh, jc) = j_rnn(jt.array(input)) + np.testing.assert_allclose(t_output.detach().cpu().numpy(), j_output.data, rtol=1e-03, atol=1e-06) + np.testing.assert_allclose(th.detach().cpu().numpy(), jh.data, rtol=1e-03, atol=1e-06) + np.testing.assert_allclose(tc.detach().cpu().numpy(), jc.data, rtol=1e-03, atol=1e-06) + + @skipIf(not jt.has_cuda, "No Cuda found") + @jt.flag_scope(use_cuda=1) + def test_cudnn_rnn(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(100, 200, nonlinearity='relu').to(dev) + + j_rnn = nn.RNN(100, 200, nonlinearity='relu') + j_rnn.train() + j_rnn.load_state_dict(t_rnn.state_dict()) + + h0 = np.random.rand(1, 24, 200).astype(np.float32) + input = np.random.rand(32, 24, 100).astype(np.float32) + + t_output, th = t_rnn(torch.from_numpy(input).to(dev), + torch.from_numpy(h0).to(dev)) + + j_output, jh = j_rnn(jt.array(input), jt.array(h0)) + + np.testing.assert_allclose(j_output.data, t_output.detach().cpu().numpy()) + np.testing.assert_allclose(jh.data, th.detach().cpu().numpy()) + + @skipIf(not jt.has_cuda, "No Cuda found") + @jt.flag_scope(use_cuda=1) + def test_cudnn_rnn_train(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(32, 64, nonlinearity='relu').to(dev) + t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9) + + j_rnn = nn.RNN(32, 64, nonlinearity='relu') + j_rnn.load_state_dict(t_rnn.state_dict()) + j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9) + + h0 = np.random.rand(1, 4, 64).astype(np.float32) + input = np.random.rand(12, 4, 32).astype(np.float32) + + for _ in range(10): + t_optim.zero_grad() + t_output, th = t_rnn(torch.from_numpy(input).to(dev), torch.from_numpy(h0).to(dev)) + t_loss = (t_output ** 2).sum() + (th ** 2).sum() + t_loss.backward() + t_optim.step() + + j_input, jh = jt.array(input), jt.array(h0) + j_output, jh = j_rnn(j_input, jh) + j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + j_optim.step(j_loss) + + np.testing.assert_allclose(t_loss.item(), j_loss.item(), rtol=1e-2) + np.testing.assert_allclose(t_rnn.bias_hh_l0.detach().cpu().numpy(), j_rnn.bias_hh_l0.data, atol=1e-3, rtol=1e-2) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_basic_cudnn_rnn(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(100, 200, nonlinearity='relu').to(dev) + j_rnn = nn.RNN(100, 200, nonlinearity='relu') + + h0 = np.random.rand(1, 24, 200).astype(np.float32) + input = np.random.rand(32, 24, 100).astype(np.float32) + check_equal_1(t_rnn, j_rnn, input, h0, dev) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_multilayer_cudnn_rnn(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(100, 200, num_layers=4, nonlinearity='tanh').to(dev) + j_rnn = nn.RNN(100, 200, num_layers=4, nonlinearity='tanh') + + h0 = np.random.rand(4, 8, 200).astype(np.float32) + input = np.random.rand(5, 8, 100).astype(np.float32) + check_equal_1(t_rnn, j_rnn, input, h0, dev) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_bidirectional_cudnn_rnn(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(100, 200, bidirectional=True, nonlinearity='tanh').to(dev) + j_rnn = nn.RNN(100, 200, bidirectional=True, nonlinearity='tanh') + + h0 = np.random.rand(2, 8, 200).astype(np.float32) + input = np.random.rand(5, 8, 100).astype(np.float32) + check_equal_1(t_rnn, j_rnn, input, h0, dev) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_no_bias_cudnn_rnn(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(100, 200, bidirectional=True, bias=False, nonlinearity='tanh').to(dev) + j_rnn = nn.RNN(100, 200, bidirectional=True, bias=False, nonlinearity='tanh') + + h0 = np.random.rand(2, 8, 200).astype(np.float32) + input = np.random.rand(5, 8, 100).astype(np.float32) + check_equal_1(t_rnn, j_rnn, input, h0, dev) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_dropout_cudnn_rnn(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(100, 200, num_layers=2, dropout=0.5, nonlinearity='tanh').to(dev) + j_rnn = nn.RNN(100, 200, num_layers=2, dropout=0.5, nonlinearity='tanh') + t_rnn.eval() + j_rnn.eval() + + h0 = np.random.rand(2, 8, 200).astype(np.float32) + input = np.random.rand(5, 8, 100).astype(np.float32) + check_equal_1(t_rnn, j_rnn, input, h0, dev) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_basic_lstm_rnn(self): + dev = torch.device('cuda:0') + t_rnn = tnn.LSTM(100, 200).to(dev) + j_rnn = nn.LSTM(100, 200) + + h0 = np.random.rand(1, 24, 200).astype(np.float32) + c0 = np.random.rand(1, 24, 200).astype(np.float32) + input = np.random.rand(32, 24, 100).astype(np.float32) + check_equal_2(t_rnn, j_rnn, input, h0, c0, dev) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_cudnn_rnn_train(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(32, 64, nonlinearity='relu').to(dev) + t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9) + + j_rnn = nn.RNN(32, 64, nonlinearity='relu') + j_rnn.load_state_dict(t_rnn.state_dict()) + j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9) + + h0 = np.random.rand(1, 4, 64).astype(np.float32) + input = np.random.rand(12, 4, 32).astype(np.float32) + + for _ in range(10): + t_optim.zero_grad() + t_output, th = t_rnn(torch.from_numpy(input).to(dev), torch.from_numpy(h0).to(dev)) + t_loss = (t_output ** 2).sum() + (th ** 2).sum() + t_loss.backward() + t_optim.step() + + j_input, jh = jt.array(input), jt.array(h0) + j_output, jh = j_rnn(j_input, jh) + j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + j_optim.step(j_loss) + + np.testing.assert_allclose(t_loss.item(), j_loss.item(), rtol=1e-4) + np.testing.assert_allclose(t_rnn.bias_hh_l0.detach().cpu().numpy(), j_rnn.bias_hh_l0.data, atol=1e-4, rtol=1e-4) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_cudnn_gru_train(self): + dev = torch.device('cuda:0') + t_rnn = tnn.GRU(32, 64).to(dev) + t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9) + + j_rnn = nn.GRU(32, 64) + j_rnn.load_state_dict(t_rnn.state_dict()) + j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9) + + h0 = np.random.rand(1, 4, 64).astype(np.float32) + input = np.random.rand(12, 4, 32).astype(np.float32) + + for _ in range(10): + t_optim.zero_grad() + t_output, th = t_rnn(torch.from_numpy(input).to(dev), torch.from_numpy(h0).to(dev)) + t_loss = (t_output ** 2).sum() + (th ** 2).sum() + t_loss.backward() + t_optim.step() + + j_input, jh = jt.array(input), jt.array(h0) + j_output, jh = j_rnn(j_input, jh) + j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + j_optim.step(j_loss) + + np.testing.assert_allclose(t_loss.item(), j_loss.item(), rtol=1e-4) + np.testing.assert_allclose(t_rnn.bias_hh_l0.detach().cpu().numpy(), j_rnn.bias_hh_l0.data, atol=1e-4, rtol=1e-4) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_cudnn_lstm_train(self): + dev = torch.device('cuda:0') + t_rnn = tnn.LSTM(32, 64).to(dev) + t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9) + + j_rnn = nn.LSTM(32, 64) + j_rnn.load_state_dict(t_rnn.state_dict()) + j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9) + + h0 = np.random.rand(1, 4, 64).astype(np.float32) + c0 = np.random.rand(1, 4, 64).astype(np.float32) + input = np.random.rand(12, 4, 32).astype(np.float32) + + for _ in range(10): + t_optim.zero_grad() + t_output, (th, tc) = t_rnn(torch.from_numpy(input).to(dev), + (torch.from_numpy(h0).to(dev), torch.from_numpy(c0).to(dev))) + t_loss = (t_output ** 2).sum() + (th ** 2).sum() + (tc ** 2).sum() + t_loss.backward() + t_optim.step() + + j_input, jh0, jc0 = jt.array(input), jt.array(h0), jt.array(c0) + j_output, (jh, jc) = j_rnn(j_input, (jh0, jc0)) + j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + (jc ** 2).sum() + j_optim.step(j_loss) + + np.testing.assert_allclose(t_loss.item(), j_loss.item(), rtol=1e-4) + np.testing.assert_allclose(t_rnn.bias_hh_l0.detach().cpu().numpy(), j_rnn.bias_hh_l0.data, atol=1e-4, rtol=1e-4) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_multilayer_bidirectional_cudnn_lstm_train(self): + dev = torch.device('cuda:0') + t_rnn = tnn.LSTM(32, 64, num_layers=4, bidirectional=True).to(dev) + t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9) + + j_rnn = nn.LSTM(32, 64, num_layers=4, bidirectional=True) + j_rnn.load_state_dict(t_rnn.state_dict()) + j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9) + + h0 = np.random.rand(8, 4, 64).astype(np.float32) + c0 = np.random.rand(8, 4, 64).astype(np.float32) + input = np.random.rand(12, 4, 32).astype(np.float32) + + for _ in range(10): + t_optim.zero_grad() + t_output, (th, tc) = t_rnn(torch.from_numpy(input).to(dev), + (torch.from_numpy(h0).to(dev), torch.from_numpy(c0).to(dev))) + t_loss = (t_output ** 2).sum() + (th ** 2).sum() + (tc ** 2).sum() + t_loss.backward() + t_optim.step() + + j_input, jh0, jc0 = jt.array(input), jt.array(h0), jt.array(c0) + j_output, (jh, jc) = j_rnn(j_input, (jh0, jc0)) + j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + (jc ** 2).sum() + j_optim.step(j_loss) + + np.testing.assert_allclose(t_loss.item(), j_loss.item(), rtol=1e-4) + np.testing.assert_allclose(t_rnn.bias_hh_l0.detach().cpu().numpy(), j_rnn.bias_hh_l0.data, atol=1e-4, rtol=1e-4) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @jt.flag_scope(use_cuda=1) + def test_cudnn_rnn_speed(self): + from time import time + iters = 100 + + h0 = np.random.rand(1, 128, 256).astype(np.float32) + input = np.random.rand(128, 128, 128).astype(np.float32) + + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(128, 256, nonlinearity='relu').to(dev) + t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9) + + t_input = torch.from_numpy(input).to(dev) + t_h0 = torch.from_numpy(h0).to(dev) + + start_time = time() + for i in range(iters): + t_optim.zero_grad() + t_output, th = t_rnn(t_input, t_h0) + t_loss = (t_output ** 2).sum() + (th ** 2).sum() + t_loss.backward() + t_optim.step() + print('torch time = ', time() - start_time) + + j_rnn = nn.RNN(128, 256, nonlinearity='relu') + j_rnn.load_state_dict(t_rnn.state_dict()) + j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9) + j_input, j_h0 = jt.array(input), jt.array(h0) + + start_time = time() + for i in range(iters): + j_output, jh = j_rnn(j_input, j_h0) + j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + j_optim.step(j_loss) + jt.sync_all(True) + print('jittor Cudnn time = ', time() - start_time) + + jt_cudnn, jt.cudnn = jt.cudnn, None + j_rnn = nn.RNN(128, 256, nonlinearity='relu') + j_rnn.load_state_dict(t_rnn.state_dict()) + j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9) + start_time = time() + for i in range(iters): + j_output, jh = j_rnn(j_input, j_h0) + j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + j_optim.step(j_loss) + jt.sync_all(True) + print('jittor native time = ', time() - start_time) + jt.cudnn = jt_cudnn + if __name__ == "__main__": unittest.main() \ No newline at end of file From e57f0e4879a8a804b560e8f3bd552e42dae9868a Mon Sep 17 00:00:00 2001 From: lzhengning Date: Thu, 30 Sep 2021 16:04:29 +0800 Subject: [PATCH 2/3] updated docs --- README.cn.md | 1 + README.md | 2 ++ README.src.md | 2 ++ doc/logo.png | Bin 0 -> 28818 bytes python/jittor/misc.py | 6 ++++-- python/jittor/src/ops/reduce_op.cc | 14 +++++++------- python/jittor/src/ops/reindex_reduce_op.h | 2 +- 7 files changed, 17 insertions(+), 10 deletions(-) create mode 100644 doc/logo.png diff --git a/README.cn.md b/README.cn.md index 40c12197..df0727e3 100644 --- a/README.cn.md +++ b/README.cn.md @@ -1,5 +1,6 @@ # Jittor: 即时编译深度学习框架 +![Jittor Logo](doc/logo.png) [快速开始](#快速开始) | [安装](#安装) | [教程](#教程) diff --git a/README.md b/README.md index 2a35af26..18918758 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # Jittor: a Just-in-time(JIT) deep learning framework +![Jittor Logo](doc/logo.png) + [Quickstart](#quickstart) | [Install](#install) | [Tutorial](#tutorial) | [Chinese](./README.cn.md) diff --git a/README.src.md b/README.src.md index 0be5b964..2294dd85 100644 --- a/README.src.md +++ b/README.src.md @@ -1,6 +1,8 @@ # Jittor: a Just-in-time(JIT) deep learning framework # Jittor: 即时编译深度学习框架 +![Jittor Logo](doc/logo.png) + [Quickstart](#quickstart) | [Install](#install) | [Tutorial](#tutorial) | [Chinese](./README.cn.md) [快速开始](#快速开始) | [安装](#安装) | [教程](#教程) diff --git a/doc/logo.png b/doc/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..0da769bc1dd37438a3de3ae2e9ea0e4282c504ef GIT binary patch literal 28818 zcmY(r1yoht^FDs*E+sDAf^;KDNF$AK0jW!OsB~XiN=mv@y1PTVTR=+bknZ{)-q+9j z{r&G+EY>}ov-h6aduE=Qc@7~eN-|jJBuS5ZwGM$b(ZT|3f^AtNVnM&O4#}pr;BzOIR3Md#3hpypNlaf*G%H;I zrit=H27`b<^`Qb+tiaZA^Icm;3iEJEOwxb}*P(rS&+N~i!=RsNf|IEeUrmNw98iJf zJ$*cJ172eYM(thW{`(!|iUEh&K=5v|CY;lGbcyrt)gvr;@Hi@~Wes;PA$Ko24+`AB zf9NnrHRvEmhZ+$Z%lsrm_p;G5a>xoz~>ES zNmh_N&6imQUfxA)er{50v%0~c?^Lh!NIc0BHN^1EWGUfJu>;4jF_i0@pL)|2HcxB1gxI?)le`z`0tfU+zu`EOXHa#A@%ks&k_2+m&Z*}4Okd2en`DGw)u!sPe#D; z|BUbg(4_Fv$<22I#Q*m)xbGMv4bpBisSE%2W^XZ&v2Y_fq&@Tg?@qoE0#^3S-9x^^U|jBg$gv{|)>M7`uPvds0LPGhKGoJnxd1O}(9n?S?=WD-I(7&d=TR0Qo8_;0(Y5bLOt`a7PR$0jG+ z2D7i&edkGu{AnxJl!zOICHr@dm`wOy-FT)mVlDc)&Q3kNzHs z{$NmV10hC%w;`&*KQnk=U2Hxf$e6{jF+#2Tu)jPdJ`EvLkK>Oq#<;x!u`)s zh~dHhZ-t#`LOAQx$7_3$>Hdy_1WXtCxwDzgLiF!cDj>LUU)_(p)~W~rLXJsnrtkzvBcVVjb|`QGhzR?3lZ-b(w2 z<%`6NvEhS_ID{MJU7F5!09y+`){S< zs0L?m&CW3f=KeGNNh;VxQVn--EZA6{9$M`1{~O|l*80Fv2kHNJP4E-sJtw4{Mdd>T zZpYv07z04k&`$)qDT9N)sup7Z+Zx~$DTTAX{@MvnDPCR3{LkxgB68eh838K%pI6DG z0Q3T$)4u0|Oz;Y1HzNKWH&+~#gTrLj-td3!wi5T}ny=IOe0E^?yfkymGeCzOr6Z)F zo53;x3IfpgvCjwzja)QlW+nRqWTt0$Q!%iynLh8~osVu;9;Ynb+yA4tOk-37c;%UW zqId7{RLvywd#lom#n0Dt-TCe@`t20 zQ^k_gF^AJbJEq9U(T+>Pj)j6rdLEWnx7EsDUbLV5GGF@xpYX*oGQz-R>2>2psG--1 zxxUYa6S*hTKic?E07HukF~?aTt}ZO`HUtcyq7G+{9L#Hwc|W5_>Fv}bn!YrH1N(_*N%!5*A+-8zVe}NqOC5(E705 zoDBD^G#*oW2eR89ta!w*CSQ*EcFUaSG;%nPF#Pjm2fc+DI?O)pescM{jZAV9>#Rhv zO0dsG4w9v|o2O&j#-9%$>uURzPgkPM>5C$c_phO{jkLd>@9z(6Z4rD)i;Bt@_&2_= zs=z%555cS4ca8fY8VFo0y^cI^JJa0D4P*~6>o3u1w&xj%pCJ3^9KZ-s4e(QRy{SSt zoqede3x`FvF`K@m8(9W~hA zA8$I!DP0{nu!bQU!#hPL!+!r9f)6-w;o@q8?U}m?NN_hN;0_D}yk-&MrU!<{>ZZu@ zy}bKcyMu=#e$B&cP73pXhSo+knAXPYD31Ukcfea<58c}PM3-eB6to`bU~|jpPQM!Y z*DT+Fat3)fZn?q6Y6z|b75Ate@K#nNpP;U0){>camNVDde_db$UUn-g9oqCh5s`Wa zDDmePTl(U7(dJT|wlQErT-kvy`#ss4T?Crzh&Mv$Z@`Eg=nADj5;_xEqh{FmZ9Fjn zfDl_;Y}@c<0S5n4@P&}U3_tLo``7*cQPXe&eBilHWWHPjUu@9z)A7&F(VdQ@`CuG? z7^r8ltbiU9fA*fEJRWs%>lh<=F{eRlcW0jh$u<~=#sag$V)n+xEE@NPgu~l8f`Yaa z5SJlxxCt>q2D?c(J@Sw-e&I*co1=$Ib41(CP{AqDcJnJ@t2=1T_U2`)^QyP|5|jm z^Rk!+1y}$J&3v??eH+`s`Yap2+v}Ku`GmM00)c#Q+N9u*Z<7%WfUTrU`e;|KCxeg! zIA{Z-&Q3An+sky!D=f0a?N_CZ0EF#;ziJ^OscY^S(y<=IeK8*$u}ieVa;owYX|o0m zE?lU_oplv~b07(B39kdTVY95lvw}zDly^q;ldMQ$midpLRlCn;ldHQ#||pF&2&m)i~XO|RODieHti#= zcz-O|UwpV5a={M=FDW}T%0)+KP53@gtW;WfLVAa(ajdYx-^-LE%S5l8Ac_6fDN*aM zxM)h-34&N}x*p~kWKNaS3IrK;DRXDGyl!2fX@WJE8Bds-?svEI>Amak85?R$v@dhDtkfOdo(nxArV2OUATXo_@xn#tJ{) z*2+m$tBW}!N}ng8-W{6sI@uR(H=ZJ^Xt={^cNez2OV7Vy8ez+rbU;~nd3cZ2r7AsFAS@6;=x@mfp0LEobZuc58+~M~yy2+fJF#gSNjje?Wfzxwx>SBnUL^>)+z z;!ybs;W>Gp@6AHHkCojrSX(4hL2rzB7 z*Vx=!+D!I@`4<#aoF-!sxE^H@*P1K6_OY6tXV1;g_pS_A$pnI>*a)0;H6`0KsAX;f z-y2W`zIKi3%VT6xfdD*WCW8lzL`6U>xbTe4)JL1FjMy>LE@Zk2vLXu=&c}_L#%}9x zBT8OM4pg9RlMJOBNU}>Gn9ph$cFAQ=de4s_+Pe=Wf0b6!IIWD24n#&*-Sdcv>WDoY zHaCxvtu7z7Z%_}Z#+gMB+#A;9wGQQUNC1bGPvfoFk=%y3w`dB0BCTJ(EbF&?8U6MXp!T6yuDk{^B%bP8yDbOnqa1?L2L~@X*hc`GLd-&BN z98Ejz6DX!o4Vn!8Tq_^&r>#Wyfpv;bi$g)kEZ^$#y`y5T*;~&!9|MP59tMSJr;zc- zca7%@PBtPRUG`?PXm6Cz&JHdT!Xt#}%FOYB%@DJ0t6!%i$rM;+5Uz!Thn?4;8JdvT zW^%xvLR(B$iABeq60~=5JFXb3cVp=uat#8qVqFc)kF=v+&Q>EwK@C zOMSi1LQ>MW%Hc*l8`t9T9yuhg#7W)sZ|B(7NZlCmJ5+uQO(Y*WSzn?lqZH1_v)jTz z28u!f@9D04BbA&@NDXHRHrris4rFHi(G$xW@^)i+9s&%$#u)S}iBbx)U$ z+(HgPH*xj51Z!YYNGE{COhvEnS_CKN{KNwjREFoA_$!AGvl(213;SC^d=>=-1Cg=E7(sPq1?_S2B%KFFDT82*?}rst&B& z5IOX3oM)T?U#-)3s?VbYq+wbLu96n{zy^EvRQw6*HXWAp;+mzXT;FM5*~FjtBI)n- za{RidYTlYbiI|7Z%YBe(gn>BrYlM)=j{@(eB&=MNs#(wK*$kI!(R~m|f0cSzE#|P8 z!KtMQdM~H>(s&?JLY^*6#ZJKBUi@i z_D_np1wxj3R@J+cw5Yy%tQ{;s^KXyp7&(H5w0Fgi6i-=(8fQimStB`9@V=@4CfBMU zxt*Zh=tA+TlR1Tgjho64vz+PNUFhlP81HxO>4ml5s|p)!UBoopdni|35QyWCW0KC7 z=Z!ErS!8gLP?1_kdK%y5e*ZRt-&8|yO^XufFJ-@ zu&g6!m_{&X%H0tN+8B1;5gmp!lxv*yiBYZBd3a7a4RG&R_nP|ELG%cjXkf=wLyf!< zq7@Nbhk3%)qHb`y(ld8N+MG6DHOSh~8#jphocM9%Jsi%{Nqx@yYhTJls@Jr55E4@8 zPSry%V>7(mYYaM-{`WxmxQF$wkzXy zG(>OwjG55BHcEA4x1$0Mvnt&AyB#c~uo2%=+j1i7^jLQ-J2c zIe~5{)xpGMC^sg7CH+Z^^n~|Q5^gPE9?fDte)8L%5DJl zCBOf=b5VV!tX@%oHiX+z6TNZgisaJ)UJiM1R_}ntF{V8J`XISmxK89r31n@+KYvLB+v_pZnOl12?wLaiP^F@e6))w2(Cp0?<4eOYozPdwMgz^xwt_oCn;RUTfcY$rx4FA8o{6@O${It^>yS>gCR_nX>h@Ft%d*>r_ltV^W@KE(KXhov{3Dab$|N&x}fp(}s(>!HBh? z(c&)ki<3{oukcb0*);!{@CHOM0z|O1v$2`X{b7K~#ebs0IL<1#pScN>8qhu+aV;i# znn?BvO-pluo#{HHI6dQZ+*yqrm69k39#MgH^!54!^nd-?Law$8f_076Br;ydA>_&4x5bog0}T*lu?OnOs1MLmJ1X)LbZo+60o z+6sPDH9lhUA$Q$nuFDX3AauQT?L2K3ndNI`4nZEI}p)2_9 zsB{2u)G#5Z(5q0xI~(GtVZiPE{QjLQ0KU|u?D~c3QdirgE-J_X1UR5Z{kr*vBINL+ zWCQ|^QGGZ`(OWD7ah`-kt@zU%x?UT%7b%z7$4)*dOa>iXU;o*O z?wH2t;P2<*UB>>gwVNqffSC?c$F?W^E`;~V`x1PhA2c?A%iG&|P=Xmdad4LaATtsi z3?X4QFTA6pz!i$kKBuF4NcgeSnw^CD5F*F>(MU;=iZe{osE&6^Up27fy*w=M9c>l8 z(lJ812H{I4@vFpdes1N!gv#QfvTM6@5_67V|jcZ zp-prlnS3f%{`u%qfz1BABf9zM*4a5eMzK<(lAMa+XTVOKb_3-G$&zs{`1-e`u($LF zD-Pl}PyQ>FlX-?|j5?pjffna?)*AqgM_OQOMGf<=m<-l*_8q;(28EY3&@iwAV-_wo z?>8+uPQwlR96u}XKjL4-ZKG!b(v;4j`(%WNx6>1TDUD}MzTz)#8mGvdl&3-Zk%mb- z2#G9PD2>4MqJYo`G6uw?bCh?!p%=7mMnCrRD)Z@psfqcDD9sRRW12QoYU@Y`9}Vka zU1MLqC1kX|VxXhj9^5pWy_^m?#Uo;$oi}ch2~bZfx3uuUXPGzQZjn7{$Ld%qSS2E7YvC8UMQ}SciV&4g3?ZP%W67iR-*8 z;hH2_SQYa7JO!y{rUXQB=arnZ?FL)Ni9j&ppv-u!vplj-LCR?eWRmp>6t6GMNfeZ- zD`{vYp#ORqAGpD`KJ^_UGP19#&b`Dw*)>%RSoxyfHGXaakpt&S=K&#V$V+YJnme@k zESIFX9;hoiZWer!RkSH}G#jL@Duo3OERle}kmNswy7DTee<@N|&+-db(8v1CNrp$z z;dtg?I|xX2?M-4_M_@)-w+FXe3mRf%Zw9m9)=l0K_5NTQ@de$UiOM`E9NG=G5dzV5 z5rDvY*#fv)&U0Z&63{p)`+2ReUE2JR^Mje!O{R;gQrA)*;_rTbPo+PUTkMy}dZS?D z`si|c;dB}QmKuE}+)>_z>YXZrJJ&v>fu_5-gXMVUhJW=N!~ zzZPWM+{Sm~*QxfXo$L!plqY!~5VVsguK%DUAh?ObMKwyT5bj5M*>**e)cIRw)MWmFNvSKAGS>X8LZD*N?K-%<{K=KqrUA zcW~Rqjy;<1Xev+x^P95rgdWIYS90e3(&v%4kd+djL+Ge@hm;2eaiQU1m9mI^Ryp^X zt2xAWDEBA#x8fLGXy*+4gvtv3n&a%nN4$%#7srd#^P`~~B!U)t4Fb$z_^9r5nL4~-bFi{Zd)b0tmsr{bTuQh%%1 z+USi-+xEOL->>9(hahiYF`Kx?`_#fB_U<&lug|U%4sjpcZGf`Zb(Uf0{PZ6g0`n># zWX<-L6p@?n3_rFW3_jjW_=bGJLN$VQgY?;t&qp6LQm$(0L+i;(SGiSTn%xb{238V) z=Q68%I;YD8Yr8a7Wuf5;> zno351hkS+Jm1ithOW)X5!lPpdLnu`Y8XpeW%krKx9c6sc9l2vEQ&z5|g#mSZ0BO8i*|=q#k4_sHBPi8M9gjA(`|>v9v1=H=!$e@N)L^*o zjYCH&lU~-)T3f5?F}wzS?S?O)5~oI^L^DB^$o5T`0ncx-O$=LfBz2q0s?fIkX37&s zHD;zFq9Rn5QW#Tra&odh+PY%iwRG$4_q&7b=+Lak+oKzcxEa~y%qRmOlg4b@TOOcH z6*w5qaJ!G2+u2306YKt27%1b=fm4$%^6FLp;hZ^s3l9$utM0>oorv$lb)K6=h`&32 z<_!5O)j0S#asY~=;W3RHYtHRxEjj4SQtiZ8Ep* zE{^Sq1#nXo3|3T{Nk8@_>pb3$tSmmvJC%%N?6liGF1sg{CO!6ytk{t#0xb%Fzu=M9*F8EP-fx_B{a(=;6IQzai*!ZsL@W$tD zFot)mO7?ZWI7YH9)ops_H9NLUZZpF+EZaA^3K~8YW=m30C_sf92=D9Jx2&)oP+nO% z@SEK;nINk^)-fl0umV(;SATeOwp~n=>ZIn=vL44F?4w*rvw18dicZN_byzzxvTt%I zGV$KMBq(dGHhZ!5*K;x!LzSO|BXI1h#j27f+p*;>SMQ;%LNd!vhwBSXO+n4PC>@^= zM+XcNAAj*iwy!L$al311eHbs}{~5&Yr`-cSRN9HkdJ2uNEj_2AlPGl)w8Pu>r&fwJ zwYj}_6^uq~IIZnQzvHA+t8T}5%&dIDuUMhex^MeX3`C4(%Zl7)W15K_ycb|u0#qszUcVuzNZ&IL=W6B)$J}oEzQn1Y&FZ4+(KK@n$C59~X?QuUo?(q` z*nQ!g9K#AE6}??c77>LC?xlsEgvoQ8`Uf$CLj zeSIcRL>KeO1Sso`x5)y`&)O_^e`3hV4<>#ao#09k+AamJZyzE)BmWJMGZg@Y!~IXb zkDLTP-g?rzPEJJjl?(LHSD(In%E!BNG}$sy<`D(o9*i-QYF9FCM)9S6>L&#j5$Yp!Psay(oX4e$jF^R6O9#q?hQZ90O zW_&Fiu^=ry!@(o5-WhCm*QVp>I(+puL?Bri9rWHd>K^%fyd~Sd`}Z za`=o+N6;>jh0JsaQ}-QFxvP^t>^*zhE~BSm(PleG4l)4S>JvE4bczHFJg1GmYW`p@ z%t7fh&ATN|@oX7r5qPZ6Z9ic3JS#0O8Gc9eU5C;Pb z2As~Q1>4#1Kef=6kwt#(NQo-ddw4t4V%Uz2Mi82;z38t}Zbu4Ehy$*!AO?F+S=h@?eOrO$xBbE2@}lf)3g}$ zsJ$)h@`HjrS{3zSpRBd+rR;8OD{R#D&tXlZr_k@2Dw%cp%fD?Z0Ju|Zf)cNhmCY4? z46AZkNsMd1KW|EB(kw1M8Vt*j0#MKu{gfCgZ6(%ZF_x^lubQ4-`{tIYd(Jog#lhis z_c1_Cznbuu5rIJxBSE-u_0I|c1&Yay(xS2l| zC}(rm%;{S3U)RWlYMPo7|-G9|u5QKfirIi24K11>^`# zebPc@0#xHbkA+DbjiMrxj4_*t z&&m~#Oe`#}##CE20~Y>g&Ir$JUkO6rpLQU)lw(hkO{2vEUBu~Q0m-t>7(Q3)0izM$ zj*76mA|PvZDTc){mGwuvk>3fMJ2XFmK zFTyfCg(j*#?bZTM+fM(sux$$C#)EPsv!KG-Lh?;Y6J~ER(So*`$JrUN^3Y}xQG4R{ z`6ko0)_Knj>DOT&RiUj;dATB=UZA`e>-8?zC@c;_M}PV1k0oPCQ9LiRiVCsWO54Vf zPr5gYPXU2;Z~Vi4+L2FNC0H9b$$i8pWEH=4((coO^E9_oQl_gi6`?$AnboKUWiuV=LXRC^Vk8IO;vIbSdLEWdW$idcxsF z<)m2!bh(|+d7yi4a|ytM>ly_1Z<8XX2axG*H?gE%Q|i=tS1jucZ8w_!`YMLft0c?e zc^;Pe-0*#YX;w9*yD-!zC-!|DTXs?QMnJk~Py!OiXC0vBW0oxe)`z2S3Je|X=4xFc z>v@^-*mXt$1ng4{u;X%A@sf#Dn{j~VTP=4b|0QW|J#7ghKpy2hX2Cn=I^~$~ z2!ex&RFu|IIaBnl_J9E7!$~q+*%?A+=1dhBXWihGVrgk1Wo{#Glca{}kCn%hj71d8 zn#Yxx+ml}rE4}n z+=f1W=ww3=%%7*@N0NM5)%C{mFpn_j(s-(c$tn3{F+VrHBQJ@y4saBo*=oV}#Rg>B z;rQ7R1XU{nxkTwz{2-@jAl7tCpT`P&Y7HDzgnsUiqEY5|0-ANAZ}OtSJX#P(Q*a2>wuO+WzRz2Uq441lOeHC*hJ6yp5~_+ z)|o!m#@m9CJtGzEYXK(Bd=ET=X9bM@0^Y`Z(uIqZG zQZOTL@+6tH1_6uPM$gz$bq6>n^TUL2>^3=s9xlAZS;;+B{VVk6&P_A3C#SZ@JTLdV zFA>E2DwgtzEP@%2zZ5|*kTgZBW_wMly-I&pH*Qc_AKpUmwCd5s3)lB7J9lkS!R!yI zDKFCiV|qfrQ0Q=?q6;_C!pn72H8*=F%0tMn%hE?vOV&8H5erVPg80mXjWlB+kwzE^ zPwuC!i)U-3%nLaJIQ(Oz=VhC3V1dEGKN*{oeU~>P!_nj17ybSG>|`K8V8^Ph50dU= zpIuJ&PW^d_#1lT3U|K(KleEIt{;|h?_~$tZ$r&v4TBmA3;3YnggEDF^eo3n*CuZ_Y z0@^s6LeN`>5xC!z4j7UAaG*``LF0gn96~?w#Aj{Ny3^N;_pkpc3S`4|+7=t)BeI8Q zM-_ zz&qcs2a{3e+D`+|m}xx_R5a5B-^WqiTI-6sxE=tYMcEn~^&fw((0&p6Nm;@Wf)uYo zUBop!;~X$Z#Om%n)?$06$7B4)L8T-YsF^SfT?NMmpOvg1`qg3niG_R*awePHjmglONL9P6_&c)JA$D#@^{Sphrfi`5w6T+~p{~v@30Rf7 zC0#bxKV8QHAx%YYC~lkUhfPNamrt<9!fDj-ylm@j@2j15&zZ9~If3lz_g!G_p-oyM zKqjf%Xi4w7QrGXrqIX^h%LneCD7s8e`o8dYilN~!>uy!M0X|TQ$6fw3>EZ2H-2H~? zkb1HZK}|K0$3P@5V4vEOd%KB*Yd3UB%5JpJOh){d(>^L!^XORD4n2jack<(PR(p)V zW~G)y5ZF%KB?R5{e&a!bSD4|LNpIa4_32d6`dFY>@^5BAqNT{}-RTp{9dloXnG>2@ zQ4!@s{FCl~Us!LtuQV~MNtvF+rVokisd zXU_z-^hn?+hsLc0u+6zTZRwr+3@<-x;;9W_`!QlKlrS6EUe(j=w1%sN^g43o{#Zup z9iK*xlmmqsK-HlDYCCG23UOWE7Lwfy#Y{}pw@1L}%*YT^g_N|L)NeitJsnoXxB!+A z^%}e;!7`d5Qw^wVo-j6BPkg8a9CvE-YKbtwU;0!il#`z?vT-(0wC$JHXAG%vcNZVS z0a{5r6xbLakHwd8-#)Cp}8tS z5?1M6M@zk{V#>%jVbPuHnsYvyJ#T&YvEJzA8QBd29;?U%r`^WJVpo~LH@mQ)DEedT z{l^2vZY6gT^kz4Si8~A;E055&(JkiUq)^_YZt06DT?*L;jUd}6YLN21p~M?A-*Mwr z4;MwoNGD0}WtEjH8}k_~d}n~taF^zRAvdKDFlAh+%1(D;A(KHeYV?yqm+fB(14deI zEW5Ea_}2D`%+}#r{AR)a=H$zd-p`I3_RemAc(#6$#8Gzoj)|%NyiP8+fyUersNU@e z1w^Z!Px?KEha}EuZDg%ruOyjfX_UrZYr)k0Q3tMRS+t3k|N?osXV+dff$E59etRWzl$1 zb*DuGp!LrQ(6r<)c?=dPr86QV+rKep=GuBuXNlF~;blopaIWEv7A*|La5a_O?IX=j zPm?LfDfU3)dj70hzzc)a|TDExSG6_>-K-x<03$ z|5ADbWO0xu6Ipp1K7Rr~UG`fLZOn_+JxEj3hW2rys)sSwo>M$5`dHVB)7`dP1u*-J z`|apYJ8+1$pk%(+=Qzm-aQQ}d=n!(pFzi!TarQyKW`@ns2Y%@NRvFZq4L@ATOhVl6 zkKTt%$u))mt|LCdtWrmH$myi0vup-b{6K`?%Om%w?=6)|ZBHfbhFsC5?zZ{$1cROP z#UIkv6QI3$uXT#1|0^=fu=1jxbtq$r&+UAuIc^Ery{*~%&ym1c#T40DES@>7s>k-3 zbva`WmKVVJIaE$bYbwHEh_nRi>Zz1?#DXEqa@`GK=m`mPZy(e1y&I#y6{~s%-<>g2%=?2GK!+VO!*wYnDwWT`YW^VUb zWNq2>j#+X83molCch$RyLrU){b(<_A8(tHaej} zOi&!Zvt&km*h@O>Ng$TFFE)-oUtGL5S3uSzKZB=y=d@f1)aT5h_oZ~ya6iZKyuSV9 z61AOrXk{oj}wW zv0Gw;$LmId=N>NR_Hu*$*`~#-w$tD0s0Cl7#|YNcZ#Ia&$iqqYc#PRWdDv=Zmxl%( ziXU)is}jsDT3<*PH(}|mS&}T76Y4(%D+AW2r24aPv5P;l7>4UN(ba89pQoxu5tFp; z8V01;0z3308y%fR&-a)9-dR@O=p$?sy4#{V5M8~EBGf;m8PHc%_QQ z;z=&t3mLe8P`Da*;bS#Uju`zbQeFc_Q4cZqS2BmP=EGzspEqHF>pmYoOw8UH?UeE` zA&V8`k`Uev|Bm#|%{`YDyPEVvzSejuFavB#c^{D%kS%Kx9- zuGB4a`}njP!i(aOnge#z%AN_xe^Wt;Ja(v3+IK5hJG{=TIeeX=99x((-b3D`U&HEV zSw-7CJ~K~~#r4AAKt{n4hhHRH3|2LKH7&$LQ7 zI>F$@GcO{s!BSM`B3n|>T$=rK?~p*L)E9=2kqHui8y2FZd}qZrs~jWD5mV$tvP+sV zNj5H^op1}_0?w-PjSAI!IpI=|y3nJ~;|hQYhl-6?0hNF86#uU^6sk_Mtfx~VK%b^O zMs7KGz2P1(OXLDn&QwTkHh;=|Bz{I5#g7}mHXnJ)6bCLAmMgyDL@rD`23z{HG=;%} zWq`;k*$cK?@I1H7_O7R)w z%rX_#H8D7Z3={*#nVECBx}mdu)8woQy?i4qAqa`n6EVx7L+gpF<>epBN94Zjz8;eD zPgXN;oeDGqd=yf-EaF2mnuCH~WD!WJ_W-<`@3Apns=m#D@Am_ z$Yicz5YZ6#FUO{}RjMHt*2jCtoy8w?v%-=_eYZ{`H(I!YxPtF~g{GF(#fCK4Gd16w z2R;?Ih@{RX>&{1**0vk<(|qYmO)Ys;<}zPqm?F*aR(zg<)EaIT?b$>qo_%q)aflk3 zTB@`Gs&?w_T9R(_QuL?LGuL0=in_WB;QWhJsJNJ>KXiv6$!EvaR<(a7=2P^#Q=RYt zYQFIg-0~JPXY<@dcGF~0pJ(w8=k=PNXm;B#`j6s6nb_=AK3pUa)N7^-9NbynoPBLa zxc){cvDuBjM3JM{h>$@XJ0{vv^aby#;m-*Rb^8-E zL4R#rr|aA8yjTNX`UR0oK5`x+q$Me61_Rh^LQthXSt_-V38Sp++oyx0c$~;a`LGxQ zsN$^D+ZX*Z{w$>}Mu|*9CK((mXr9K#v3OAY`ACYF2_j)epQJzi&5E?10+3JUukSP4 zDfuG0E(X&Hx5{|~Wj~CCcwVxJKS27{Pe0JDNX*0)z_sXk zE4T>2EU}>4Z55vhl$dpx_3TEb$d0KVmf{=PlImtZ(Q$lU-E}8)J89uAsaoJwsW-1K z{XLx*VVW%8F}$e*5UL=q{pqzI?tiY;PMLiZ`6iI*@K34Ih|xy;oo?Ywuu+i~f_;#p znxP@}SUbSk+lhQzlzBDbKZhSl#=WXKB8EhV)fdP5PkjZ!4XGgc&nTr56PVj>m8V$r z8#wVZN0)`|t869pBqSz-fH9vnZrz0-R+^x@Kqn?T{{GgYGN2_rDYwlIW?6dM2)nW`y!XCyCLp>Ny5kWrmYQqbx|d)jqS;Z{0q;H5_`g6zExpLGc;+ znR8if4`WSH{uV!iQ?uv2>b()GGXa7{;a6xRgSYxxrQLzZ&SCzOe}A20@gQ%qEuV}) zC~m#V-|}dhy02pcEW4-(%E+N4_vcvu<_v$Tz8om+%viDQ(-1&((^alMP1Vbkkqc(C z2OW^45Nx%7E~y(IlCw>c6q1&VI8!mDrzC&ZQ4N=)=i_sAo_O&Rz#0$!o2@^~1;z!I zjd)h}2y;Y8b?U*K0L+wAu|=o_Ljp5eam}o)urDin`ePN@y*Ve)yPIJ-5d<3goC*XK zRmxyS&!W=8%Nv3BGBFp20@SU7+sVW)Ay0hWIkE~|ZG`gdPaXX!`RW?&cLpEbA%O7) zCL_GSCx6bik9M>^w}oQOrVi^$0D!||x`^<`t;SV)64_?^Rp0(8bl*?<6IhHMM@OT^ z`kPLFv5Q8c&20>WH^kA6000Y*cJ{3jwrv5_fh(BP_g^$wrCUYJEK6z^0%v7%ditpf zdKg=N(S!-9z9-gQe6vzN``mmiCEPP`mnbaj9|=mbyZqsU!rvDT@uE1g0hxk#jtQep zq+Ptd=(|efbf$=>_T>Z_$LFJ?s8mh;JB7FBLmiTRL2t%z8dpndt=Ya(<$DA~v8Ji_ zOn7U`Z)Lz26~K7ww?u1r7@6ihpxkX_5#+Wo8Nx$sjwQi`)~4o)BCA>YBXy4C7O>`*_imx>Q+O z+k>hT9!}p-=NlLYEeH`(oyT>Yf1_u~k1)eN8=GqcVzmWL{th)1G>p=_mq0LUmI@}H zCgXseRE5cHwx5xX-@hM$!^q-2vM-+}WxXXNb(@+7Ug4N`D<_mIFU!oB!^^p)9PhrT zE8@!I_PTK*nw)|-RC>pR!iEF=?vWn3kADMnU+pZ!m1kFPb5W&}A8Cy8s^qgapPk|z zgR@}5AEvt4H3!V=*iOx@6TbL?HLqxPKkEG|MDN$3@QF4PfM%kI%N)wpT5{f`YiiP5 zsKrm4h2bIZlE{KequHuvZM7?1{!$1ADwF2Rql<%nD{idufHox1)?Bf?o6p+5_P093j0I?Ii){aKLvmNNc@O!UUAK6eE6q4FZef5k;r+^one7}^|`1)9n-=5&SjAD z@lXRDS~g9FnDd-K8fa{xp%q1aT9RRQeAdx0%QtnvZ27!@{q{Yc+PIn4q4)!JHI;u? zLO9tXr+^hl5wA3hm%I=D**TMCkd#L8UHGSOLJV~;a>opRg zQ%kl`Lml;>NuZtr1n%HR2nRM05ll9o%%fuMo;|O<+h!F;t?2B8TZ=ob&ac&aBQ-$c zN`h$nohs_saGWXVaPoKMnT6>xAQo?*`V(#TB~zX7gzacz+)dCX>@ZjJW8QI$zTsn$))c=~P^(#-?^bI=Ou@4@Gp57Z?Z>T|kPp-P z@!ws;qOn+|gPys!1p$DvdsHvHinkVXI>MT;zNcTajKJK0bK0oN>(r3B931c6^#%b* z`2Agg{0n%gg~%Uy>WUGM9aLQl!pic|uwnPA)<#&d^tTB_S+%zD*H4Hk@S} z$OX3CpBOIl#>%>yWeIbwj03+Y-%0n282-zyn41ZA&nikY0C?tkG)fQQ;jri@;b_fRKy5cN9ZFhjyI_HHM~4{9OwQl@y6c6u0q*4l7t%ToR@GGU=6vUTEDpCZEz_ePeou?X9nICx!!po88IPd&#JLyHdN7BnYc}8l)w`Pzw zHl?5HwZWkz!@7<~3O*x&eyUnLPmJWMc5FmspDjJf!1f96tx&(=R@JylijB#DSh5{HWwg)?8MR!U&XVY^!5zJc^r5@XmLontB%S)ODN|qh9|w z(%!Zrgv?P#%gS%Y2@olQi}aevOBr$49h#O)LoqrCbILjvEUne5D#gngyqOk6w#l7n z(>V&P7*-it;>$YJ0N~@-vAUXh$%}#sV-j-)$bfBy;oFvwwxX97m6e?aGgHO{gFS#QvypnKFTd3WtVj;GDbmV%x77+q zv%|Y2SKJ+mWMnJ!Gx}1#z$$IguXo47=#i+9>fa>rNT<-dskOJ26UU{332tQ7sH}H5HjYl+4CfW;wmtXfP8g^fF8XtEHgGa zJV}CLf((@d!hq@SF0({bw1CF;YVRRSz<^qPDeEu(#YAU#o=?a5bg+?zUNsDJ9rsv!K`qcwS-i=Zdow_C0OdO^%znfy?SkjHn=zXN=-9qoArOvg{gl)=OhmlB^XjGg+-70_=DlC<)YoxDaBKi?4SrJ8QwwLdTAp5MO~Qi9URL|=Z4 z1eW%NqAcn`xT}TH;ko)!2PRk?6~h3NP{=8;qXXZnL%kkXvIN7YH(2rYnbKE@Hu(3-3u zIQ}9cCb)U(&K3$9J|&1sV0CqrkYwTiuc)t%i>eE^CZs`7KoAg6B&0z=I;BH~?rsUC zOFEQrP(ZqSDCuSp1O(}jhL;>Vr2E^0_ulW%bAD&e?0EJ*Ppq}JsVj4kIGy%=z%_N9 zW7Hob83N_oI2(5S0Dw%Z)U2V2&uBSQ??Rv!@|dD0CM(nK~Z%Ib>`o~nPQDIGFk zry#&{vau1&yf|r4S2rxry|&rxdm27cP;fTqmusd^EMm;r#X47JgDR;qji$RegaPn= z^b#}Ku0Aven~4Uke1e|_H{u|VkzXi_)~Y}hOcHGFv8#mm@-@P4T-nItI}#t+nS76k zt1fExQb^#uA>R^CFl;mPJu&fb8EEuE2{7)@^@UBAw^S4#Mm3Odb`AXhcBmIN!e5KxHCK)p~p1H88?KIg!i~=h^)6o?0k+hL{v%9!#HC8j}-j90F1y z8p2e%eJti9(;-~XroFjI`fx`qc`4DztGEU*a}_k;DE%7z1{!=aCl|JMT(P|-7*3#})29M(|lS_PWV ze%;vznjqD2BjM5}oAQdJUwfE6)C;o4P}msgwt7zp&Tn(Yk7d zschJxIHh!NKS&yXF3hU?OR8Ooy33Wf#NfG*FWAzWULRc}XBvMjDDHAPEUVD3qW*tL zl?rQk>+i3tD~U0KG4Ywh+ST~U9#S+bf=#P;&`6d0mLBlre89xNKU38;CY|8Pck7bN z^cR~Z+!o@6xHHLWLg-{p#b~nyD*(4G4SWOpA}jB$qT-7M=D(sZu)7ZnJm+XVIK6z% znycV$p?-N3#OH5)rR=C(|7Bqfj(8nHebrcpzwe+ml}2{f`6cq>%cRi(dlD_6QWX@H8~|hkQ{*$};Zb5VXp6;sis`F#(UNkeQDE z*GVAoDXA`C+cKmQCUDdWOPP5p;!zRFp11jc#;MmI8sRXnTD{aT+do1>e5bu|R0P`t zW+1Uq(fgwkoetRrOE$dbzZOYQ#&T}3m=Zky`;X)4DdciLFM5*i>6b0I;?yftWgLHH zgP}Dzh9OSqv_>bZdkwSONP4vzQi5eR~r*$jI*}CC;gvS%;GP3M!LG- zZmpGJD!yt*u5*PK%cT>i3jVUXgFCX81!)KTru$6FomUB8%r7SNKBRY4e2vgQ*%QfX zFyLt^3p~{tKTpOaHD&~b#NFxgf#JrT4d%7EsI(YW;rG3?@bJA6H&O3ggNulF$UZ#K zWJ=TG2k?XhulJQOTx8Twwj=u_gkxTAg+_hmI;E^PDTpsCN_!$+Y<`2QC8}%wP;S5<6dfCO81o zqc2rSx1aI#EREPq0E_|CW4CC5WOT!r+aEuq^$#QS28-p%SNL-ph?JyM-O#tzL)ehf zNpYJSq+$SJ@llVL)mJO*K87!fQze~{w%T}HvGbZ5X9tTFt`}|Wo-$Jhi-A@(o*CbZ z`^Jfx-EMhNy1dvXO1V>39|`2JjY%evSgXGV$K^~fw+1BIQHkHY6q&ze*LHoGKWx%?G#n3SeBTq+s5KEJF=flbDN~UOs9c< zl{te1-e;OvRbBseAZeb0iSrQ`9h%he7yEG?10zdbH(?#pq$9T-ygPTK?Yo+6X2 zzW*q4(mC&i-;*cA$aF*5o1Zg&o6cd8$jz3A?cs&6VINB=K2?HTyLs5N=pJ8<>SMts zeR5Iq{QB(2-94UF`75k+I`7S{I*RRi)~o3)Ef$C?Qc#r+W_Si!~P}`upvEU z4!h4O<5n4kD)9<6BkE=`g36GZ6^WmsqpfPc-EJP3k0!fGm8d3@bI$P5%pZY2vv$+a zm1L6#1(mmjxY91F02+8-^!tay#kZPqb)!x3PoGLOmegOon+t?6tB}h??cUZ%ydk&2 zEY>ibhn=(w&t9~x1aWhnotPu6G6O^-s}*%qVHBsnbbMn0<_-Qzxyi0Ai zF5vqfp++0K{mT*1zDEZ;%>rG%3DT`>0h((j299u(n+ha(;1#XuNu=`-E@pDgh9nl+P_9lV=Od%AQKPWQ8745(!Cm_ORgKRz4OCSe*G3hd+=@V4I9b zfBt(pcwikqCed5m;OPVQ4?t!u_L#@5WZTJhI9_1f#dzU0NC&)VVq%T;d@^J*s0h;8};#di0SXHXO=9ZDapz4Y3rMi<4$k-y-KxT%O4{t3It4O zx#sBNT5L7XrVcD%aUhx;6a^f4=p*0DP8-s?tsDB$A~p>hA!ksp!q`desR8Wz=#%?A zXEGY(Pwo7YA`nOh=+{gKi0GWGu-vbS9+>l&hhB2><6H=kj`JEb>)bPc;3{$p^whW> zP%4hTzkFZ?mlMzG`tNCr`f_epnoq-qn8xq%eMP-W(P5dx9FZMo-yPE zZv~-fEhG*9MJC^g(-B?wxsZDkr(?31ctoQllcAZk51Vo&_np=_Y(vZ}3V>Sj?AUpT z;wh^N<9S$FSG@XBE4H!V{2%_CAZ!-Uk*1tZ0~pS%ED7i*{H~4D{|0vGM-bBV&pYEc zd55q}Uuy+_%*@tA2LcBLt~ZgNIucISPx@Ecv`B;6$qfhnO>HE&p;p(9;(AGSe2=BJ z?zw2ju{LmU)E0B3z6kNMpSt(~f?Xe~_mG7rB7A;TYq3|arJk@1YFfi!zNXE8bKbxc zJ&rESypVJ%u)vnQU40|buyiU>*uvU&M(dnxrxH6%p&5Wq+AD}3rjJn_ zbs$lSmUTnJLnb-3Ea# zmD%Tp0i=7jYkWVrr8n_TJRx#Zi9*2?G_Be=eqzw)_aF*$YJ`>aK^A5xXXKuq$UZ2M z`>JYBBr@7b^M>L-_>%QT^t_4SgJyriJ6+eoTy;U6nz=9Ml?eFI|9S8b?U_P*cDyuS zOpGb+C@Oq%jM7&s4cVvD@*BbU=g&s0%oQeVMz48&HQ)*4<~grL`8Li49PF=LG~x@H z*RQVCU0AwRN$+j4{o<>j#cu_7KUD28`Z3osKHT(rVIpIwt5(C03t zRP$6BpM77ir74@>xT%%h_`(wPV!zf~ZV2{HZQXo*nHUJ~4ck|39AND&);K-Sx{%I9 z#x~G+jVlEUX`B|60)Hm}M6cs$v{sX`(b*S5szywPl3kr3o!E+Q>G^{{TRvFs)zR|b zfNQepe)t71jq{i2MiwzaE|JL(Xmf`MtDaCV0II_)8bf(B32Ag92Gms!6eV6{*p35~BA1Pd$N(U3{s8 z{cPGzGjLg}J}lry8~9sAmRsjy^pY-**}BQVhzpf?pJKU)Z& zc!N!VH=kQ5*l}}`Jm>-%A zbakzLS!Sg7>&_q!O+iFoOr#s;K?_SZa}eqj$v2}5eX}lrk|sShn{dB46^T&6;F)51RKb+wD6rVfKVTb)Z#!$u8;<{BXjAFUT~D3~@* zAt{+lmF1nmNhdTiiEM`e-1)iO3AX8X1BB+xA70pt^%`D!KkGfmm;y6=j2yym5wM5e zKW+hkDVwau_5kOk`s`l!PKxMNMCVUG^*B}UBZJ@aTj#xYj|$uJTCix6A^Q3=?~2y~ zKHuS38fz}|;D<_!No$hm5JS|Cv*nvQxRFQy8~>MwHt*gI4N7C_jGajASVvVn@NrM* zD%Wo+r!ToW{5$jJt?1NB;nHrI52G0IM9$3V)+WVnI$#yHYdstl@(gasnhx@dq?S9fxw2bFeXq-S3oxsrY>&+ipM}AZ zFOaNH62d4 zfN<=>FF%rB{-Gp;Qfz_S^w;HujV6E5$|;lwNP^PJ^}qy;S$z)2?`go#1_7WRLDyo(og{pCmJ?cpjJ~P_>tI4V+ zS7~H8J61MY>=Ct=!0qrbn>J@@yOXG(s&Mk9IcpHq2X(p7>2Y@%*9t7VcRwpmNHKfluGLy68$X+Uu;@VL_K8 zx$p$=I77rm8Py9bZ_+pl6%KIwCDJ_4pe}Zg&0k%u|0tsWTK44*?9II&;9{mAQKe-Z8ADcdD>*Isr~4C z*FIGIrkees?8VeFW7DDojQU6mwws{ivm!4t%Ye({SvVYSOcKK!I}nu-cXsq7vgp_^9x!ia-&L`D7Y}YCa^)=^9RBEOE=Ebr5E0FQ~{gVz=chJin$LPa{q?!boxG& zD=7$O{^+8?eU6iqQv!fyR2{Dd4sxw1HGixxnUCrhFF~%Kbv68y5W#KiX3mhdBd<0z zzX__K`Q?|zpzC|L7Q__Q7naDmh-U$uD4kY)Ik1P& zwEsfF!-Zue&<#HC#v&0BNR;xCvMUcl#wE+73W0A)a2)++amHS>`IsrDT-%oZAQ51; zI#KPEAJin*zlQwQpiKhVc6=CQgWcWN2S4l2&(;9n)^d6^- zEM<&6Uw$kh*lS+TNbELxI0FHe7;zFF;lDqXyUQ8{C3@Wn9{2ucR}eGJB)+h%QqLHr z;!OM@<=ractmLmKe8SSpA;?R{tl_;q8xBgpcq^KmR203w0E0r&I->L(*Z9%gOWU`9 z{0)ykU>LK13S)cAP8!8$gfl;cA=tLiFo-u40D=AV$pktcQl0YYLP1*0O(r*rh!P^T znm*`a;Iq&ab6d}Yi21&GZKPuV{f#4+Y;x&Dj7BK&|wisMQqpFo^c zNyydz%4}*|%Djt!Do>zLJv8IVn*z7`%JMm6rD7On(2t6)AQ!>H+p#m$gcRfe4E(6f z3E56*OMe+sC6-1SBqgOrO7+VDWteplVCm~xY{{A-J8 zAL~x)K7rBsSIU4=S$f~vP$TL$AO$S)U`G~lPy5N^eu4Sjgl_B3>f3V`P#KHl5APc@ z7y&t(HoUPQ5>J)|uL>HYLOF3wPJjbz76QThOsQzh7-=ij#b%%8L+uoW2UHd*%$Fg7 z)5ubbrXwU~v#$u&Z8iAVo1ip=uj<998k!(_*gY()orpKz33=M5x68g(I5uTpEqwqf z4RZqHJO_a{qGIJ`fMmC)Ik>1cJm7`IpIO-;LbQZ^dkDo>e|L~vtlMK$6#nRl546is zlY9NHvQ2p%Wue~q&1onslY~tUbHct#*#V@x0>o6WpA7?I(&61sujMGs?R+Nrk7Ld> zdF&ECZq=t?Sawm{<9yf**&9fHc@(GX2xSZoKf+(fCs)&-iTQAO+LcYvBYQ;Nvbh=Z zrcke4Vg39p!-=)9dTNpl)-Zw;8~vRr@fOWW zy>Rmch*p26e3Y(a#s>?+S?})k&_3ymBw`^v5lrLSo{2t*ERK<(_}Me)6r zqr>SfTGvfGpRA&M~DNN0BFqB(gs2~(F$*H~AKk++e?z0I>`RQvdyg)7>3dV8I4_$4y zarJ}hbvAaqH}vtVeS%{9`8#>UQ{k9acMv9YL>Wv$+@~LxgK7j4Uek6qFh0&J+U=CF zvA;=eC>%aww+d+D96WF7W$V`8*$zG>iIH`~yX@)kPc`r5WA|1PzCN8vRnI55=>Ygeab27V_&N9P=`<^{8Rm3C6KIwz)Kg~@s4T%}-7Qn#)$ntkx7ZQZ8uJFAUYCAFod zA&|6Y{q^kt=O6^A$kHOu2>2Of&niw6&-yDL5w{Rm*5oz$Byu;Ydc5zcl5bFpWouhN z;0X;-QF*{SF08sO3ymX;xvp^ndYQ2*nY>s^CP-Q)yf4YnzRQF*|Cn!s;paSHM9<1* z)C>Ig(l(0euy)wL$=P#ZVeY1-V%)0H5PSsk5X(z|``H6B6&Yi`L6YY*RKDWULe>UQ z$Naa5YR;*0^^9x<9J-JPzIl~}S|xxZ*N24~RTBqtOiVr^VPM_-vBz7--JCr&S`Q^- zLgjRQz3}3YEXr_@JjJq>fJx2=|WFi}6n?8l0r5q(G{54g`_Qt6+xi#h4Cz-sVWxW1#JuJ_B4gxuZ~HNyQNRHei?+X9N*=^);qWKV1 za|^y1jaLp5%%TeOK3QZ|@Owj=+KhdzAREx#r}kH+l-&Qr#%FbnB~f=t$zk^nnjUIYFAp?`oa+mFD9L zyU4PaHU6R}2z}A?R4)Ec%tJgGB(!4#FD_Y9>eEKQdMXYn1lFZleHAf4D!6gn!dsdM zD`BF0D1N)EZl>BN5onM=w>+6rT#8tmcqD1(+-td53Bz-<;1`3j6FXdJmk6!Vx`s#CimtSTELWAr`?XVX;j&gdVS`b6z;A`G;`4Aaaai zl%FIH$SuS#7=*(Soswv*N$$_%*5)TAAWyg+glx0tarPcKT*lu$#Rv_ zxLkHVEGBZ|Q{H?VR73-FEl)7FQ7{ui$chTSb*gbhl$gHf)o`Sobc`GKF(Ks)w!Vby z$n$7aW4tJKCfypCJJ?%DH8>{p(dFXAtlmh6Hyj+$)HZr15`h0v(ga~aPm?mE!)$MN0*ar>vQWH8nG`Bzt-m)Y=(u55p=|K@qsU$xfL zjlwPFLb!@rcE#Mb^56@u11k<>{Q}E)r*qe@r5Y&ZwNnCLNWAp+xuo&8TpsZ9|1hXt zH2ak#C_$|Lx;Ndpsqwu3gVU++!4+PCi0@74243`^viYManvx6MslmlKTY?nDM?KYS zX}@V4*sEQknYZ$1O~2Jd=hj0oWtazrhX?V0gyHj0Qole|a@bJu7Ro|l$cbn$$U{ra zqi503n*c+!_1A!1&$qvOu<2Ic_;$G3c`ufxY-s5HVt>`CQVC5{t6p$IhYa+wo9x5A zj*w60w*?jz5dNXL)y`)cHKzcH!j@yWfKbTuyGNZz7C*dy)vR zXFQCA8QDHG-4FlIx}VA3_jGQ4(7<`>C?YDJx&)H*_ptX#{hxOC0g6X36F+Nm_zf)3 zh(i~y4Y}s#K}($eddABZ&;#Q{ec|O$3^)2%92c8N5T`8G?Bq1q;xWs6J4z^W)@N&5 zKSfA?RyTVRm_Ay-W1r-E@iUz86|BU)CH#VYuC=uJ)P8aH9wu`$P-YO_94d?XLw zy2*7h#1M2HG`_A2)$^VFoqPP&$Z&Ra)CZ9wbdr7uCA_wlQk&uZ>`R2ZI*UbkRCWuN zn7*Au38Lza)7)8aR~oT`(ML$fi)C33Jz)flCz`t~}>q+MmX_x(WUQl3`u1?EEZt&w4S)*iX$q?u8sBYr+?ssJ!kD8ex-0v%u zdGX_23$qC3I0hsONo){9E?w zkj?FI*zxId;u@off)PqAWMY1PJpP^J%*oFQig-Y)HzyCDotS`o&iok|&G4^Xu8!tt zeb9Lj<9!zRQW)RR{k*SF1RLT;F#-CdZESpd?Bl*8#xn5s+UBVR{=}_xLh7)~?zwtX zS2X3E(JND>%OB|B)X3H?k5lu9QOg43*-O?DJCWMach4a0Dh;Q7x(oYEM4QT+vmZH_ z=rWkmxTWf}u)FaR*D&kv)>sd5Yr^r3>{j~Z(YddvzhdD&wn$EUjo_^K+ONT7LSwg+ z0Rwh#I{Tsxk+IV@aeCshk~PWSJ+1X6wXlR!9x}Zp@5!ynk%s1nq-URO*{e&)*W7uz z?GzDz4Mt?y19~#wicN^bJr`q75Iz`}*EDq3kywZ-%#g@BLhA#QSg0b&;^hA?!<3WDp zZPuS+PPq&Mv3XIC!QE>(^@Z&A)U^svg`dG_ZK+T8hfDAL-5pWSc9XUJ8ZmFz)`C{k z+0UvDaf6833_&lm#3Oi)%is;dsd;0m;H`z&RS!>@@!@87*F?0gr^#(r{zR=C-=Fbl zAgbCuf03R(FE8ttPo;>BPr(ok-Q4QlVO2^Wq7*eB5Jgl!F?Ak}%9p$>PMP=i1{+vL zN#I~&{+{1us6L|-mCyKOMGj@1<2)~$frqT0D9s~Ls1z|`-`m*fwybDRXw?K8#n``e ztA|2BoAWuatfAMZpO>%FV8z}7B@IIlTZIkv|F-X!%#Jkp%47IN|NLpd_WGj6VZ$KZ zGcXW^A-Jp|M`(a>^oP#*3VJN;0zJXlc?AoiU9)$ZXB8^5#+baYI}IJ|ade!fA$I^Exn(q8l-i&;1Ra=*X5TOh1EGy^8^qQwvshJ26QEHie~|Kxot2PtRI zN4kJ1)uPrnfbpc@h^MxtZd_DeH2ebvMI+C0e0Fb@d}-EC18>G%R*!8Ag04T!>`1E< zW&ZG8XFE5l5+L;#4RJg{Wa+R-sOnk;~lPM?P0D9 zN%}65+$i1sPLHq20j{UXmkf5Tc;Ju#E>+Z6E*#$fc62Tih(kEQ9U@{>r9aP!+pv*G zNzlrmB}TniGj}L-V|}u9lkM2a^#>Iw{hB^(vrs-~F^NjU1XqbTCfxanzx!F2AJo;% zy)4TY{0WY+5eMf7J$61$@LrQ{sH85RN4{JY6Lj>&-bqev?W*}idd2vsJm|;lDR{_?omm?r>feh%!&qMo-pj@M<8K?c zNd_KaqC-`ZF>rfx+ss6EiJj{>e)ordXo)4iq@ykQNB$JtvMy*=A`ck*selqo+Ay1% z2<6d$SAV-}Mg?)3#cM7aAC#)?a1UwNUGGsr?>~B>sdYnlHLj73lLy>p@KFEf13BJ# z(5`@cs}Y^qhiT3U{`9yr&>NQgoNdn&yb7&R@GuHv78(P3#JLW7EImC7xE9w11C^Nw zU9xMNIhh@?F+>^$imDbcGp02&Jd-2jp2SkiOSjXopSzFC3hoP5;;kI@3Pq_{1p_2rj;BP0PXBTdmp&?!Z8?hc+| zmPd(o?wtL|Y|^L+M>b~_$3)FBgg8de_{w$c-^e@%Bjel&P%NBK&i@;lPsDAxljUU} zlVvzjZKnpm-5nklz(WM?V%O5R**7cSwKF#qg`qQ|!DgNlskO zh67mZ-9$3Wpv0PW;G2B=RsxE^sBZzRQ@+F(f2rC@J8XtdEa$I0dxbF2&D*rXTef;BQs`nH^;eP)bb+E zZ3c)wEA1ieLyil)m*8nYL(_8Xzt{AG*YGS?6ms=S$6~&>n1hhi{Ii9kX94nAPeK1% z-KG9#U~(f+7(yc&BrpKx8@!B`4*XW5PgXgYAC^~cTz)semEtI|lEITfV+V(Cb4cln z!E@rwpmBw_EBZ12mI{6kX{7X?c7&24F|Q0{%qxg)`n$`I%#nk<>qX#X?^4 z0#mt({AXyqz*Jt(DZ+VBiuuja&m=)W78ObX;meO24;tbbdM>U6MO*z#= reduce_ops = { * [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). - * [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. ---------------- @@ -67,7 +67,7 @@ unordered_set reduce_ops = { * [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). - * [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. ---------------- @@ -98,7 +98,7 @@ unordered_set reduce_ops = { * [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). - * [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. ---------------- @@ -129,7 +129,7 @@ unordered_set reduce_ops = { * [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). - * [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. ---------------- @@ -160,7 +160,7 @@ unordered_set reduce_ops = { * [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). - * [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. ---------------- @@ -191,7 +191,7 @@ unordered_set reduce_ops = { * [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). - * [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. ---------------- @@ -226,7 +226,7 @@ unordered_set reduce_ops = { * [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). - * [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. ---------------- diff --git a/python/jittor/src/ops/reindex_reduce_op.h b/python/jittor/src/ops/reindex_reduce_op.h index 37d7e376..41d29197 100644 --- a/python/jittor/src/ops/reindex_reduce_op.h +++ b/python/jittor/src/ops/reindex_reduce_op.h @@ -56,7 +56,7 @@ struct ReindexReduceOp : Op { * [in] shape: the output shape, a integer array - * [in] indexes: array of c++ style integer expression, its length should be the same with length of shape, some buildin variables it can use are:: + * [in] indexes: array of c++ style integer expression, its length should be the same with length of output shape, some buildin variables it can use are:: XDIM, xshape0, ..., xshapem, xstride0, ..., xstridem YDIM, yshape0, ..., yshapen, ystride0, ..., ystriden From 03b892d473ae686648ff97f1cae2945ffbd1d7d5 Mon Sep 17 00:00:00 2001 From: lzhengning Date: Thu, 30 Sep 2021 17:50:49 +0800 Subject: [PATCH 3/3] update logo --- doc/logo.png | Bin 28818 -> 31170 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/doc/logo.png b/doc/logo.png index 0da769bc1dd37438a3de3ae2e9ea0e4282c504ef..7bbc74888a275bb986904df98ae60e08ed03e935 100644 GIT binary patch literal 31170 zcmc$G1zS~Hv^I!Hx1f}C2qGmQjWkly($XE$-6);X9n#&>-6`GO-O}IOc+R=^{(^5k zkB{56*Pd(6F-N}Ro%(;05n-m+=+h0;RF6cuoP9c zhJr%Eg#3a2J@0l41w{-cCip?#A#raW$whuW>Cc0Km9%KLM2#jEoZyGIA7s+Jp2`=* zd^8p_*}q3RtF|?{oW$F$vQ=dp-ny?kmOsHSucY@0{@)e4*{O^3$fZ$ z5^^78a+2zYT?7M$HFB*8$7JTs{;PV;pUsZi8~t%T45xQIi-(h=zkid0sr|id0+x54 z{QHl2QPF=tLhb*5{Ol3jv1it^d<^|K8@_kMB+yV7ReoJ?e`h!1;qw$~N?y(Jl%J&i z>nozacLiZEF`>Q)9Is1ikO(%Sr~CZh$0H?#Gm%IxPX7FTniwxE3<@eijM@TOQ3Rd_ zY#ii5rvt0@xO>BtNA?r};qN^`Bx!MRBs5B>DZ$^3-wHNxhLUlXG3Z+5=3{YJ*GvDM zV=Dg>s69Qcgz_|wf8;6S4Vk9kN#i$DlJ+@P0|D^F$tvY0N2vVd#E3-T~u zZ{sLCzu)%$pbz4|uSbC_3k`|$Jrelvc=_RxlS2`=%3e-x+I=FF0rRj<{6FIeYTQg? zczu4g{_kvZAhRi` zWbjd?WhMMOLnwj}AgRxpL!R-8=wO?${}qZb5K1#gn!JeLXxvxu`SC)xg={G7H6a&C zK3UdxOK}rPk-y(~Q1d~?#*nprx*z|~!-bHC_3zLWhmcA+_#p}cx%etb!F;ZLRR~G2 z6GHypgOY*hpy9hE30pyZHgnJac8da{ETM0p3>mfpR~I&Lq{RZ!Afc8;lUPu(*2VbKmZi` zl^fgtC=^fN9(B238Tks}OdJ_eu7!9QH~kFt+I z_SR?Ht4(CXSDBDR3Pl=XR*#oVFyt%xxg9^*|CZPb(O-r?Xzs7`$(Y*ezg3w0eVA{I z47`N<-O7+z_qc&v3HyNmc)6j5)|YBz94VTnu6HOKe@|JNf*p zS(XqVW9imwAoB9Ru}vUjzYXFplU5YbwY$gGf|&W^HF;D~QE_Al!w1(%M0d$r?Ql^A zFYOd~OZ;X+y>~eltWwOB=h#IPdCVm`if&KTVep2$HP<)#wzUbRYMc=X$ zcTxY|XM+q*7;9-$JSi%=#a~DH_jU~;M4+*QJMg(vHvjCGpEg8pNqZU2Wj^Rg5=h2` zo3(%4S+gX>e5d(GG0s2M9Q+-@5(}(sxbr8rUKxjW5$q&H29KATFGLz9;oE*rH^ZSVzebR;MMipFr>Hr%b*M^c zK1K)+Nk3MEsv_%m7iq52@F?nx$6wY4+<9>44M|t7i5J(RTk#{pI zP^=}D$k6>W#O2a5VDd!>memThc0oi;{Wrg76zbkOT^|Q?H{wv`?Y0Y`)kjk2Q(m9# zRvucI^kvwE>>K~3I|>r$8jr-)qlw)Vsq%D5Ip@Q|MKTxM*O3xPk55iP0V6E>YP)^` zy!{#c!XhNZ$Yd84eI1btH4_aOyUSyjk?wmmSr597;}R&#aM6Jf_;mg!epW>0Z``$x zSMlMF-orq7h@RXao+u1N(d5$voo!M>@jVAd)r_5`WY&k#iq*3N|JBqG4*`L=H>t?t zw$?P^xmd1|KSur_(4UaoMlT_LzlgHDR?#Yp%AiFhVk`_a+S&$LeAB%rP=Xl1D&|Ok z6Y)#{EA*o=Q9QWahfFW@%oTBLY%&h@3Dgq5@E46L$t2>zTlE_9*Wo%GzpW`0Q@Egr zYZ(Qe@QwB5Uh5{KlQOqnM()59_c}61ScKZxwhu3Bj8wZ9OGCs_1Kt~E_H>l0;@4*2 zU89%d+ZxR2EHdsc0xl>IS>TvzJPAJriHSmu4BS*#>fMpKqICt~lke=}O(x(#L0Mun z6%$XwXWCc1?m-blh{DE%xc-kLWe9xo3Cfi$n()#hM1lpK_>RN@h04E97qI)PDTJ?I zX}At_#~U$=G=5P5qM*2KeDVv|sKxPMe}&p&G1J&i!&X^AOWvY@DMC&;3s#WsTLnzG z2tQa-?ZzL);^1c`AczZ=PEExo0=FP?i;CKJM#xC`6sj3E$PvTJ;d~DL&`h2ZGHxDO zt$`L*)H^?^e!HipAT)XyM$`NF{qa!V-U?eFCtu=#c&ve#@Dtnp%h^?S4!hRWl!Opf z!-qy1i%g^`f=6;=-PZ9jluUui3$;s5d55jt&fWU$tRm9!c zLF+@cNO>6yMWVK?yJXWcXTi;>a_Q=Ray2X7DfLp9!~Eyk+Mh8|-p=Zwrh3ZF8Kgr- zy7%oHXcvWCw?dL$z2$_->$9A!&L#NYV(h+OW?WukyF2D=)Z|&j#z)7#v$Qo1(slB7i zoR#kL-f+uKV|e}D!nT<9mgS-yiBB1Ed6Z(|6ytqc@cL-7lc_*NUP9F5+A*)|q>hT^ zuN8&5S2Q)7l?ZSbxz2h0wMG?x7SP^mzB9II2XV?0H{;8Rvz^_eVa z%P1kd#lyZGFYS!q&HgV92}zrZJ$B$Pfh(GV6n!$w-*yJ3g_joX>eNDz#i+mzlxn-R z$9{zol9MZ^xt`B!LwyIj|-DF-$I~CNt;NZ3G?ZpP6SD1L=k590=3n`BY@(sS!~e-TSWq7KPcI`zn$y3g33Q&(~_5O?k={Hm2mA??hRy zdb;kgebGOY)N&KV&sW@g+0^*wlqAcSVLBIV!lhBj&?%LxaeEg`bf5)8ZXiKiyi+`><{hU8WpcvE1fUt zNZb|XPCSsTy&U0n<4jF=jU@O%hegn&_TJ=bE%$}md$W!17ZRFl62Kg2DR&I_S@ds- zXZvwVVc2!u&Q(*+$G<|xymp_0CVna`yMi$=xuexo++5^ox*=kNMwc+>y93{>56NzP z*gb;-gO-il5e2ygS zmtZ@oHBxe++x6VZn<-6LpdFfOJTEcS z{6LC@1F~~tv%Yu6XC8>`VeIn>x4r9fOhmz1#(f4_!ei0O!Vp8cG}Gf1@T0TMppX1aw;z*YfQIlm*8@X z7nvw;Ao_hjem5C9S6!WFYp=(zZ93olp#jbPtfevHm%YSWySClVJZ#QR_nJ~Q@s*uC z0{4dSZqM5|waVSJ>{qj?>jq4Hi(S!XqfpK8?_ZeQA~Vh((#9H8_?ICmBvbP!Fd_r8!sMc_I>QsbIIw|_`jeoo2c6uJ(>-ZvsM@X*ZXMOW^B6AeRVBCulXiQ8? z?7j_=Gs&;U;>i)J8$lMX>Ms+x|0McexQFUP!eF)bRn-xksm&J$#m*02Ynt2Rw9CST zZ3V%sx+$?ZP|eRk3<-aKPEy#=-qBGLtLjWQK=g~*?rLb%5<}NHks$7jp->QYKc0W>65D|@ zhheM8Dxb)mtfi{CVwe%P-TgCsY1bib7_+n$=Pk{RT$Eu3$CDBg;Hyn3=#LICS5bbf z65{Rs!cwEVlxC>!H9W5J;-YJ_Pwoxtq5@7N~Okbf{ShyYhM`-iwH$- zYo&NIr0J9Kzx`+85=e-{tSBBZJ%ap>Xwr2vqQX;U9HU$1#d!mT_)h!zhA_H4A;#@|^qkn%4 z`4vn2|3o@#Gu}|<;-;L4DkG@GPI%JA`QXuL&pA_CuOTh%R9im~d}ganBBG$(G(GcJ zJDJk-yW*(Q05*ia4&Gd%rtzhcjlM+q;QBLqhf+$kq$NuJh6wUd@B8kdXy%4mo4#QB z2gg%Bbt+!qC}2OFFauoT)*RHhVp*_4vQ89xWw%n$9#4Qi1@fDUKc7=L@VO+A7xfvO zKtRa2Q!Gta6l&Rdl}#B3-X`-l1n#9r)X9@WU4-S=0cDODl$yvFphTH0F?ADk-w=7P z@TjUR{VKg%39Z&`B{$^({l z4{5tuy*87$>FdZ=df->@{TK#G<|b5Re+fhcc+f)hHm<&zlV;IyTKV-5O-)al_c)o& z;=D3jFwVIxicl8pdr$kJg8+X3K~yB>AxALlUIj%+!X13Rn=o*dL<;Dh_zbZcU8b#r zOYCl+Hd?~DU=g6hr;;N3fjfQ&2schq3fMGFQk|z?!&`^CnMy0V4`~_bKChWFwH(rZ zUYM_v2~aUXP@#so0NAGo#mSTW4TNRoy@>n)DdZZli3_^#6+@vl5VK*Q47Tbs5tUai zx7utdroSPVoOF%&aFS?3U%v>p-PRW6ii1`rYw%@hiu)?mdv#C%hJ(buYX0rs6)Ta( zJjRBIPxn+Qk$E%k4?E3yZu&UCn!XD3WC`CW=H+<0MYitDz}sI)$IiKUhLS>5Uvxkv zY8dmv!8siL)wo5)UeNqY%V1zd_uT<@lDMu;gba2rEm|up+8bLgAwjsje$%7|4u3OH z1!~6E8V_i7NI%?czNoC1%CU8Hbg0e!CVl4SxFz^IS1!B3s3CI2SZa)=s^LkC&i2Au z8yL}l?Z=-6xJx3x~AmV=WLZL-94qC^cQRCs>#4QnB+ApzQwzm zdd2biu7}RL^C52Xop`tUw(|_~y{$HOR$g=A7dJuCAL?Wo^NQi=6M}YBSRac2xWb0Q zli+;&MAvc)-wX>BZjg<=^Vl$4{K#*oZHUcoCIr$G zijxZK)Sa`b)YidmA{prw*-SmIP=C92n?RDnYgdr0h3B$`%TIHavnZFJ#qI4aH? z`fR>cX_0V+AzUdO9-Q~GfRv4T18*EtodW3xIXS_tHA>AB;4M|wU>ACETFc32fJ5<8 zH?s`x4UiMCT~uP{<%d7^G9qc^`7^ul@S{%CHbn>|8P3f5!L^eW=*kOA-*1NO6*7x7 z%bhcCPmc<#DYOCPy*;I$B_hv`NZEj>!H=h3ZyrsH#L{N;f_=b#_?Y zkNc3-I@{0W=ARz(QAClbSTf4}6}R0|_m3Yw6TFABs!>q~lD!kax@)bBqy7B-uLcu@ zX7v+r$ksaN#O14KWqKgV8RF`Sc2OZH$hYba_oU>;KZ*}(hhvLCz^O>-GSW`o-K-I% zq5AuDcf#oP+2jaJ_eGz+Z$->^-&wO)&)*f=uN(u!WFsDD?5LOgi)oO-b#y2is(Pu5 zYI?w0{qa;UZ?E7%7<;vJ{qlse3VC(D}B-@RQN@l>7a~(S4maVSSlpYnG~r&D?aFEhJ!OG!G|d#o((Z| zDps3?noot)Ls(gWEG|5Ox12K_xMs_v4!a!TSZoFQJcwi z7w;|I*!<*2u9I1olUX&RX_Vxy!&gn`UnVuR8qnCHUl4FQom7FTKoH*W2q;@>-7=kZ z{FICRq-2}(f&}3v=k(eGYm#B)^9)nGd{YB_W_CBcGRVB5bJbR|{9XQBt*-ycrW|CU^_i^dsqhB&(e~}k|%BEHU zF7MoLV{+{ntIf!%r%%n?mNw^&TKL_fEg|DIFbJa=;sxwraI%JINixrU&Z0HGU(6oVc_Fq zUEDCo`5Bs;CRhJC4u&DUJEh%xt*0@Z?K6?ekTWBuvG1jEL*q3!-zY4s(9WNM2elcm zsi;uXZX-W`3y+EEF#zlWy{f)B!OFq5RD46|O^leIIt-jI*VJizorx7!x={Cne>8KG zJLnU+qZUgt8>kEoQ!Q+Vbbz8gUoWt0pxf?(Cnad+7O66R2fw*+d00jQA} zki;Ov85$}9v@VI?IC`>y;yaz%z1+5nV$=ZU+4tv;H5(~Ul+%UQ0;8kV7W_D5f1zuK zi}Pfji}qx}*mr*@thF%c|MLNE|0Q4kd(7eFpzylp$Vf=)Pdzp@Z1ehq70cQxi_Sz2 z6~~SKO<(?X0&2R5jS{Nu20$|D_h!DXbyDFJid+943P!$oxb5nDcB|cDYay?YtP*BD zPpw)xQSd(R{@iA+z2%(j?cPnT(eiEmVPpACv8p(#exGj0&gstUxYSoM+?i`TCc@^W zee5LZhLg`pGd_~cNm{Q1uS%joDTDo0gOHL8xjz(24d5RhulkVc8HeC!gmrlPjdd6< z)6MR_`#%4Xd=(g}F+%#?x^m5VPKFNCY>;Oo@KTBl-K6rfDq*u z<0?mT8*N1bM@mC0Wh7lt(=moic40s9QRtwmiqv!SnW8*ur2#UN+ujIV;0;IResn=} z7&H9Aufurv-Kh9jyQ?4l{G%^~Jc_EmdVkAQibYKpIiBS_3<(WYw*>YrS*_aYzXM;9 zL1Rf9ON0>-#`A)KcBQ`dqcHr22-E?9u=qmPU^O(UeGEQ4*WjLx$jzzHR8+Ev9%NjN z$zDqsQVT^lk4wQYQz1I-NRh5oS0(0*#D2om{B!h5Wl1mqluk_7XTR}mTy3|b@!1-b z7A7k^UWrKr9MZ_T;-8X@m84tcpd5DEI!9Yv{aK2#o1r5@b>wY4z3V%OT5LrXLr^zc zb;_n9WFg#3e1yT9lq<}<2)_@?5qhFh=Z9H-D0bc7j}KB5II(WG871`^j0r+VOqn!& z14978_|1Ee>NmLMurUC;*DRWm-shELIOY=ZU8O$Tq|@5v&wn)II> zkif1{%z_^yJCV>Wryn}EK6vL%qDGAtAAoB;P#SGQ_9K-__R98N4~Xg}7Kqv^=Og{2c#9rNW;& zKP1**bqGCBLS%+P<#;UthqNvKL=iNVe6zf900cGWh`Q{ipZ_+W=3PTFNM;qQ-Zs`1 z%uwM^J<y^+W5_a*{s9%0E^xNHpK;K?-(SEkg9>)&4geRV252;guAPxSMzvMEi@| zak0S-NV(KI+lfMWfYnCd{n9?#?SVrq~8|L3n z@wwjq6h>EL2iUZbZ96+C-JIW+*6`m0C~&E!h~||fUA8V0X?BLL6cr?hB7-2h8TI1x zTYC?is?$+~-e5Lm@d<+3H+NqLUnQ2vAe)$gel1$fbcE@WOrD%MyIZ@C1pkVm-W6L!QPIy2N_b!6*6!=7$wrsBtX}POnFaTvh(PO7 zQSGEs?_>s zZ5a@bvkR4u^yO0u_Ek}hiqXSdSFDTh&r3k|0CXAS1yS8}Md~Lrv$jk4hCoP)vG*fbf6%q{VIy1T8TDeGW=~yQ-+b!Zus>4zw%9#Zp6f8=#uc`M^~#hw z5>o?vKjEFC(;K##ceZTh#Vgl#QyNGf+#p#PPOERGU0{ZG4IO=QY7)$*By{M|<$ES< zICCdCEv)$&EK|-zJc~C*LrR3Q+YIlbc3vcfH_MJ#|whmDi4Piil~;KL9)T<9#Co8IMvF^*uo%;5&*qu`mI^s^GV+18!xj>_^pexRs4=9Ir`6I3fht| z#gN5)&q^_L-$O$y93WtEYQEupI4*;ha@da?$50v~a4ZS@h?f1;l%HSA!r7+(0Ks2d z=iXnv`$BRW1UGIDstoQ(mxdHG734)SaAQ8XGQGqLl+PaCpxN^!;>BmmO`$hmhs{&V z>EaeX4x#40v)O54HK?Q{RuOFc9xWkd17dxOIORx04}UOZ$)z02dX%dtTWHd5kVcr1>N~k@$?IC90-7A+iEoxYCF)AR58>rXa zI6=sFmJGz}zOV;4f}r?p&9_m>&D$Hm*aKm_`f8I`V#c`L2)}Vo zTU&UMa_5%0g~C%1OY!p))hUToxGjWzaL1>lSaaMvzd@{(cv>MQxMp(gqG?P{)_P^y3Ca&& z5J=wB(jUcu@}p{ZdO<2zMqlMj<*-j+2IHjC{l#=)6gpjnu2s%N`{-!`rQNkY{zRC7 zCM1L_T?Y5?96wgWU|pD4tL768v+4&=q^0~PofdVpwHw=tTv>FFMmNO|Y?{}W>7{}g z2n9oP3J8oNgG5>BYNs@l9~}5xT>9q~`wvIP@1-aaEaI4XrZ&1wDvYe;nY0J@v4oZc z@8~6v2K6?hMbe(O$I)hoV$HBr%KrEfmXKeM$Q#;wcp9*%bmT`+Ghe0B>LuvIz>~3l z`9(v}`6DJK&{5D6K!xNl`Ss1RoM^?U2SJ)=A7`RSV96+X$?HLuZVYiSqtor94{l+S z;*<&RWz_A5df9l!-IzG8+w+gsm*bafEkJaP8T+^F!WKCnYOxHAo79w_o(|K?2c=0` zo)z^}mVYli^VM}SfO~jjHcveYfkLXir9UcDf)7pVm7X=8$9=HcndD!<>e)bPz2zWVG0SYKVto#UF#JGgNSiQ{2^s?C#r^7u>glb{uxm$PFPd1YJ>Br zH%bjFw?zHh=fxaEfULMn7y^>%|V5?7f^8&;d}_r6f8QPL!b!LFn@aJ zonip#lwVodgl^NeMdp%>jD%MT2U04ej6(wAa|DFuj*}6Xlm%xUV*Nqh(dKRt9f_}|Lsja z*{lf*3yb|V^k-BL-jdKK^w>qc$KOG=L#v31xpFValx|XH)gCjFf3%_6bzjj(kL*1- zO;Z-J0SWQ$mdm{uQi+S0X5$!$_Piz-{CtT%K0F-u6|}`6hJ7z>=pr#qBmy^?DrGSZ z?r$hMhs_9KvTUyN$%<@`^!;DYv$}aqY)kuKD4LX z_DH!TWpK1w0do`+-5ND&)`TuZK+x)^J-R_WILPc!&y~R**}e?(bR~POH5uJwmitm| z9#?@^CrQHu>H@H`EI}j;MDjhF(h})BVHWdC41=-6Ns^IeK(CBkV z+l_&QpaG%;e=nSil# z(Ez|;-o_H$QNw)u{yxEWY4uk--mt&8_~^RV$~%ey0^ z(7YTFQ}|{88BucA^o#TRL>^V4drk^0yGTj{ML;zgqoU|Il#&#e09>_{NcM;>n$V=@ z0|Rsddh6oTMFD4kwq+atDd;!gGT*}6d@u-1+Rh@Imzxj9MK}$*k=2$LvN0Kr3Xa6g zGZ3=z0cI2kksvAkQIj%gP?L@=P`~`4ab~r{Leg+Wa9ME_5Iq4se)S5d(a3bP1by+V zfe`G?(}P*nTyA3bwmtD3h>@s{^7|+=+J89s2v{w{MjonBpy#WRdF|YImOaqKN+>V! z%pQQWZcLXUJ-slQ952nfepJ1RRCqDOeV2A|1y{!Xo2`*W>gtpnbIttr&Y6}j_W&`1 zw-4?1KAP{rGj~YsxU%5Ei54Bke{b?u?eTO^*5T`GT1C3d#No89 zwF4KEw9vmxpVz&WU`5Qzx={#vVhzLIUyLt6Wo04$F(ic)>-HxR;ZfhF-jRx+j;&AY z{E&*xxKtM(2=`{3Gr}ZMl_S z;)DQf&vOo=~By8DFPG>2C z(Orf;x|>Pz?_;uamHnn-%g5>6l^rqaVyoO)Jn?v?EP=~G>bk5*Q<)>*J{Gd1QzC+>IWi2FTd#M4#=_epZB zL3s1-J2@kFlDo^XO#5Eb_Vb+9UA!)bSI~;{MAPP?EXwPr*JXvdiu2!-;j3nh-_KWQ z&Gd(#+EDZaBgW3Ir_iTF#1VFIx(MJB%oAz|jW4J_W4LR4qEPJ}V(WO&7ulF3RBVM}2czprc0COzP z`1@BcUX<&Cs?p78{1%q|rrCzUWhURniRLp#`1;VFzNx*ID;%X=Aw8y6!>ja){XmfT z6z3G8?8lReZ(8&0^+u!5IQkBV^7k`GpxaG_QANU-fV80;`w%sbsO8)H<}{$~f>U=u z%Z{R-w4du<8r(u@ra~lakDsCygPp>9Zj-ac-K+iLF4+P9TH(=_qu}wR<~@JsgSxEy z-?ZPF=xJs4rdZK8OY%Abl>}@>q;nHHWQ@Hx|9n;VG$-pOI@lC93$0^ZB%Kr%^9$)d zRUvib<0f&r;0nmQ`!wk82;+%4GugE_}&H{0JV6FQnX?`Wl1tZqjGxu3K?kJ?;{Ky6j3dT57whS6 z=0|kkRAjf~)`zARIJu4dus_`KN#5~+^*dZf6OJY~@1b7(p>sIOYC}LMUDOv2*2#g_ zj!vPk$;+tC$Oc=f?VGY|l~LaiT(}|f(#(H`k|ePT*TQS$+7TqdC39PTsWj|>X>Q4m ziOKi82Z<|z*2lpA8yja!w1l{Z)Sa;C!sSO2Mn8mZojX&OBM*aWtZ;(fR>WXb?&uW@aLC_si~rh7HPhV z;+sYDX^-R6a0QH2D%7i8zrMhDUzJiXov?^t84ePkp9FGS)fXxN{>zm8-obswFCmhOUK)BlrBaKk* z6RINkv_Gs)T!d|1D??70n=AcZ!+By`wQ(>JoJBr#Ixqo3@Qs& zI>UDm##@z(i)9a}kt!v>@X9O7U-|G;M%re(M427>nP&#$2RwWJLs< zZI8U#CABh~y3gpe4`m*T^64sSXTQJps*>n97C>pq$)MoC>fQsAhQAX05UFol{%nQr zwZ3Sb;kTx@WkDd4gCZt`879v}==Vr`z|_pM=>^i#bERxYKgD%%J9WK4OuH>AxHc8W z;c2j;%FwBt9+axz-m$Q7GKwzES69nA)8FJ<-JZut_qdEzERjOk^WoT@YT2lYXeyQ3 zU4MqE+~W%C(#2njjcsgfz>&#JzcUnd@asi`$_T3_?v}3# zMwWVkFEBAlm@r|M-B9VcHTmXqJi6_apQO{C+z+w1!RAv11nYkIN~OS__Y4$32$l?> zZRDLKCR}kZ!1bY}!C^y!qP_Rn-S_A|4FjhEY&+4rW)j zIgmy=!I2-q^4T|WL|UVKIR#L|^Z9p>f^;aBWd~u@YrwUp^fp@&S@~6R)YsvC0ex_; z#;VGK+mxdy0(p$Xw#4ZZKmS}LguHrVEqhS%(B)k87N6~FH%yigk|7>zq3XmnOA1=x z8bY^mKGChi$DVevHvO7AmW~$rr<+L=RiX+@iZ#Rz`7Ulgj45dGUZke`F{r?tScMfy!lXsZL%oEy?>e_B;vP?JkZwZ0 zz1SbR^@Efnh8#Ju$P*h73aH{Dn}XLkh9a_abU9ojm)hyT*DP$4v)z=cNh40H!n~91 z{=`|T0D%5XGdS%v=O(!zm=3}t=Q}KL(s|?-@VCM?TquxR)Vpc6@UJ@S$=#(2Iwf3T zA<4U~0b(Gd4Rpo$=+ zi>(4wHYS@z9rn-j{{E%J`FUizPgwdEH;Vx*8FX#O3kX-|S>D{bLQ&ENx=bthi+I&8 zfEtWkY{$wr82afKzI)-e=$Xthfz0@mX_}#zh0)?cQy!y(yaeVBQZSH%g-Z;Mp4sd5 zesxZk6T;$Iv;5{g@4f_gNBP5l5T_LJ&$akL7iWwanj~@buA1DgoI=iEg6C#i*DmdE zxKsV#h84*OU{#VP%30AzK95r9e}wpVAUZxaYwe;rZA)rPYB*Lzu11XkHGqAn0^Jv| zd!rhlUfr*>!;gMI>eg)sPVoUkp>48n)4^E>ou_2O6ryE?vDPVsyx_IA2)q5R4cu`; zrVotsu#=tuertp51<3kE4&&HfCLfT)eFU<;n-5cAI(Jsm4P+iiKji|h0(m;(b?MFz zQ6719Le)iTqH1Q(wb&9Yj1NG^B(!B+ia<|i4-dXF{R!W*IRGP9PJx`ZBj4^WkU}8( z=FG~qF01ofmbfjNQFIjbz`y`~JPrXVBV*A~hXHR{CWrRtEDPJrO|(ShF=t6Q7SY|L z%r)!uf|Kbp4M)vUFa4HY_y`Oc0Wr%3EwezdDvqcg{5l#oo=|LCR)<;9Pivsr!QMGJ zL&qp%Y|1ZNIvpNt^>=B+^hdmqf`)(YO6E=hbH$mYVf&!`gTpL2$y`|-IZgZISWmQs z=byZp_d{{e$>V+#VR=>(6Nb?UFjgh&PzJ8*p)B3P*7>4=gC#v2J*Anle92@is{QxNB%|%T9=RkUw*|24a8)&3NOD#(jNA5`)8c!lKB$ zUf)w(hMek%SD$UO@v=ULGayFzAN+b?xPqiHwP||(v@bs>U?;xgB^6(uYm^-g-^y5F zC&d~JPiS;6GS=BxZM8FttL+f>_Xi2?_-So@ISmHXAV^P2O6!~DPJ(5W>Tn=|%uV!x zI@_JNY|YYGIV7!+&ww-=42Wq zp+X4+zr2+oRHl}V0^=qoz)kFUBmvRY^qJM#NB8m7$-(6ZP9R2nJ5(*b(WrW0KhX=O zQ+WU-znRz*N_Gm3o~FB~E-7Q_FVA{T@9*4)uB^tk?#ap0kxt24SHC5#E#LRY^fos^ zvMb7RtUv`Gb2Ch=vIz3UmF3Bwoo0x>X5Bd+G)ha*jT(ZGr^;7odfesiSDVbEdi(+<11&ieS{ye-1WAOj-?f-j zv{NGZPq&-`e5!YroI0SbAeP#Q`PM1D(tcC!+gxpN9T&o)LSZgc-UsaTtQ{#?*hA|<+YAjxI-8UWnv_({H#J3B4<2J& z*;PaWG4k^DKR60Hkgze+`{Lijg_tr=$@~9ao&zPi9<~OH5@HL(aPwhD2L$VAB5b zqW!q`q2IBWus#~OA8U^1 z;Uqr6f>8B~ik}BKTq9Qqe-w(x&4*fn>MX!7CwKZ&k^GaJ<>qrn2gBkd>%RY28cVfg{Zw7X-OCkQnK2Q;1XC9`6$Fw; zMEyC4Fx#I~-(IB>e3~c!9;088+*D~6$5FQREA*{iOd3*;=U{ZjTB!xGO~CrEhxXh=-ULBkpr&PRFj17CA z-~<0`1a#R?2mJPmwmXZz2k#)nefg|iVSpI-)y;}|CY>P?Yg|}Q@aFO>=)hW{f*c1W z5t5dAr-0T0uMa`y$0-skE;{Yx&3^o6^FM!nqTUt_7UjUq@#B@5)#s};D`FVOc&p#R zXMhX8|9Bn@-EpBoud6}{s@tOu%l_ozqvRh??4G>FyF^Z1<){F%i3&;lcbV0ItyJH* z7n!2W=&<;0zwKFU2sNbvV~&Gsv;q%)Kwm}mn99+T+yL z^32>lxV1q)N%76g{cw3ovYE(nEc+f!VJmI_@WibN%23n@EOucp(jQe)S4rdoph8|E z0EAzB2ln$p($(NT$BunWDAb&P^3F0M3X(e1}DOQfkv23|q4rX{u!eK}DL_&!CX*YK0xSfNh z*w{-_H*8Ad2Hb3CGfEh12h`2#R>?l?Gpq=HA@?`dFMIq};O*YUev#Mj4uFf}QXZR+ zmiho3ScjQ=R?t%+6K!yErX(NHnPSLf^#WQD9E^-w)&8k`JyOgO_kU!KBa`;)@|Q40NKelad;#`Di1&kYke`U$z)ZNM_yb3(PE)#qwH_9_ChL&O}}Ys ztyh!>R`H(gLtP##bnz;W@BYu2atm1{^SAwDIy1TK-(G?$j^lYDO9O(_NM?E>;UU47 zgV_G98!%9({f-ea<_g1#M#A_X2AK-145{3F@1{QY8|iwzd=t!3iuG{_pGzi{x20X? zo=n)#Or-cd;iW|bZjda)l7F0mI z{u)H_I7usG-omTo1L=j+%=RcR+7zh6S*v25ZfkR4h2HhP53*1OzAwqww06;7@Z-Pj zbvu3?a_U`=cDVAZ^v)l(jNaS$*Wi@l(LjOACg8ml-My*OH>n{h*d0k)Q9+xHDN2lg zwml>$r9!JdC}T4wow0J_(N;bn-9>j$wO({sqTYh)n)9AzW&-1VUAi!$xJx+z*7?4V z)@2r)pLENhr1!pCpJ%E2`Qh_g$@pujpTi@MK7q1A1nfnsuec%4c5RUdIP{xcu8CDf zO4=bpwK6wv)OLnQ--DA736*7D2j$J4$WS}qBA^a-&U6Q`)>P6HbBeZ1cyqpglS;9g z4B+JA@|p`;Yd@JIus=Vfi_?^PI8$yT8-4mC(1l}AdT`fyp&615N_Scxu?eJEc>3Xy zq6J`3)7pj1QO*o)RCjD409ZhP6(Zz|7xsg>PFuKr8#Sg(KecS6Dn)<+xqu*&0clqT zEmd7vAU(&qAw#_cKDWTvW3=^jr;E)t+kaes_&X|(UA3{_uI87Wczpu_P3ea-hNHn@ z2RH#F@8I=uv}+v~*b@(Yhd^RH*{MBMBvK^bYdN67YOr&X!#UXPYB`80l^z5P$85fU zLLv~bLXY8n<+65uKjh0OlU;0Q)bc3b7S-l~vqc7Sq%=T(8*u3JGY=Av{mywhfQ}(g zT_fy}IWqxH#daZ}T5{N4PvwrV>ty7PAiEa|UiN%76g1ZeWX->83eY~BfZx>|L|!AMa^c=p^d#?g zbBKvsBZ7eAUOt~)jGn&2^av6>PwIzqRO=Zrae}j&$1Sgd>XfcTUG@e{`~0FjjvoY2L%m>?G0WnIM_Ma`*w0fxKU>NkRw1z6vndQ&@-@v(UuE1z^` zcE#c1Fca4zochM#&wLKIB;9S};d*^ZWb(Gl%W4ewG(ajK9`NO{=hPcO4*Nlh)vbb? z@WAV4Moy;V=9|O3o7(;u7mlJ-ebx72PR+ze42I!jn`2J&IFEjYy?}}%SA&33qdZT< z3^h+$I=VvWKV@Rzcsx}E!Filv;wvZ*lJasxy|=p>O`O}?a7G<9iv2CShDWCLdy+d; zkOx4R6Ualb(wV`HbVVs-w(Rd*Kkufr0#JjlDVIr5VvndvQAx~28<0)lITRTF(2gjB z+C?0QZZ?iit?6iUfU}snc)VJ(na!Ch^+D~RKRaYF%DZe%m9>97j@dWm%QLEAt^GWw zrn)p(W8poGDwn?9QFSX_D$G?viL$qpJcGl_Gww58s6b5iaH`ru0zvQyk{au)!Nuh2 za$-6=T{8_%z4kmjCTPAK)AZMzw!lGwu(3L%H|>ry!$-y*L*?J64n0aw_{>0#EJ#hi z83p5a6D%e$ZI8=Nij&9M@FHvAtMuSecaD@)=?zABeLT z+b*ilVK|*LmTXX9K~SbUk_%LjS#tAK_0IX1Y%maPvsfR&ZE+N#<1|=4;mVf%2Dg*_-zTE;w0AM8$u30ShACFNM~i-*u<9 zJt%S3z(yDoGl-H(n5Sj%nicR^-5c{4jlsr4=gQp2jwj)92 z^6*XK`GRaw9`w+tcP;_Xb$*a^(8cjT;6px7kn2ZvvOKAQzUs6H=#`_%>&5FGjmG*VnL6kme z$OF95R_#}Twm2xTx(2cVe-@+K zTduF&VWf(|k^rLT-vZQo>}RX0>;gA(^bu3SRerEGb7$`iml2KPy#+2P_v*vjiLZSv z=J#~?oU>u(3{6GlbGd|qpF>f{k(S_Be4@f1cuoV%64aEdaW@7Y0Ayi+4#8HD9Fv^U zEFSQu#^P%8h>~Ba*Maci1}KwcisP{9h0hsR!av$%YL{0(vVHTqNUITvc44E=%*bDm z2dHK1=2LNyOcA&Py+9I`Sj>$M4LxEhdwRZo5cnEpvQ@z?@4H3&4zl?slNtJGT*%fb zV^uouuiT~9jcnmzVa&%(4fZ!57uk1?m_O>w^O^Go=A@aP)veQeZ=B&S6DhY;=9qF_ zZs!v$GJ&LD=ps@fu-;$A=b3+q#x=pg6?^q6HzJi5k`SX=OstxPExRzStG22p z>w97$WZ<+y{5AMnfWD5z#zTndmnR34I!xym#D#Z=2+x~N>L$YOQ&>iXhp(MY>QsiZ zGMTOYTI|+R51)7#v{? zfAAfFach)D)s&Q|!?%P@43%d9P>Gh=s}GwX9(h)iN9BTutnHpUAM(5f`2xzofd5oU zIx6Ot{ft6ISB>Y*hYRc(oZXWL3}yz%7klhX>zv=uIaS@GGu4Qt7wZ;JLp!Qcltv`NjzuTCtqVH_akv3>My(Yuyw@ekp9hwCPAH`4CkHvck8+pr+%V) z$A%>~Hd$N6jj9jR zb7S$ISMO#I>P$Kc3bhYR(H8wjtd%sGW2_XNjXFUycaYCtQ zTX(jRjl>q)pLzghrPa17KF&xu7;(Sj=n6M(3S{w08d}h0yi%tE=uoL|5VG%CSS&ez zVbjzBY|@B#^!;0!@|W0#J@vk9()Q1m~XD!T=HN)w#NAZP-Po5{VH4{`g6S{YfiFfChYWcLfH6($$lYWUBGmT3GgKQuXV1L7(8y*{F|oO%pJVm z`1+Qz8+f^lHVIsxiw(zP0B9e6R)o^ja3K@H#V2!v{Yoz6^MLc&d$z6B(PuXTGmF0) z;Gs!m`!chezP?mr_)w!D75roH{-=@hwpJDN^uibEO&{pj5ox=L@f?foHrnvaQ%S-J zG3Gy~1k?H&+i980)n>J8y`9L;9$LOx0t0$H%6xhDxacH(V-0+odMO^?uL8pz_`vqt z&q$lunZT0;ldd^%r=!oOSGuV+f6IZ})s8>`NWTZ*aZ&?&r#K6>PCH&Z#hQ{MNdR1= z*<%){8&jsCA5}h%=PBkJl&h1`ohj*m<2wySQ?qX>=EJRKK!HGzgUt}16NxbsC)abW zygJFYt~=!?7EWQ$Nk=kkYT8Sii%Cb5&A%Dq7fHFEvnx;oPmHS2InpR0R}#uQ9}}&V z{m)4~I5l-K*{w>Ju2}S(aom%d5R)~7cYyj%yz-_a*xFm}ZZp2mg93RUj`CK?)DI?q z3TL+au@p3|Qf8gQhc_dg8(Iz;~Qiw=ecy^$5x# z(o?)QQ)487J!ReKXc?l2WE!Kz#_t9SWV4re2@1;!%tQ3}eBxa#E;;X0Aly23Qit%# z$&cR3CJdT;l2%rxYB>iz>eD=?ip@#Xs(czzy>qD&pM(&P&uG+=@@#t3YOI_m1Q z;BNCfxCN$Gtyh5q>b%gWK~dN{U2yIxQwilPS{{%v1zCFq;K}4NpLG~0%l7TMW}DWe zP=p8l!+^1SZ6@z3izLBLcQD@)O$q9Os{q9txeRd8jHb-&++aM`O3-pb@ar z?x5G*3mYq(^+*6465tDAyx{b9Iv@DtqteS6URai`JbvclNM|D9_C#_$saCIH%GWx*t7aUau%#v{!L1ig4HK%k{}OHqcK1TJb81f>F~j8CGnyYz6W}?OXgZ2A z)V$G#rJDMCoch0U-3hoR*^n8^1}HW*;$@V<3tLuR|I9V{P04hJe(3rXDVP5ycq`*+ zjoM@{73%(Nl{3q}{HXKnk6V+-aK5+&A}Vg-OsYqy+!UDP3weRUnWzOj^wpz<{dnz* z8Res>bX%_Z_bN2u)gIo~upvRy?+3Gu{6HnAvC#}JAz+h>7fGAPow?n{txc-p0Qgx5a}oN_*PcBY?`(9 z)m4S06P@njmfO{`spjD>9b?eOqZ8x7<$QdUZ+BLwOAneKnN*!mX?q#z)LQcBSLP1e zzj7!SIl9A~*opo;gZs|Py?8ABkX2E?*OH-}jswTCyP7GbS8;?kI^1+7<*S+PaXk1h ztx3GDZX)Dqn)q*h7kNZ(q>s?``s{aG0y2J_CYY&o((BM9-(2u*sAlm%$-<{BQZdr} z?AIz=x6XD21vl|~<~R$@cZ&QC0KQnBM{{8F-tg^po#tbKp7gPVs?1C>a9+p-aon(+IEHL> zNdKI0oIpwaV3mN0VQFqo7x-@&ou#?uuI7SZ!fC*q=i{wC{)Y{ndER?7Jck3-pLpH~ zA41%hl%~cAB&}0J5@K;meEj0<(fw-WZ?qwQ6o+-x&VC5A1j`9mR#{kEyJXf(BtT19kM=%Hyhvbk+uu9&biZ0N zR`KY94U^q2P0nKSv%BnFY+jP}#R|OlT}>$ZdGWok8=r;i`3!k0H;2aaitZj+T-e;_ zI;-&t!rI#+=*9l{%wB^VtGd0gi8-84O|!ako8vG{xgV=z6RGe-jiD8m)7x=f<@tq3 z>Fo{f{W7{QH3C^TvxaRzvI$avJuTp8QkPwNQf#QOnHLo{2Q57~Q9YGr%h(jM_7=EW zf~`=?zSQ_wM$XjAS#_1v6tw)gX{Oxs>h)-r^z_lF<5Q(OZ7X*G_x&)#@|7`89k2p2iryx829OqQZspU>5E#bhIZq8(~ zZWh-mEdI_c1&MBwN44Bx4d zThn-T>|!fQKbnWj3QCu8lS)tcd7MtaAP;#JOsQ9bbYzs%vw!{F>%PXoWV z=~eOVh$AGl_x9H=j9A@|YT!UBDePnxNpfpmX-#YrG6F$Jk zb{o1zne_Z%3<~{;MbcAEXFZK}Nqp}c(0Hl#29Ji`87LxsQZY}+ng+7N%L8oLuMNWP z6ew0z*KQGd$!`N#!c@=3tKk2tWvq=l7@sgy!K^1iMf{zb2HuOJ<|H2F1YEKM@%(aW)z z=Of&w!z3ed(?a$&i>Cdqjn~1?^8`H<`#Mj>1l*ok2xcY= zDn~uzJ{2%w5rvI$9pCKj;4Y41hq0D%%B23Zo*4NuEv3Re@$m3a@!_n-?OH%%6bzo3 z?1d zsPo%FZ@(e#l=E|%*0$2G+1#iU)>mKKml)7pMf)gm*U&IOiUs%aa~Qs#zRZn{j?^hL}_sYVPUaC@g~vW5oH^UDS_TFOT4=3LD>f& zF{}>w4Gc2!IZ==lv-(W$S9Q02BiD&KI)r}?Vf@L+d3Z3QWm+F`>WnH@nVNT+T$=v0PVu&xG&OJ^)GFyPt#+7lOQ&-Ng50McFV~ zRLwoQsgXKl6V*e>=ydBw0`_>=1TGd_f2G8>7rKxTJbOO$#@@c?M6Ajs!eVH8`q2Sd zCi|u1ly!iG&PV25UlR;-nOIsuitH8)3(^# z6-?}8J`O?+Hjm#9j*Qm}+uy~OwHkfgeU`r>9GGS;yOEAnaX6UK->78fs}28>K$eE$ zGE*y%l|P-VK=`!to&E=+Z~shY+xTE83)dnCzx?>S{)7QEw~}UD(BR+{2GHsa@#1_Z3mLax$O zCl)8q66@shc}&xn#PesndiHR-1F&qD!;cc0637msU%N)?wS91V;5>bKtp{aY_+iww zuo9oa+e19PcD`(e)Y^E7aU;LCt&N}V(@?6geqA9uw@^bCAYDEhKdpa1<97E~-zO|vu1s7QV;6zEQ118Q`C|x)8 z`Ee9{J`O)8@;v@+0GL9PptIbjTCW6>rN>%vSmy+@*w!G&rwi;crECL#Z|L&CCLC+C{}U^omj93zrU(kv~po{Y7RlsX8d;#nfnxIWYte3yR?=ZuBUB(%w+UbS_z(KRIu#r7W0@wP%93nI^TO;$b* z8o+xi;NEj%tcwFWvY2ono0*Jp z)?A^hu-NuX4DSmMPfsTy`)c%JIHXZbBzM)_BWe~p3yjW&f60=MH(9(ys@7AX+cylaD$EqS!S(6);@cC3{Fj z>SWd7)l;{zBii`n7M`-3qw?7H-w8cG^_{(uFScLy;MORcGaIK4WN$#vWA?k2UDrFSk;~+&(Ga zhFJxyO-WW#xSp4orQ{+NP}-FV@KSF+$(XXrO-M^Iml3T_?^0#Ck=QsPMHlvKeA0N7 zEZoBksx!0_Hy9Qn0S%SY#Q{+0nx#9@_0-Am-o(igoXgwlcW;{x(%c0EolhBB*Tkov zTgZh(PtHw5Pbt<)lG0i|zu=&*7RYMX<6=k`sQu>=*LsC)9W}h>5q~LYfxY0)joD#)?i;ta){0g>3#e-qz z+3FV+b|UKbEq4)$W0@vPaYkjDR`%wS$6q?9&H}y=*OL;5eHe!qk@LPU+>7|V^rnHm z4Vkt{)*7dyYmvKv(~oxXFuHj5d38W*)2(czvxLE$BF!_=o{FY2wv4&2Md8jv;w`K7 zmU>wRTX}Z9*jp6g-(n=IzJ5;Fh7f;RjX{U_9ML~z@as`MJ$)l^(fs6YjjzIwzB+?* zF^aa&Kr~g3*hewFT-KUw6*qdUt|zkYn-qPvJH8n@nmC~T*8TWpSxQ3OEa2uv7S-OF zqqC|4u6)dDrY=GzO>R=!8jFS4AHcp1@c6oplH=HHukE=hOLh&lA-Zw4p&~@P?Fot- zEfuEoGM#&Np|>BazoZ`&l1xV7+uxSSjPjfYE7vgE5Bp_2i6BUho!+xp>ySjv0^U-O3KqHHm z*Nv`I++uz{YkGWzT4PHgV4Xa+tY1o=`aXh$gYKZvN=uWusvQeL?23n&MqN zLP;v;5cuGsiFMJZw!9}7u|{?EG}_fG4Q#Kj%MAS^R}+LqE1NhjX*8>^lbCXuKC5h< zPp?R27y$?4rJp|nVzaTwDH>Q-o`BO65L*zUbsJ`qw1XDMwe>ISQ?$;I^v5Pwj1#l3 zttG{uUR1Ai%!s`CT9jh9)a;XHeLua5<7@8(!c0~*lTi!a9%d;%`;o#RXb^oWT~acr z$MX%It?s)^svdhm%uqp=S7^rS6#j;|)I+g&WVytMrB^Vo@BV$aLz_(?M%wn~s*hOa zy0$#5a!g*8Q}zU(I>V&9b4!3LN<3^KZkJ(v!ND;*EAMIgI)*co$bS>kF{)QG{6k;7A+(-g+b_Cx2^QvWElr@-;d zUhoR(7gTm@T=+iyGS|Wm#a}0|)M&|vtg+&z|6{QpAcaW0Z=fl#y=pC0B^8 zq^L*(*slQ1Lhaxb_60ZjUW-h>LsC%4zn*ozHV+?`UduTu&(yBn&LygPoorX$h*gmd zN1+KAoeTsH#+d-a*)^xzexXZ_KNNg>cA!jksQSAIyvY4J+T*&_Xjj&PHu1~Cq!+^V z1{T!bsJkKoaN8dfk_t_xT_=Tf(SPE*`~GrkR2@6wMjPNP-(tNUe+Wz*;-`Qkg?K*8 zp_Mbwoxh3E&e|J@n>5y!wvujLI%A64$7JxTC=-8S>Dl$=8IV1QPnXuY^|x^r`sx%G z<(9aNmr!=>>UCCFSzDJZq2$s}C}e}Xvzw~cI?UXzM&~%8&Idix2W+n<&v&?#qcb#FCb+cnyP3aLc13A){;s-rb6_xxhBJ;IK3=mz8EA$NF=iZ5z@ z{hE=dUs$QuGxe8B!{qGpq<6-{7f)O-8jyq_**Be(+!Q|&>YZ@78s`_8BT-ANj=sB~ zqf28Z@yZY1u0)ejC5+By?{uVEddcg+%6GsflPq|PKA5aEg}Haza}#xcCviyH4pzOc ze)OQ@0R_GArJ;zF_Lg?Bn!|P=P5)1eYrDhA{+_*~LP&aJk?opaE|PY6_nv7_2T~Yc zUHjbWYJN&~kF;ZkNL6Sjt4vXar8>3Wi@4a46>l^{%Cy^C3yZVab%$r4JGV|i%u1Sz zeLlkZx(7?YbN{*X-0^1-dHK*3aovBb3Z#hz=T6#7=tXc6w-cX81Kqc5Dnq`Q%k(e5 zpde!6wN2jND!t%L!E;CT@|kHJ9H~YOu8HhJ292+VSb)e zqwiPu;UgKzSjwm)gt+J8uKLMmH$HrXu*XmL;0o&{W>d9y7`RcygZqN3Ltd(d+JkAm ziBHfhf;z7X$0x%sR}9ywi*JF;!HL$nqlHUMtOYvAC<0PeTTizoRuw*J?jod<4qE=zX1%QtCMD z11H7!@mBo~xk+EA1?PZw`p(utPRX?lebBi=CAqn|rFNQ1Eu=DfhrVVggkL2;=p6Wt zqQ%k()%W}{?qqhCQ*^fQKTj#kIr=}E zTu);0@yQg=m90czN24bK7JV>oWtUB6bp9*LxlcA+xGToGb7*T>x1?Z-GYp(HDxNkV zh$-rt%}|up4!?-`kq=Em#16MTV+33Jlfyzo_?%B$945SOK!{aTU%_%LqIFHw%1UG} ze)CesH)zZP$~!q2{+Yy0SS8~HNp}hGkp0sGagy!v1rx!GoLn(=7N;vO=Fx2r*xDWO ziONEs9K2XsGKiE(%8@WPHI+1j#}7TJ@kvg6L=zdQFI!tn^Slcd>v`>#m{Vw$)8A>Q zOv3-p)$JgWVsCt&lX|P7!mTMKeB8(~hIz2BHJ~rEd#h+?%bwxFl`$Z_=crej$;9<+ zLxTmW1mQE5bZOl21mx>X8#!~_-%~51Pge%Cos}woW&e~le)NE_=L-QD{TRBu)JpWAb_mop|&HTH;}R@hxTbk|f< z%0rg}OmnU|A$_aH=OVKgHYdHB0$eprnpQZ*#}BAb!yyU3r0NSlb=*g#pFJX0X~B=N z)D4x}Xq@g-j#P5Fmb7*hKG|x&SH*LTN_DDcYmEk(Lzt4}KP@JdPIreT>g7S3E1^tH z^A#8Doq;l~C{m3nWKC%WGhaMiim&&q=y>@tS}2=4ei2st)!>r!cy#9`Y7a&c_?@@M zIAbu@Cd7QU2W~AKek;gr8Q>J0toUA8T7GgmZe57FPc6Sy{m`(`!=I5*dkX(GUP7_e)Xn;K;9tY(NGuCT8tbE z*_E*$i>ej&B)=aHR0HBWg3_^M)yDZggr`}6o0F>pR{BL~JZX3<_Lo(}3d zeqG*Kh@pgaFx${Q3Q9k-Q&#-j6O4}>%vOi4#jIH)+UhngL-N1(6CmCno&I>eUZs>b zUFyI479T(= zRNbyg*!Lh;KlU?d@OP=>!{10t@na&LRylv-w zO&KD9%mo=yzrX06gxEPWMe5H={e})M7}PG;EUCa1R79)p|IL;sGx&4CB8aB>noqIreMD z{Q9`Rv`j_l^nuV|Za+oZr?1?^GQcTZHV>kOEx6`2f$T^n^;8o7Oxhr?^xt&P@xTX6 z15^*_!c+e?lHA7z^DS@P!cw2iAo@2{NA@o(&p-V2FlgRE-xT}we>5}2NBkgTrJE_? zi}_!UTt}E+e7q=-$`QA+cMmA|lladBXlO3^kN>Q|Hv!Q42V#0LP-GUd`iE@ZGJrFi zjp~!CEa;&E+W(UQ0|*&i9Sox=eU5Vb@vp=Xf`D1gvnQ2XYIr%JjsB4%Lgc^hYG<`R zxc?u??jX3J=RN6DMSHox(d5DO{{)7&EP^!XYw+~fFP<|f^!{B!@)rntilG2;lqat6 zcRwH__-4}7*7WXubyE2bqCX7#9pUS7)nmvzOt^j?#lTfA_%j$R5(LebO20@)|Hr16 zh~)zhM@+1e#b^^q{iSaz3V2qo-HSfRZz;j6pn#Z4rG)GE`gse2K?*sl*JWoF{1F7H zWQeC-Nav6rt)SQZ_i|S-QAo5Z_N&d?%)cu-KnNBnrygVB^C)%*Qk%-yAMP4JEF8Gc zCW^}DbMpI3kYj`>5e}rSqQ$}}spa{Pd7Dv#U|U9dcgB-qO;{eSq|;6B9-3Q23lqvwN?Lx>Maypj}ov-h6aduE=Qc@7~eN-|jJBuS5ZwGM$b(ZT|3f^AtNVnM&O4#}pr;BzOIR3Md#3hpypNlaf*G%H;I zrit=H27`b<^`Qb+tiaZA^Icm;3iEJEOwxb}*P(rS&+N~i!=RsNf|IEeUrmNw98iJf zJ$*cJ172eYM(thW{`(!|iUEh&K=5v|CY;lGbcyrt)gvr;@Hi@~Wes;PA$Ko24+`AB zf9NnrHRvEmhZ+$Z%lsrm_p;G5a>xoz~>ES zNmh_N&6imQUfxA)er{50v%0~c?^Lh!NIc0BHN^1EWGUfJu>;4jF_i0@pL)|2HcxB1gxI?)le`z`0tfU+zu`EOXHa#A@%ks&k_2+m&Z*}4Okd2en`DGw)u!sPe#D; z|BUbg(4_Fv$<22I#Q*m)xbGMv4bpBisSE%2W^XZ&v2Y_fq&@Tg?@qoE0#^3S-9x^^U|jBg$gv{|)>M7`uPvds0LPGhKGoJnxd1O}(9n?S?=WD-I(7&d=TR0Qo8_;0(Y5bLOt`a7PR$0jG+ z2D7i&edkGu{AnxJl!zOICHr@dm`wOy-FT)mVlDc)&Q3kNzHs z{$NmV10hC%w;`&*KQnk=U2Hxf$e6{jF+#2Tu)jPdJ`EvLkK>Oq#<;x!u`)s zh~dHhZ-t#`LOAQx$7_3$>Hdy_1WXtCxwDzgLiF!cDj>LUU)_(p)~W~rLXJsnrtkzvBcVVjb|`QGhzR?3lZ-b(w2 z<%`6NvEhS_ID{MJU7F5!09y+`){S< zs0L?m&CW3f=KeGNNh;VxQVn--EZA6{9$M`1{~O|l*80Fv2kHNJP4E-sJtw4{Mdd>T zZpYv07z04k&`$)qDT9N)sup7Z+Zx~$DTTAX{@MvnDPCR3{LkxgB68eh838K%pI6DG z0Q3T$)4u0|Oz;Y1HzNKWH&+~#gTrLj-td3!wi5T}ny=IOe0E^?yfkymGeCzOr6Z)F zo53;x3IfpgvCjwzja)QlW+nRqWTt0$Q!%iynLh8~osVu;9;Ynb+yA4tOk-37c;%UW zqId7{RLvywd#lom#n0Dt-TCe@`t20 zQ^k_gF^AJbJEq9U(T+>Pj)j6rdLEWnx7EsDUbLV5GGF@xpYX*oGQz-R>2>2psG--1 zxxUYa6S*hTKic?E07HukF~?aTt}ZO`HUtcyq7G+{9L#Hwc|W5_>Fv}bn!YrH1N(_*N%!5*A+-8zVe}NqOC5(E705 zoDBD^G#*oW2eR89ta!w*CSQ*EcFUaSG;%nPF#Pjm2fc+DI?O)pescM{jZAV9>#Rhv zO0dsG4w9v|o2O&j#-9%$>uURzPgkPM>5C$c_phO{jkLd>@9z(6Z4rD)i;Bt@_&2_= zs=z%555cS4ca8fY8VFo0y^cI^JJa0D4P*~6>o3u1w&xj%pCJ3^9KZ-s4e(QRy{SSt zoqede3x`FvF`K@m8(9W~hA zA8$I!DP0{nu!bQU!#hPL!+!r9f)6-w;o@q8?U}m?NN_hN;0_D}yk-&MrU!<{>ZZu@ zy}bKcyMu=#e$B&cP73pXhSo+knAXPYD31Ukcfea<58c}PM3-eB6to`bU~|jpPQM!Y z*DT+Fat3)fZn?q6Y6z|b75Ate@K#nNpP;U0){>camNVDde_db$UUn-g9oqCh5s`Wa zDDmePTl(U7(dJT|wlQErT-kvy`#ss4T?Crzh&Mv$Z@`Eg=nADj5;_xEqh{FmZ9Fjn zfDl_;Y}@c<0S5n4@P&}U3_tLo``7*cQPXe&eBilHWWHPjUu@9z)A7&F(VdQ@`CuG? z7^r8ltbiU9fA*fEJRWs%>lh<=F{eRlcW0jh$u<~=#sag$V)n+xEE@NPgu~l8f`Yaa z5SJlxxCt>q2D?c(J@Sw-e&I*co1=$Ib41(CP{AqDcJnJ@t2=1T_U2`)^QyP|5|jm z^Rk!+1y}$J&3v??eH+`s`Yap2+v}Ku`GmM00)c#Q+N9u*Z<7%WfUTrU`e;|KCxeg! zIA{Z-&Q3An+sky!D=f0a?N_CZ0EF#;ziJ^OscY^S(y<=IeK8*$u}ieVa;owYX|o0m zE?lU_oplv~b07(B39kdTVY95lvw}zDly^q;ldMQ$midpLRlCn;ldHQ#||pF&2&m)i~XO|RODieHti#= zcz-O|UwpV5a={M=FDW}T%0)+KP53@gtW;WfLVAa(ajdYx-^-LE%S5l8Ac_6fDN*aM zxM)h-34&N}x*p~kWKNaS3IrK;DRXDGyl!2fX@WJE8Bds-?svEI>Amak85?R$v@dhDtkfOdo(nxArV2OUATXo_@xn#tJ{) z*2+m$tBW}!N}ng8-W{6sI@uR(H=ZJ^Xt={^cNez2OV7Vy8ez+rbU;~nd3cZ2r7AsFAS@6;=x@mfp0LEobZuc58+~M~yy2+fJF#gSNjje?Wfzxwx>SBnUL^>)+z z;!ybs;W>Gp@6AHHkCojrSX(4hL2rzB7 z*Vx=!+D!I@`4<#aoF-!sxE^H@*P1K6_OY6tXV1;g_pS_A$pnI>*a)0;H6`0KsAX;f z-y2W`zIKi3%VT6xfdD*WCW8lzL`6U>xbTe4)JL1FjMy>LE@Zk2vLXu=&c}_L#%}9x zBT8OM4pg9RlMJOBNU}>Gn9ph$cFAQ=de4s_+Pe=Wf0b6!IIWD24n#&*-Sdcv>WDoY zHaCxvtu7z7Z%_}Z#+gMB+#A;9wGQQUNC1bGPvfoFk=%y3w`dB0BCTJ(EbF&?8U6MXp!T6yuDk{^B%bP8yDbOnqa1?L2L~@X*hc`GLd-&BN z98Ejz6DX!o4Vn!8Tq_^&r>#Wyfpv;bi$g)kEZ^$#y`y5T*;~&!9|MP59tMSJr;zc- zca7%@PBtPRUG`?PXm6Cz&JHdT!Xt#}%FOYB%@DJ0t6!%i$rM;+5Uz!Thn?4;8JdvT zW^%xvLR(B$iABeq60~=5JFXb3cVp=uat#8qVqFc)kF=v+&Q>EwK@C zOMSi1LQ>MW%Hc*l8`t9T9yuhg#7W)sZ|B(7NZlCmJ5+uQO(Y*WSzn?lqZH1_v)jTz z28u!f@9D04BbA&@NDXHRHrris4rFHi(G$xW@^)i+9s&%$#u)S}iBbx)U$ z+(HgPH*xj51Z!YYNGE{COhvEnS_CKN{KNwjREFoA_$!AGvl(213;SC^d=>=-1Cg=E7(sPq1?_S2B%KFFDT82*?}rst&B& z5IOX3oM)T?U#-)3s?VbYq+wbLu96n{zy^EvRQw6*HXWAp;+mzXT;FM5*~FjtBI)n- za{RidYTlYbiI|7Z%YBe(gn>BrYlM)=j{@(eB&=MNs#(wK*$kI!(R~m|f0cSzE#|P8 z!KtMQdM~H>(s&?JLY^*6#ZJKBUi@i z_D_np1wxj3R@J+cw5Yy%tQ{;s^KXyp7&(H5w0Fgi6i-=(8fQimStB`9@V=@4CfBMU zxt*Zh=tA+TlR1Tgjho64vz+PNUFhlP81HxO>4ml5s|p)!UBoopdni|35QyWCW0KC7 z=Z!ErS!8gLP?1_kdK%y5e*ZRt-&8|yO^XufFJ-@ zu&g6!m_{&X%H0tN+8B1;5gmp!lxv*yiBYZBd3a7a4RG&R_nP|ELG%cjXkf=wLyf!< zq7@Nbhk3%)qHb`y(ld8N+MG6DHOSh~8#jphocM9%Jsi%{Nqx@yYhTJls@Jr55E4@8 zPSry%V>7(mYYaM-{`WxmxQF$wkzXy zG(>OwjG55BHcEA4x1$0Mvnt&AyB#c~uo2%=+j1i7^jLQ-J2c zIe~5{)xpGMC^sg7CH+Z^^n~|Q5^gPE9?fDte)8L%5DJl zCBOf=b5VV!tX@%oHiX+z6TNZgisaJ)UJiM1R_}ntF{V8J`XISmxK89r31n@+KYvLB+v_pZnOl12?wLaiP^F@e6))w2(Cp0?<4eOYozPdwMgz^xwt_oCn;RUTfcY$rx4FA8o{6@O${It^>yS>gCR_nX>h@Ft%d*>r_ltV^W@KE(KXhov{3Dab$|N&x}fp(}s(>!HBh? z(c&)ki<3{oukcb0*);!{@CHOM0z|O1v$2`X{b7K~#ebs0IL<1#pScN>8qhu+aV;i# znn?BvO-pluo#{HHI6dQZ+*yqrm69k39#MgH^!54!^nd-?Law$8f_076Br;ydA>_&4x5bog0}T*lu?OnOs1MLmJ1X)LbZo+60o z+6sPDH9lhUA$Q$nuFDX3AauQT?L2K3ndNI`4nZEI}p)2_9 zsB{2u)G#5Z(5q0xI~(GtVZiPE{QjLQ0KU|u?D~c3QdirgE-J_X1UR5Z{kr*vBINL+ zWCQ|^QGGZ`(OWD7ah`-kt@zU%x?UT%7b%z7$4)*dOa>iXU;o*O z?wH2t;P2<*UB>>gwVNqffSC?c$F?W^E`;~V`x1PhA2c?A%iG&|P=Xmdad4LaATtsi z3?X4QFTA6pz!i$kKBuF4NcgeSnw^CD5F*F>(MU;=iZe{osE&6^Up27fy*w=M9c>l8 z(lJ812H{I4@vFpdes1N!gv#QfvTM6@5_67V|jcZ zp-prlnS3f%{`u%qfz1BABf9zM*4a5eMzK<(lAMa+XTVOKb_3-G$&zs{`1-e`u($LF zD-Pl}PyQ>FlX-?|j5?pjffna?)*AqgM_OQOMGf<=m<-l*_8q;(28EY3&@iwAV-_wo z?>8+uPQwlR96u}XKjL4-ZKG!b(v;4j`(%WNx6>1TDUD}MzTz)#8mGvdl&3-Zk%mb- z2#G9PD2>4MqJYo`G6uw?bCh?!p%=7mMnCrRD)Z@psfqcDD9sRRW12QoYU@Y`9}Vka zU1MLqC1kX|VxXhj9^5pWy_^m?#Uo;$oi}ch2~bZfx3uuUXPGzQZjn7{$Ld%qSS2E7YvC8UMQ}SciV&4g3?ZP%W67iR-*8 z;hH2_SQYa7JO!y{rUXQB=arnZ?FL)Ni9j&ppv-u!vplj-LCR?eWRmp>6t6GMNfeZ- zD`{vYp#ORqAGpD`KJ^_UGP19#&b`Dw*)>%RSoxyfHGXaakpt&S=K&#V$V+YJnme@k zESIFX9;hoiZWer!RkSH}G#jL@Duo3OERle}kmNswy7DTee<@N|&+-db(8v1CNrp$z z;dtg?I|xX2?M-4_M_@)-w+FXe3mRf%Zw9m9)=l0K_5NTQ@de$UiOM`E9NG=G5dzV5 z5rDvY*#fv)&U0Z&63{p)`+2ReUE2JR^Mje!O{R;gQrA)*;_rTbPo+PUTkMy}dZS?D z`si|c;dB}QmKuE}+)>_z>YXZrJJ&v>fu_5-gXMVUhJW=N!~ zzZPWM+{Sm~*QxfXo$L!plqY!~5VVsguK%DUAh?ObMKwyT5bj5M*>**e)cIRw)MWmFNvSKAGS>X8LZD*N?K-%<{K=KqrUA zcW~Rqjy;<1Xev+x^P95rgdWIYS90e3(&v%4kd+djL+Ge@hm;2eaiQU1m9mI^Ryp^X zt2xAWDEBA#x8fLGXy*+4gvtv3n&a%nN4$%#7srd#^P`~~B!U)t4Fb$z_^9r5nL4~-bFi{Zd)b0tmsr{bTuQh%%1 z+USi-+xEOL->>9(hahiYF`Kx?`_#fB_U<&lug|U%4sjpcZGf`Zb(Uf0{PZ6g0`n># zWX<-L6p@?n3_rFW3_jjW_=bGJLN$VQgY?;t&qp6LQm$(0L+i;(SGiSTn%xb{238V) z=Q68%I;YD8Yr8a7Wuf5;> zno351hkS+Jm1ithOW)X5!lPpdLnu`Y8XpeW%krKx9c6sc9l2vEQ&z5|g#mSZ0BO8i*|=q#k4_sHBPi8M9gjA(`|>v9v1=H=!$e@N)L^*o zjYCH&lU~-)T3f5?F}wzS?S?O)5~oI^L^DB^$o5T`0ncx-O$=LfBz2q0s?fIkX37&s zHD;zFq9Rn5QW#Tra&odh+PY%iwRG$4_q&7b=+Lak+oKzcxEa~y%qRmOlg4b@TOOcH z6*w5qaJ!G2+u2306YKt27%1b=fm4$%^6FLp;hZ^s3l9$utM0>oorv$lb)K6=h`&32 z<_!5O)j0S#asY~=;W3RHYtHRxEjj4SQtiZ8Ep* zE{^Sq1#nXo3|3T{Nk8@_>pb3$tSmmvJC%%N?6liGF1sg{CO!6ytk{t#0xb%Fzu=M9*F8EP-fx_B{a(=;6IQzai*!ZsL@W$tD zFot)mO7?ZWI7YH9)ops_H9NLUZZpF+EZaA^3K~8YW=m30C_sf92=D9Jx2&)oP+nO% z@SEK;nINk^)-fl0umV(;SATeOwp~n=>ZIn=vL44F?4w*rvw18dicZN_byzzxvTt%I zGV$KMBq(dGHhZ!5*K;x!LzSO|BXI1h#j27f+p*;>SMQ;%LNd!vhwBSXO+n4PC>@^= zM+XcNAAj*iwy!L$al311eHbs}{~5&Yr`-cSRN9HkdJ2uNEj_2AlPGl)w8Pu>r&fwJ zwYj}_6^uq~IIZnQzvHA+t8T}5%&dIDuUMhex^MeX3`C4(%Zl7)W15K_ycb|u0#qszUcVuzNZ&IL=W6B)$J}oEzQn1Y&FZ4+(KK@n$C59~X?QuUo?(q` z*nQ!g9K#AE6}??c77>LC?xlsEgvoQ8`Uf$CLj zeSIcRL>KeO1Sso`x5)y`&)O_^e`3hV4<>#ao#09k+AamJZyzE)BmWJMGZg@Y!~IXb zkDLTP-g?rzPEJJjl?(LHSD(In%E!BNG}$sy<`D(o9*i-QYF9FCM)9S6>L&#j5$Yp!Psay(oX4e$jF^R6O9#q?hQZ90O zW_&Fiu^=ry!@(o5-WhCm*QVp>I(+puL?Bri9rWHd>K^%fyd~Sd`}Z za`=o+N6;>jh0JsaQ}-QFxvP^t>^*zhE~BSm(PleG4l)4S>JvE4bczHFJg1GmYW`p@ z%t7fh&ATN|@oX7r5qPZ6Z9ic3JS#0O8Gc9eU5C;Pb z2As~Q1>4#1Kef=6kwt#(NQo-ddw4t4V%Uz2Mi82;z38t}Zbu4Ehy$*!AO?F+S=h@?eOrO$xBbE2@}lf)3g}$ zsJ$)h@`HjrS{3zSpRBd+rR;8OD{R#D&tXlZr_k@2Dw%cp%fD?Z0Ju|Zf)cNhmCY4? z46AZkNsMd1KW|EB(kw1M8Vt*j0#MKu{gfCgZ6(%ZF_x^lubQ4-`{tIYd(Jog#lhis z_c1_Cznbuu5rIJxBSE-u_0I|c1&Yay(xS2l| zC}(rm%;{S3U)RWlYMPo7|-G9|u5QKfirIi24K11>^`# zebPc@0#xHbkA+DbjiMrxj4_*t z&&m~#Oe`#}##CE20~Y>g&Ir$JUkO6rpLQU)lw(hkO{2vEUBu~Q0m-t>7(Q3)0izM$ zj*76mA|PvZDTc){mGwuvk>3fMJ2XFmK zFTyfCg(j*#?bZTM+fM(sux$$C#)EPsv!KG-Lh?;Y6J~ER(So*`$JrUN^3Y}xQG4R{ z`6ko0)_Knj>DOT&RiUj;dATB=UZA`e>-8?zC@c;_M}PV1k0oPCQ9LiRiVCsWO54Vf zPr5gYPXU2;Z~Vi4+L2FNC0H9b$$i8pWEH=4((coO^E9_oQl_gi6`?$AnboKUWiuV=LXRC^Vk8IO;vIbSdLEWdW$idcxsF z<)m2!bh(|+d7yi4a|ytM>ly_1Z<8XX2axG*H?gE%Q|i=tS1jucZ8w_!`YMLft0c?e zc^;Pe-0*#YX;w9*yD-!zC-!|DTXs?QMnJk~Py!OiXC0vBW0oxe)`z2S3Je|X=4xFc z>v@^-*mXt$1ng4{u;X%A@sf#Dn{j~VTP=4b|0QW|J#7ghKpy2hX2Cn=I^~$~ z2!ex&RFu|IIaBnl_J9E7!$~q+*%?A+=1dhBXWihGVrgk1Wo{#Glca{}kCn%hj71d8 zn#Yxx+ml}rE4}n z+=f1W=ww3=%%7*@N0NM5)%C{mFpn_j(s-(c$tn3{F+VrHBQJ@y4saBo*=oV}#Rg>B z;rQ7R1XU{nxkTwz{2-@jAl7tCpT`P&Y7HDzgnsUiqEY5|0-ANAZ}OtSJX#P(Q*a2>wuO+WzRz2Uq441lOeHC*hJ6yp5~_+ z)|o!m#@m9CJtGzEYXK(Bd=ET=X9bM@0^Y`Z(uIqZG zQZOTL@+6tH1_6uPM$gz$bq6>n^TUL2>^3=s9xlAZS;;+B{VVk6&P_A3C#SZ@JTLdV zFA>E2DwgtzEP@%2zZ5|*kTgZBW_wMly-I&pH*Qc_AKpUmwCd5s3)lB7J9lkS!R!yI zDKFCiV|qfrQ0Q=?q6;_C!pn72H8*=F%0tMn%hE?vOV&8H5erVPg80mXjWlB+kwzE^ zPwuC!i)U-3%nLaJIQ(Oz=VhC3V1dEGKN*{oeU~>P!_nj17ybSG>|`K8V8^Ph50dU= zpIuJ&PW^d_#1lT3U|K(KleEIt{;|h?_~$tZ$r&v4TBmA3;3YnggEDF^eo3n*CuZ_Y z0@^s6LeN`>5xC!z4j7UAaG*``LF0gn96~?w#Aj{Ny3^N;_pkpc3S`4|+7=t)BeI8Q zM-_ zz&qcs2a{3e+D`+|m}xx_R5a5B-^WqiTI-6sxE=tYMcEn~^&fw((0&p6Nm;@Wf)uYo zUBop!;~X$Z#Om%n)?$06$7B4)L8T-YsF^SfT?NMmpOvg1`qg3niG_R*awePHjmglONL9P6_&c)JA$D#@^{Sphrfi`5w6T+~p{~v@30Rf7 zC0#bxKV8QHAx%YYC~lkUhfPNamrt<9!fDj-ylm@j@2j15&zZ9~If3lz_g!G_p-oyM zKqjf%Xi4w7QrGXrqIX^h%LneCD7s8e`o8dYilN~!>uy!M0X|TQ$6fw3>EZ2H-2H~? zkb1HZK}|K0$3P@5V4vEOd%KB*Yd3UB%5JpJOh){d(>^L!^XORD4n2jack<(PR(p)V zW~G)y5ZF%KB?R5{e&a!bSD4|LNpIa4_32d6`dFY>@^5BAqNT{}-RTp{9dloXnG>2@ zQ4!@s{FCl~Us!LtuQV~MNtvF+rVokisd zXU_z-^hn?+hsLc0u+6zTZRwr+3@<-x;;9W_`!QlKlrS6EUe(j=w1%sN^g43o{#Zup z9iK*xlmmqsK-HlDYCCG23UOWE7Lwfy#Y{}pw@1L}%*YT^g_N|L)NeitJsnoXxB!+A z^%}e;!7`d5Qw^wVo-j6BPkg8a9CvE-YKbtwU;0!il#`z?vT-(0wC$JHXAG%vcNZVS z0a{5r6xbLakHwd8-#)Cp}8tS z5?1M6M@zk{V#>%jVbPuHnsYvyJ#T&YvEJzA8QBd29;?U%r`^WJVpo~LH@mQ)DEedT z{l^2vZY6gT^kz4Si8~A;E055&(JkiUq)^_YZt06DT?*L;jUd}6YLN21p~M?A-*Mwr z4;MwoNGD0}WtEjH8}k_~d}n~taF^zRAvdKDFlAh+%1(D;A(KHeYV?yqm+fB(14deI zEW5Ea_}2D`%+}#r{AR)a=H$zd-p`I3_RemAc(#6$#8Gzoj)|%NyiP8+fyUersNU@e z1w^Z!Px?KEha}EuZDg%ruOyjfX_UrZYr)k0Q3tMRS+t3k|N?osXV+dff$E59etRWzl$1 zb*DuGp!LrQ(6r<)c?=dPr86QV+rKep=GuBuXNlF~;blopaIWEv7A*|La5a_O?IX=j zPm?LfDfU3)dj70hzzc)a|TDExSG6_>-K-x<03$ z|5ADbWO0xu6Ipp1K7Rr~UG`fLZOn_+JxEj3hW2rys)sSwo>M$5`dHVB)7`dP1u*-J z`|apYJ8+1$pk%(+=Qzm-aQQ}d=n!(pFzi!TarQyKW`@ns2Y%@NRvFZq4L@ATOhVl6 zkKTt%$u))mt|LCdtWrmH$myi0vup-b{6K`?%Om%w?=6)|ZBHfbhFsC5?zZ{$1cROP z#UIkv6QI3$uXT#1|0^=fu=1jxbtq$r&+UAuIc^Ery{*~%&ym1c#T40DES@>7s>k-3 zbva`WmKVVJIaE$bYbwHEh_nRi>Zz1?#DXEqa@`GK=m`mPZy(e1y&I#y6{~s%-<>g2%=?2GK!+VO!*wYnDwWT`YW^VUb zWNq2>j#+X83molCch$RyLrU){b(<_A8(tHaej} zOi&!Zvt&km*h@O>Ng$TFFE)-oUtGL5S3uSzKZB=y=d@f1)aT5h_oZ~ya6iZKyuSV9 z61AOrXk{oj}wW zv0Gw;$LmId=N>NR_Hu*$*`~#-w$tD0s0Cl7#|YNcZ#Ia&$iqqYc#PRWdDv=Zmxl%( ziXU)is}jsDT3<*PH(}|mS&}T76Y4(%D+AW2r24aPv5P;l7>4UN(ba89pQoxu5tFp; z8V01;0z3308y%fR&-a)9-dR@O=p$?sy4#{V5M8~EBGf;m8PHc%_QQ z;z=&t3mLe8P`Da*;bS#Uju`zbQeFc_Q4cZqS2BmP=EGzspEqHF>pmYoOw8UH?UeE` zA&V8`k`Uev|Bm#|%{`YDyPEVvzSejuFavB#c^{D%kS%Kx9- zuGB4a`}njP!i(aOnge#z%AN_xe^Wt;Ja(v3+IK5hJG{=TIeeX=99x((-b3D`U&HEV zSw-7CJ~K~~#r4AAKt{n4hhHRH3|2LKH7&$LQ7 zI>F$@GcO{s!BSM`B3n|>T$=rK?~p*L)E9=2kqHui8y2FZd}qZrs~jWD5mV$tvP+sV zNj5H^op1}_0?w-PjSAI!IpI=|y3nJ~;|hQYhl-6?0hNF86#uU^6sk_Mtfx~VK%b^O zMs7KGz2P1(OXLDn&QwTkHh;=|Bz{I5#g7}mHXnJ)6bCLAmMgyDL@rD`23z{HG=;%} zWq`;k*$cK?@I1H7_O7R)w z%rX_#H8D7Z3={*#nVECBx}mdu)8woQy?i4qAqa`n6EVx7L+gpF<>epBN94Zjz8;eD zPgXN;oeDGqd=yf-EaF2mnuCH~WD!WJ_W-<`@3Apns=m#D@Am_ z$Yicz5YZ6#FUO{}RjMHt*2jCtoy8w?v%-=_eYZ{`H(I!YxPtF~g{GF(#fCK4Gd16w z2R;?Ih@{RX>&{1**0vk<(|qYmO)Ys;<}zPqm?F*aR(zg<)EaIT?b$>qo_%q)aflk3 zTB@`Gs&?w_T9R(_QuL?LGuL0=in_WB;QWhJsJNJ>KXiv6$!EvaR<(a7=2P^#Q=RYt zYQFIg-0~JPXY<@dcGF~0pJ(w8=k=PNXm;B#`j6s6nb_=AK3pUa)N7^-9NbynoPBLa zxc){cvDuBjM3JM{h>$@XJ0{vv^aby#;m-*Rb^8-E zL4R#rr|aA8yjTNX`UR0oK5`x+q$Me61_Rh^LQthXSt_-V38Sp++oyx0c$~;a`LGxQ zsN$^D+ZX*Z{w$>}Mu|*9CK((mXr9K#v3OAY`ACYF2_j)epQJzi&5E?10+3JUukSP4 zDfuG0E(X&Hx5{|~Wj~CCcwVxJKS27{Pe0JDNX*0)z_sXk zE4T>2EU}>4Z55vhl$dpx_3TEb$d0KVmf{=PlImtZ(Q$lU-E}8)J89uAsaoJwsW-1K z{XLx*VVW%8F}$e*5UL=q{pqzI?tiY;PMLiZ`6iI*@K34Ih|xy;oo?Ywuu+i~f_;#p znxP@}SUbSk+lhQzlzBDbKZhSl#=WXKB8EhV)fdP5PkjZ!4XGgc&nTr56PVj>m8V$r z8#wVZN0)`|t869pBqSz-fH9vnZrz0-R+^x@Kqn?T{{GgYGN2_rDYwlIW?6dM2)nW`y!XCyCLp>Ny5kWrmYQqbx|d)jqS;Z{0q;H5_`g6zExpLGc;+ znR8if4`WSH{uV!iQ?uv2>b()GGXa7{;a6xRgSYxxrQLzZ&SCzOe}A20@gQ%qEuV}) zC~m#V-|}dhy02pcEW4-(%E+N4_vcvu<_v$Tz8om+%viDQ(-1&((^alMP1Vbkkqc(C z2OW^45Nx%7E~y(IlCw>c6q1&VI8!mDrzC&ZQ4N=)=i_sAo_O&Rz#0$!o2@^~1;z!I zjd)h}2y;Y8b?U*K0L+wAu|=o_Ljp5eam}o)urDin`ePN@y*Ve)yPIJ-5d<3goC*XK zRmxyS&!W=8%Nv3BGBFp20@SU7+sVW)Ay0hWIkE~|ZG`gdPaXX!`RW?&cLpEbA%O7) zCL_GSCx6bik9M>^w}oQOrVi^$0D!||x`^<`t;SV)64_?^Rp0(8bl*?<6IhHMM@OT^ z`kPLFv5Q8c&20>WH^kA6000Y*cJ{3jwrv5_fh(BP_g^$wrCUYJEK6z^0%v7%ditpf zdKg=N(S!-9z9-gQe6vzN``mmiCEPP`mnbaj9|=mbyZqsU!rvDT@uE1g0hxk#jtQep zq+Ptd=(|efbf$=>_T>Z_$LFJ?s8mh;JB7FBLmiTRL2t%z8dpndt=Ya(<$DA~v8Ji_ zOn7U`Z)Lz26~K7ww?u1r7@6ihpxkX_5#+Wo8Nx$sjwQi`)~4o)BCA>YBXy4C7O>`*_imx>Q+O z+k>hT9!}p-=NlLYEeH`(oyT>Yf1_u~k1)eN8=GqcVzmWL{th)1G>p=_mq0LUmI@}H zCgXseRE5cHwx5xX-@hM$!^q-2vM-+}WxXXNb(@+7Ug4N`D<_mIFU!oB!^^p)9PhrT zE8@!I_PTK*nw)|-RC>pR!iEF=?vWn3kADMnU+pZ!m1kFPb5W&}A8Cy8s^qgapPk|z zgR@}5AEvt4H3!V=*iOx@6TbL?HLqxPKkEG|MDN$3@QF4PfM%kI%N)wpT5{f`YiiP5 zsKrm4h2bIZlE{KequHuvZM7?1{!$1ADwF2Rql<%nD{idufHox1)?Bf?o6p+5_P093j0I?Ii){aKLvmNNc@O!UUAK6eE6q4FZef5k;r+^one7}^|`1)9n-=5&SjAD z@lXRDS~g9FnDd-K8fa{xp%q1aT9RRQeAdx0%QtnvZ27!@{q{Yc+PIn4q4)!JHI;u? zLO9tXr+^hl5wA3hm%I=D**TMCkd#L8UHGSOLJV~;a>opRg zQ%kl`Lml;>NuZtr1n%HR2nRM05ll9o%%fuMo;|O<+h!F;t?2B8TZ=ob&ac&aBQ-$c zN`h$nohs_saGWXVaPoKMnT6>xAQo?*`V(#TB~zX7gzacz+)dCX>@ZjJW8QI$zTsn$))c=~P^(#-?^bI=Ou@4@Gp57Z?Z>T|kPp-P z@!ws;qOn+|gPys!1p$DvdsHvHinkVXI>MT;zNcTajKJK0bK0oN>(r3B931c6^#%b* z`2Agg{0n%gg~%Uy>WUGM9aLQl!pic|uwnPA)<#&d^tTB_S+%zD*H4Hk@S} z$OX3CpBOIl#>%>yWeIbwj03+Y-%0n282-zyn41ZA&nikY0C?tkG)fQQ;jri@;b_fRKy5cN9ZFhjyI_HHM~4{9OwQl@y6c6u0q*4l7t%ToR@GGU=6vUTEDpCZEz_ePeou?X9nICx!!po88IPd&#JLyHdN7BnYc}8l)w`Pzw zHl?5HwZWkz!@7<~3O*x&eyUnLPmJWMc5FmspDjJf!1f96tx&(=R@JylijB#DSh5{HWwg)?8MR!U&XVY^!5zJc^r5@XmLontB%S)ODN|qh9|w z(%!Zrgv?P#%gS%Y2@olQi}aevOBr$49h#O)LoqrCbILjvEUne5D#gngyqOk6w#l7n z(>V&P7*-it;>$YJ0N~@-vAUXh$%}#sV-j-)$bfBy;oFvwwxX97m6e?aGgHO{gFS#QvypnKFTd3WtVj;GDbmV%x77+q zv%|Y2SKJ+mWMnJ!Gx}1#z$$IguXo47=#i+9>fa>rNT<-dskOJ26UU{332tQ7sH}H5HjYl+4CfW;wmtXfP8g^fF8XtEHgGa zJV}CLf((@d!hq@SF0({bw1CF;YVRRSz<^qPDeEu(#YAU#o=?a5bg+?zUNsDJ9rsv!K`qcwS-i=Zdow_C0OdO^%znfy?SkjHn=zXN=-9qoArOvg{gl)=OhmlB^XjGg+-70_=DlC<)YoxDaBKi?4SrJ8QwwLdTAp5MO~Qi9URL|=Z4 z1eW%NqAcn`xT}TH;ko)!2PRk?6~h3NP{=8;qXXZnL%kkXvIN7YH(2rYnbKE@Hu(3-3u zIQ}9cCb)U(&K3$9J|&1sV0CqrkYwTiuc)t%i>eE^CZs`7KoAg6B&0z=I;BH~?rsUC zOFEQrP(ZqSDCuSp1O(}jhL;>Vr2E^0_ulW%bAD&e?0EJ*Ppq}JsVj4kIGy%=z%_N9 zW7Hob83N_oI2(5S0Dw%Z)U2V2&uBSQ??Rv!@|dD0CM(nK~Z%Ib>`o~nPQDIGFk zry#&{vau1&yf|r4S2rxry|&rxdm27cP;fTqmusd^EMm;r#X47JgDR;qji$RegaPn= z^b#}Ku0Aven~4Uke1e|_H{u|VkzXi_)~Y}hOcHGFv8#mm@-@P4T-nItI}#t+nS76k zt1fExQb^#uA>R^CFl;mPJu&fb8EEuE2{7)@^@UBAw^S4#Mm3Odb`AXhcBmIN!e5KxHCK)p~p1H88?KIg!i~=h^)6o?0k+hL{v%9!#HC8j}-j90F1y z8p2e%eJti9(;-~XroFjI`fx`qc`4DztGEU*a}_k;DE%7z1{!=aCl|JMT(P|-7*3#})29M(|lS_PWV ze%;vznjqD2BjM5}oAQdJUwfE6)C;o4P}msgwt7zp&Tn(Yk7d zschJxIHh!NKS&yXF3hU?OR8Ooy33Wf#NfG*FWAzWULRc}XBvMjDDHAPEUVD3qW*tL zl?rQk>+i3tD~U0KG4Ywh+ST~U9#S+bf=#P;&`6d0mLBlre89xNKU38;CY|8Pck7bN z^cR~Z+!o@6xHHLWLg-{p#b~nyD*(4G4SWOpA}jB$qT-7M=D(sZu)7ZnJm+XVIK6z% znycV$p?-N3#OH5)rR=C(|7Bqfj(8nHebrcpzwe+ml}2{f`6cq>%cRi(dlD_6QWX@H8~|hkQ{*$};Zb5VXp6;sis`F#(UNkeQDE z*GVAoDXA`C+cKmQCUDdWOPP5p;!zRFp11jc#;MmI8sRXnTD{aT+do1>e5bu|R0P`t zW+1Uq(fgwkoetRrOE$dbzZOYQ#&T}3m=Zky`;X)4DdciLFM5*i>6b0I;?yftWgLHH zgP}Dzh9OSqv_>bZdkwSONP4vzQi5eR~r*$jI*}CC;gvS%;GP3M!LG- zZmpGJD!yt*u5*PK%cT>i3jVUXgFCX81!)KTru$6FomUB8%r7SNKBRY4e2vgQ*%QfX zFyLt^3p~{tKTpOaHD&~b#NFxgf#JrT4d%7EsI(YW;rG3?@bJA6H&O3ggNulF$UZ#K zWJ=TG2k?XhulJQOTx8Twwj=u_gkxTAg+_hmI;E^PDTpsCN_!$+Y<`2QC8}%wP;S5<6dfCO81o zqc2rSx1aI#EREPq0E_|CW4CC5WOT!r+aEuq^$#QS28-p%SNL-ph?JyM-O#tzL)ehf zNpYJSq+$SJ@llVL)mJO*K87!fQze~{w%T}HvGbZ5X9tTFt`}|Wo-$Jhi-A@(o*CbZ z`^Jfx-EMhNy1dvXO1V>39|`2JjY%evSgXGV$K^~fw+1BIQHkHY6q&ze*LHoGKWx%?G#n3SeBTq+s5KEJF=flbDN~UOs9c< zl{te1-e;OvRbBseAZeb0iSrQ`9h%he7yEG?10zdbH(?#pq$9T-ygPTK?Yo+6X2 zzW*q4(mC&i-;*cA$aF*5o1Zg&o6cd8$jz3A?cs&6VINB=K2?HTyLs5N=pJ8<>SMts zeR5Iq{QB(2-94UF`75k+I`7S{I*RRi)~o3)Ef$C?Qc#r+W_Si!~P}`upvEU z4!h4O<5n4kD)9<6BkE=`g36GZ6^WmsqpfPc-EJP3k0!fGm8d3@bI$P5%pZY2vv$+a zm1L6#1(mmjxY91F02+8-^!tay#kZPqb)!x3PoGLOmegOon+t?6tB}h??cUZ%ydk&2 zEY>ibhn=(w&t9~x1aWhnotPu6G6O^-s}*%qVHBsnbbMn0<_-Qzxyi0Ai zF5vqfp++0K{mT*1zDEZ;%>rG%3DT`>0h((j299u(n+ha(;1#XuNu=`-E@pDgh9nl+P_9lV=Od%AQKPWQ8745(!Cm_ORgKRz4OCSe*G3hd+=@V4I9b zfBt(pcwikqCed5m;OPVQ4?t!u_L#@5WZTJhI9_1f#dzU0NC&)VVq%T;d@^J*s0h;8};#di0SXHXO=9ZDapz4Y3rMi<4$k-y-KxT%O4{t3It4O zx#sBNT5L7XrVcD%aUhx;6a^f4=p*0DP8-s?tsDB$A~p>hA!ksp!q`desR8Wz=#%?A zXEGY(Pwo7YA`nOh=+{gKi0GWGu-vbS9+>l&hhB2><6H=kj`JEb>)bPc;3{$p^whW> zP%4hTzkFZ?mlMzG`tNCr`f_epnoq-qn8xq%eMP-W(P5dx9FZMo-yPE zZv~-fEhG*9MJC^g(-B?wxsZDkr(?31ctoQllcAZk51Vo&_np=_Y(vZ}3V>Sj?AUpT z;wh^N<9S$FSG@XBE4H!V{2%_CAZ!-Uk*1tZ0~pS%ED7i*{H~4D{|0vGM-bBV&pYEc zd55q}Uuy+_%*@tA2LcBLt~ZgNIucISPx@Ecv`B;6$qfhnO>HE&p;p(9;(AGSe2=BJ z?zw2ju{LmU)E0B3z6kNMpSt(~f?Xe~_mG7rB7A;TYq3|arJk@1YFfi!zNXE8bKbxc zJ&rESypVJ%u)vnQU40|buyiU>*uvU&M(dnxrxH6%p&5Wq+AD}3rjJn_ zbs$lSmUTnJLnb-3Ea# zmD%Tp0i=7jYkWVrr8n_TJRx#Zi9*2?G_Be=eqzw)_aF*$YJ`>aK^A5xXXKuq$UZ2M z`>JYBBr@7b^M>L-_>%QT^t_4SgJyriJ6+eoTy;U6nz=9Ml?eFI|9S8b?U_P*cDyuS zOpGb+C@Oq%jM7&s4cVvD@*BbU=g&s0%oQeVMz48&HQ)*4<~grL`8Li49PF=LG~x@H z*RQVCU0AwRN$+j4{o<>j#cu_7KUD28`Z3osKHT(rVIpIwt5(C03t zRP$6BpM77ir74@>xT%%h_`(wPV!zf~ZV2{HZQXo*nHUJ~4ck|39AND&);K-Sx{%I9 z#x~G+jVlEUX`B|60)Hm}M6cs$v{sX`(b*S5szywPl3kr3o!E+Q>G^{{TRvFs)zR|b zfNQepe)t71jq{i2MiwzaE|JL(Xmf`MtDaCV0II_)8bf(B32Ag92Gms!6eV6{*p35~BA1Pd$N(U3{s8 z{cPGzGjLg}J}lry8~9sAmRsjy^pY-**}BQVhzpf?pJKU)Z& zc!N!VH=kQ5*l}}`Jm>-%A zbakzLS!Sg7>&_q!O+iFoOr#s;K?_SZa}eqj$v2}5eX}lrk|sShn{dB46^T&6;F)51RKb+wD6rVfKVTb)Z#!$u8;<{BXjAFUT~D3~@* zAt{+lmF1nmNhdTiiEM`e-1)iO3AX8X1BB+xA70pt^%`D!KkGfmm;y6=j2yym5wM5e zKW+hkDVwau_5kOk`s`l!PKxMNMCVUG^*B}UBZJ@aTj#xYj|$uJTCix6A^Q3=?~2y~ zKHuS38fz}|;D<_!No$hm5JS|Cv*nvQxRFQy8~>MwHt*gI4N7C_jGajASVvVn@NrM* zD%Wo+r!ToW{5$jJt?1NB;nHrI52G0IM9$3V)+WVnI$#yHYdstl@(gasnhx@dq?S9fxw2bFeXq-S3oxsrY>&+ipM}AZ zFOaNH62d4 zfN<=>FF%rB{-Gp;Qfz_S^w;HujV6E5$|;lwNP^PJ^}qy;S$z)2?`go#1_7WRLDyo(og{pCmJ?cpjJ~P_>tI4V+ zS7~H8J61MY>=Ct=!0qrbn>J@@yOXG(s&Mk9IcpHq2X(p7>2Y@%*9t7VcRwpmNHKfluGLy68$X+Uu;@VL_K8 zx$p$=I77rm8Py9bZ_+pl6%KIwCDJ_4pe}Zg&0k%u|0tsWTK44*?9II&;9{mAQKe-Z8ADcdD>*Isr~4C z*FIGIrkees?8VeFW7DDojQU6mwws{ivm!4t%Ye({SvVYSOcKK!I}nu-cXsq7vgp_^9x!ia-&L`D7Y}YCa^)=^9RBEOE=Ebr5E0FQ~{gVz=chJin$LPa{q?!boxG& zD=7$O{^+8?eU6iqQv!fyR2{Dd4sxw1HGixxnUCrhFF~%Kbv68y5W#KiX3mhdBd<0z zzX__K`Q?|zpzC|L7Q__Q7naDmh-U$uD4kY)Ik1P& zwEsfF!-Zue&<#HC#v&0BNR;xCvMUcl#wE+73W0A)a2)++amHS>`IsrDT-%oZAQ51; zI#KPEAJin*zlQwQpiKhVc6=CQgWcWN2S4l2&(;9n)^d6^- zEM<&6Uw$kh*lS+TNbELxI0FHe7;zFF;lDqXyUQ8{C3@Wn9{2ucR}eGJB)+h%QqLHr z;!OM@<=ractmLmKe8SSpA;?R{tl_;q8xBgpcq^KmR203w0E0r&I->L(*Z9%gOWU`9 z{0)ykU>LK13S)cAP8!8$gfl;cA=tLiFo-u40D=AV$pktcQl0YYLP1*0O(r*rh!P^T znm*`a;Iq&ab6d}Yi21&GZKPuV{f#4+Y;x&Dj7BK&|wisMQqpFo^c zNyydz%4}*|%Djt!Do>zLJv8IVn*z7`%JMm6rD7On(2t6)AQ!>H+p#m$gcRfe4E(6f z3E56*OMe+sC6-1SBqgOrO7+VDWteplVCm~xY{{A-J8 zAL~x)K7rBsSIU4=S$f~vP$TL$AO$S)U`G~lPy5N^eu4Sjgl_B3>f3V`P#KHl5APc@ z7y&t(HoUPQ5>J)|uL>HYLOF3wPJjbz76QThOsQzh7-=ij#b%%8L+uoW2UHd*%$Fg7 z)5ubbrXwU~v#$u&Z8iAVo1ip=uj<998k!(_*gY()orpKz33=M5x68g(I5uTpEqwqf z4RZqHJO_a{qGIJ`fMmC)Ik>1cJm7`IpIO-;LbQZ^dkDo>e|L~vtlMK$6#nRl546is zlY9NHvQ2p%Wue~q&1onslY~tUbHct#*#V@x0>o6WpA7?I(&61sujMGs?R+Nrk7Ld> zdF&ECZq=t?Sawm{<9yf**&9fHc@(GX2xSZoKf+(fCs)&-iTQAO+LcYvBYQ;Nvbh=Z zrcke4Vg39p!-=)9dTNpl)-Zw;8~vRr@fOWW zy>Rmch*p26e3Y(a#s>?+S?})k&_3ymBw`^v5lrLSo{2t*ERK<(_}Me)6r zqr>SfTGvfGpRA&M~DNN0BFqB(gs2~(F$*H~AKk++e?z0I>`RQvdyg)7>3dV8I4_$4y zarJ}hbvAaqH}vtVeS%{9`8#>UQ{k9acMv9YL>Wv$+@~LxgK7j4Uek6qFh0&J+U=CF zvA;=eC>%aww+d+D96WF7W$V`8*$zG>iIH`~yX@)kPc`r5WA|1PzCN8vRnI55=>Ygeab27V_&N9P=`<^{8Rm3C6KIwz)Kg~@s4T%}-7Qn#)$ntkx7ZQZ8uJFAUYCAFod zA&|6Y{q^kt=O6^A$kHOu2>2Of&niw6&-yDL5w{Rm*5oz$Byu;Ydc5zcl5bFpWouhN z;0X;-QF*{SF08sO3ymX;xvp^ndYQ2*nY>s^CP-Q)yf4YnzRQF*|Cn!s;paSHM9<1* z)C>Ig(l(0euy)wL$=P#ZVeY1-V%)0H5PSsk5X(z|``H6B6&Yi`L6YY*RKDWULe>UQ z$Naa5YR;*0^^9x<9J-JPzIl~}S|xxZ*N24~RTBqtOiVr^VPM_-vBz7--JCr&S`Q^- zLgjRQz3}3YEXr_@JjJq>fJx2=|WFi}6n?8l0r5q(G{54g`_Qt6+xi#h4Cz-sVWxW1#JuJ_B4gxuZ~HNyQNRHei?+X9N*=^);qWKV1 za|^y1jaLp5%%TeOK3QZ|@Owj=+KhdzAREx#r}kH+l-&Qr#%FbnB~f=t$zk^nnjUIYFAp?`oa+mFD9L zyU4PaHU6R}2z}A?R4)Ec%tJgGB(!4#FD_Y9>eEKQdMXYn1lFZleHAf4D!6gn!dsdM zD`BF0D1N)EZl>BN5onM=w>+6rT#8tmcqD1(+-td53Bz-<;1`3j6FXdJmk6!Vx`s#CimtSTELWAr`?XVX;j&gdVS`b6z;A`G;`4Aaaai zl%FIH$SuS#7=*(Soswv*N$$_%*5)TAAWyg+glx0tarPcKT*lu$#Rv_ zxLkHVEGBZ|Q{H?VR73-FEl)7FQ7{ui$chTSb*gbhl$gHf)o`Sobc`GKF(Ks)w!Vby z$n$7aW4tJKCfypCJJ?%DH8>{p(dFXAtlmh6Hyj+$)HZr15`h0v(ga~aPm?mE!)$MN0*ar>vQWH8nG`Bzt-m)Y=(u55p=|K@qsU$xfL zjlwPFLb!@rcE#Mb^56@u11k<>{Q}E)r*qe@r5Y&ZwNnCLNWAp+xuo&8TpsZ9|1hXt zH2ak#C_$|Lx;Ndpsqwu3gVU++!4+PCi0@74243`^viYManvx6MslmlKTY?nDM?KYS zX}@V4*sEQknYZ$1O~2Jd=hj0oWtazrhX?V0gyHj0Qole|a@bJu7Ro|l$cbn$$U{ra zqi503n*c+!_1A!1&$qvOu<2Ic_;$G3c`ufxY-s5HVt>`CQVC5{t6p$IhYa+wo9x5A zj*w60w*?jz5dNXL)y`)cHKzcH!j@yWfKbTuyGNZz7C*dy)vR zXFQCA8QDHG-4FlIx}VA3_jGQ4(7<`>C?YDJx&)H*_ptX#{hxOC0g6X36F+Nm_zf)3 zh(i~y4Y}s#K}($eddABZ&;#Q{ec|O$3^)2%92c8N5T`8G?Bq1q;xWs6J4z^W)@N&5 zKSfA?RyTVRm_Ay-W1r-E@iUz86|BU)CH#VYuC=uJ)P8aH9wu`$P-YO_94d?XLw zy2*7h#1M2HG`_A2)$^VFoqPP&$Z&Ra)CZ9wbdr7uCA_wlQk&uZ>`R2ZI*UbkRCWuN zn7*Au38Lza)7)8aR~oT`(ML$fi)C33Jz)flCz`t~}>q+MmX_x(WUQl3`u1?EEZt&w4S)*iX$q?u8sBYr+?ssJ!kD8ex-0v%u zdGX_23$qC3I0hsONo){9E?w zkj?FI*zxId;u@off)PqAWMY1PJpP^J%*oFQig-Y)HzyCDotS`o&iok|&G4^Xu8!tt zeb9Lj<9!zRQW)RR{k*SF1RLT;F#-CdZESpd?Bl*8#xn5s+UBVR{=}_xLh7)~?zwtX zS2X3E(JND>%OB|B)X3H?k5lu9QOg43*-O?DJCWMach4a0Dh;Q7x(oYEM4QT+vmZH_ z=rWkmxTWf}u)FaR*D&kv)>sd5Yr^r3>{j~Z(YddvzhdD&wn$EUjo_^K+ONT7LSwg+ z0Rwh#I{Tsxk+IV@aeCshk~PWSJ+1X6wXlR!9x}Zp@5!ynk%s1nq-URO*{e&)*W7uz z?GzDz4Mt?y19~#wicN^bJr`q75Iz`}*EDq3kywZ-%#g@BLhA#QSg0b&;^hA?!<3WDp zZPuS+PPq&Mv3XIC!QE>(^@Z&A)U^svg`dG_ZK+T8hfDAL-5pWSc9XUJ8ZmFz)`C{k z+0UvDaf6833_&lm#3Oi)%is;dsd;0m;H`z&RS!>@@!@87*F?0gr^#(r{zR=C-=Fbl zAgbCuf03R(FE8ttPo;>BPr(ok-Q4QlVO2^Wq7*eB5Jgl!F?Ak}%9p$>PMP=i1{+vL zN#I~&{+{1us6L|-mCyKOMGj@1<2)~$frqT0D9s~Ls1z|`-`m*fwybDRXw?K8#n``e ztA|2BoAWuatfAMZpO>%FV8z}7B@IIlTZIkv|F-X!%#Jkp%47IN|NLpd_WGj6VZ$KZ zGcXW^A-Jp|M`(a>^oP#*3VJN;0zJXlc?AoiU9)$ZXB8^5#+baYI}IJ|ade!fA$I^Exn(q8l-i&;1Ra=*X5TOh1EGy^8^qQwvshJ26QEHie~|Kxot2PtRI zN4kJ1)uPrnfbpc@h^MxtZd_DeH2ebvMI+C0e0Fb@d}-EC18>G%R*!8Ag04T!>`1E< zW&ZG8XFE5l5+L;#4RJg{Wa+R-sOnk;~lPM?P0D9 zN%}65+$i1sPLHq20j{UXmkf5Tc;Ju#E>+Z6E*#$fc62Tih(kEQ9U@{>r9aP!+pv*G zNzlrmB}TniGj}L-V|}u9lkM2a^#>Iw{hB^(vrs-~F^NjU1XqbTCfxanzx!F2AJo;% zy)4TY{0WY+5eMf7J$61$@LrQ{sH85RN4{JY6Lj>&-bqev?W*}idd2vsJm|;lDR{_?omm?r>feh%!&qMo-pj@M<8K?c zNd_KaqC-`ZF>rfx+ss6EiJj{>e)ordXo)4iq@ykQNB$JtvMy*=A`ck*selqo+Ay1% z2<6d$SAV-}Mg?)3#cM7aAC#)?a1UwNUGGsr?>~B>sdYnlHLj73lLy>p@KFEf13BJ# z(5`@cs}Y^qhiT3U{`9yr&>NQgoNdn&yb7&R@GuHv78(P3#JLW7EImC7xE9w11C^Nw zU9xMNIhh@?F+>^$imDbcGp02&Jd-2jp2SkiOSjXopSzFC3hoP5;;kI@3Pq_{1p_2rj;BP0PXBTdmp&?!Z8?hc+| zmPd(o?wtL|Y|^L+M>b~_$3)FBgg8de_{w$c-^e@%Bjel&P%NBK&i@;lPsDAxljUU} zlVvzjZKnpm-5nklz(WM?V%O5R**7cSwKF#qg`qQ|!DgNlskO zh67mZ-9$3Wpv0PW;G2B=RsxE^sBZzRQ@+F(f2rC@J8XtdEa$I0dxbF2&D*rXTef;BQs`nH^;eP)bb+E zZ3c)wEA1ieLyil)m*8nYL(_8Xzt{AG*YGS?6ms=S$6~&>n1hhi{Ii9kX94nAPeK1% z-KG9#U~(f+7(yc&BrpKx8@!B`4*XW5PgXgYAC^~cTz)semEtI|lEITfV+V(Cb4cln z!E@rwpmBw_EBZ12mI{6kX{7X?c7&24F|Q0{%qxg)`n$`I%#nk<>qX#X?^4 z0#mt({AXyqz*Jt(DZ+VBiuuja&m=)W78ObX;meO24;tbbdM>U6MO*z#=