From ac60206cd11ad3222a9bccc32f8b11ce68106b40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=89=8D?= <1104239712@qq.com> Date: Sun, 9 Feb 2025 15:17:40 +0800 Subject: [PATCH] ctgantraining --- ctgan/.idea/.gitignore | 8 + ctgan/.idea/ctgan.iml | 12 + .../inspectionProfiles/profiles_settings.xml | 6 + ctgan/.idea/misc.xml | 4 + ctgan/.idea/modules.xml | 8 + ctgan/.idea/vcs.xml | 6 + ctgan/__init__.py | 16 + ctgan/__main__.py | 231 +++++++ ctgan/__pycache__/__init__.cpython-311.pyc | Bin 0 -> 545 bytes ctgan/__pycache__/__main__.cpython-311.pyc | Bin 0 -> 2550 bytes .../__pycache__/data_sampler.cpython-311.pyc | Bin 0 -> 9189 bytes .../data_transformer.cpython-311.pyc | Bin 0 -> 13850 bytes ctgan/__pycache__/load_data.cpython-311.pyc | Bin 0 -> 168 bytes ctgan/config.py | 48 ++ ctgan/data.py | 88 +++ ctgan/data_sampler.py | 153 +++++ ctgan/data_transformer.py | 265 ++++++++ ctgan/dataprocess_using_gan.py | 141 +++++ ctgan/load_data.py | 73 +++ ctgan/synthesizers/__init__.py | 10 + .../__pycache__/__init__.cpython-311.pyc | Bin 0 -> 776 bytes .../__pycache__/base.cpython-311.pyc | Bin 0 -> 8812 bytes .../__pycache__/ctgan.cpython-311.pyc | Bin 0 -> 29580 bytes .../__pycache__/tvae.cpython-311.pyc | Bin 0 -> 13794 bytes ctgan/synthesizers/base.py | 151 +++++ ctgan/synthesizers/ctgan.py | 564 ++++++++++++++++++ ctgan/synthesizers/tvae.py | 246 ++++++++ dataprocess/Dockerfile | 20 + dataprocess/dataprocess.py | 191 ++++++ 29 files changed, 2241 insertions(+) create mode 100644 ctgan/.idea/.gitignore create mode 100644 ctgan/.idea/ctgan.iml create mode 100644 ctgan/.idea/inspectionProfiles/profiles_settings.xml create mode 100644 ctgan/.idea/misc.xml create mode 100644 ctgan/.idea/modules.xml create mode 100644 ctgan/.idea/vcs.xml create mode 100644 ctgan/__init__.py create mode 100644 ctgan/__main__.py create mode 100644 ctgan/__pycache__/__init__.cpython-311.pyc create mode 100644 ctgan/__pycache__/__main__.cpython-311.pyc create mode 100644 ctgan/__pycache__/data_sampler.cpython-311.pyc create mode 100644 ctgan/__pycache__/data_transformer.cpython-311.pyc create mode 100644 ctgan/__pycache__/load_data.cpython-311.pyc create mode 100644 ctgan/config.py create mode 100644 ctgan/data.py create mode 100644 ctgan/data_sampler.py create mode 100644 ctgan/data_transformer.py create mode 100644 ctgan/dataprocess_using_gan.py create mode 100644 ctgan/load_data.py create mode 100644 ctgan/synthesizers/__init__.py create mode 100644 ctgan/synthesizers/__pycache__/__init__.cpython-311.pyc create mode 100644 ctgan/synthesizers/__pycache__/base.cpython-311.pyc create mode 100644 ctgan/synthesizers/__pycache__/ctgan.cpython-311.pyc create mode 100644 ctgan/synthesizers/__pycache__/tvae.cpython-311.pyc create mode 100644 ctgan/synthesizers/base.py create mode 100644 ctgan/synthesizers/ctgan.py create mode 100644 ctgan/synthesizers/tvae.py create mode 100644 dataprocess/Dockerfile create mode 100644 dataprocess/dataprocess.py diff --git a/ctgan/.idea/.gitignore b/ctgan/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/ctgan/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/ctgan/.idea/ctgan.iml b/ctgan/.idea/ctgan.iml new file mode 100644 index 0000000..8b8c395 --- /dev/null +++ b/ctgan/.idea/ctgan.iml @@ -0,0 +1,12 @@ +<?xml version="1.0" encoding="UTF-8"?> +<module type="PYTHON_MODULE" version="4"> + <component name="NewModuleRootManager"> + <content url="file://$MODULE_DIR$" /> + <orderEntry type="inheritedJdk" /> + <orderEntry type="sourceFolder" forTests="false" /> + </component> + <component name="PyDocumentationSettings"> + <option name="format" value="PLAIN" /> + <option name="myDocStringFormat" value="Plain" /> + </component> +</module> \ No newline at end of file diff --git a/ctgan/.idea/inspectionProfiles/profiles_settings.xml b/ctgan/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/ctgan/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ +<component name="InspectionProjectProfileManager"> + <settings> + <option name="USE_PROJECT_PROFILE" value="false" /> + <version value="1.0" /> + </settings> +</component> \ No newline at end of file diff --git a/ctgan/.idea/misc.xml b/ctgan/.idea/misc.xml new file mode 100644 index 0000000..4f4e5a4 --- /dev/null +++ b/ctgan/.idea/misc.xml @@ -0,0 +1,4 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project version="4"> + <component name="ProjectRootManager" version="2" project-jdk-name="D:\anconda" project-jdk-type="Python SDK" /> +</project> \ No newline at end of file diff --git a/ctgan/.idea/modules.xml b/ctgan/.idea/modules.xml new file mode 100644 index 0000000..b65f96f --- /dev/null +++ b/ctgan/.idea/modules.xml @@ -0,0 +1,8 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project version="4"> + <component name="ProjectModuleManager"> + <modules> + <module fileurl="file://$PROJECT_DIR$/.idea/ctgan.iml" filepath="$PROJECT_DIR$/.idea/ctgan.iml" /> + </modules> + </component> +</project> \ No newline at end of file diff --git a/ctgan/.idea/vcs.xml b/ctgan/.idea/vcs.xml new file mode 100644 index 0000000..6c0b863 --- /dev/null +++ b/ctgan/.idea/vcs.xml @@ -0,0 +1,6 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project version="4"> + <component name="VcsDirectoryMappings"> + <mapping directory="$PROJECT_DIR$/.." vcs="Git" /> + </component> +</project> \ No newline at end of file diff --git a/ctgan/__init__.py b/ctgan/__init__.py new file mode 100644 index 0000000..f693c93 --- /dev/null +++ b/ctgan/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- + +"""Top-level package for ctgan.""" + +__author__ = 'Qian Wang, Inc.' +__email__ = 'info@sdv.dev' +__version__ = '0.10.2.dev0' + +from ctgan.load_data import final_data +from ctgan.synthesizers.ctgan import CTGAN +from ctgan.synthesizers.tvae import TVAE +from ctgan.config import HyperparametersParser + +__all__ = ('CTGAN', 'TVAE', 'final_data','HyperparametersParser') + + diff --git a/ctgan/__main__.py b/ctgan/__main__.py new file mode 100644 index 0000000..6a38b64 --- /dev/null +++ b/ctgan/__main__.py @@ -0,0 +1,231 @@ +# import os +# import argparse +# import pandas as pd +# from ctgan import CTGAN +# import logging +# +# # Set up logging +# logging.basicConfig( +# level=logging.INFO, +# format="%(asctime)s [%(levelname)s] %(message)s", +# handlers=[ +# logging.StreamHandler() # Log to stdout +# ] +# ) +# +# def train_ctgan(input_data, output_model, epochs, discrete_columns,embedding_dim, +# generator_dim, discriminator_dim,generator_lr,discriminator_lr ,batch_size): +# logging.info("Starting CTGAN training...") +# data = pd.read_csv(input_data) +# logging.info(f"Data loaded from {input_data} with shape {data.shape}") +# ctgan = CTGAN( +# embedding_dim=embedding_dim, +# generator_dim=generator_dim, +# discriminator_dim=discriminator_dim, +# generator_lr=generator_lr, +# discriminator_lr=discriminator_lr, +# batch_size=batch_size, +# epochs=epochs, +# verbose=True +# ) +# +# logging.info(f"CTGAN initialized with {epochs} epochs") +# ctgan.fit(data, discrete_columns) +# logging.info(f"Training completed. Saving model to {output_model}") +# ctgan.save(output_model) +# logging.info("Model saved successfully.") +# +# if __name__ == "__main__": +# # Define default values that match SageMaker hyperparameters +# input_data = os.getenv("SM_CHANNEL_DATA1", "/opt/ml/input/data/data1/processed_data.csv") +# output_model = os.getenv("SM_OUTPUT_MODEL", "/opt/ml/model/ctgan_model.pt") +# epochs = int(os.getenv("SM_EPOCHS", 80)) +# discrete_columns = os.getenv("SM_DISCRETE_COLUMNS", "label").split(",") +# +# embedding_dim = int(os.getenv("SM_EMBEDDING_DIM", 128)) +# generator_dim = tuple(map(int, os.getenv("SM_GENERATOR_DIM", "256,256").split(","))) +# discriminator_dim = tuple(map(int, os.getenv("SM_DISCRIMINATOR_DIM", "256,256").split(","))) +# batch_size = int(os.getenv("SM_BATCH_SIZE", 500)) +# generator_lr = float(os.getenv("SM_GENERATOR_LR", 2e-4)) +# discriminator_lr = float(os.getenv("SM_DISCRIMINATOR_LR", 2e-4)) +# +# +# # Log loaded hyperparameters for debugging +# logging.info("Loaded hyperparameters:") +# logging.info(f"Input Data Path: {input_data}") +# logging.info(f"Output Model Path: {output_model}") +# logging.info(f"Epochs: {epochs}") +# logging.info(f"Discrete Columns: {discrete_columns}") +# +# # Call the training function +# train_ctgan(input_data, output_model, epochs, discrete_columns,embedding_dim,generator_dim, +# discriminator_dim,generator_lr,discriminator_lr,batch_size) +import datetime +import json +import os +import random + +import torch +import torch.nn as nn +import pandas as pd +import logging +from itertools import product +from ctgan import HyperparametersParser +from sklearn.metrics import mean_squared_error +from sklearn.model_selection import ParameterGrid +from ctgan import CTGAN + + +# 设置日志 +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler()] +) + + +# 定义模型训练函数 +def train_ctgan(input_data, output_model, epochs, discrete_columns, embedding_dim, + generator_dim, discriminator_dim, generator_lr, discriminator_lr, batch_size): + logging.info("Starting CTGAN training...") + data = pd.read_csv(input_data) + logging.info(f"Data loaded from {input_data} with shape {data.shape}") + + # 初始化 CTGAN 模型 + ctgan = CTGAN( + embedding_dim=embedding_dim, + generator_dim=generator_dim, + discriminator_dim=discriminator_dim, + generator_lr=generator_lr, + discriminator_lr=discriminator_lr, + batch_size=batch_size, + epochs=epochs, + verbose=True + ) + logging.info(f"begin train ...") + + # 训练模型 + ctgan.fit(data, discrete_columns) + + # 保存模型 + ctgan.save(output_model) + logging.info(f"Model saved to {output_model}") + return ctgan + + +# 网格搜索主程序 +def evaluate_ctgan(ctgan, input_data, discrete_columns, n_samples=1000): + """ + 使用 CTGAN 的 `sample` 方法评估生成样本的质量。 + + 参数: + ctgan: 训练好的 CTGAN 模型。 + input_data: 输入数据文件路径。 + discrete_columns: 离散列的名称列表。 + n_samples: 每个类别生成的样本数量。 + + 返回: + accuracy: 分类准确性。 + loss: 分类损失。 + """ + # 读取真实数据 + data = pd.read_csv(input_data) + logging.info(f"Data loaded for evaluation with shape: {data.shape}") + + # 获取真实标签 + label_column = 'label' # 假设第一个离散列为目标列 + unique_labels = data[label_column].unique() + + # 生成少数类样本 + generated_data = [] + for label in unique_labels: + logging.info(f"Generating samples for label: {label}") + generated_samples = ctgan.sample( + n=n_samples, + condition_column=label_column, + condition_value=label + ) + generated_data.append(generated_samples) + + # 合并生成样本 + generated_data = pd.concat(generated_data, ignore_index=True) + logging.info(f"Generated data shape: {generated_data.shape}") + # 获取真实数据和生成数据的分布 + real_distribution = data[label_column].value_counts(normalize=True) + generated_distribution = generated_data[label_column].value_counts(normalize=True) + # 使用均方误差 (MSE) 评估分布差异 + common_labels = real_distribution.index.intersection(generated_distribution.index) + mse = mean_squared_error(real_distribution[common_labels], generated_distribution[common_labels]) + logging.info(f"[validate] Generated data distribution MSE: {mse:.4f}") + # 假设生成样本的标签完全正确,计算伪分类的准确率 + predicted_labels = generated_data[label_column].values + real_labels = data[label_column].sample(len(generated_data), random_state=42).values + correct_predictions = (predicted_labels == real_labels).sum() + accuracy = correct_predictions / len(real_labels) + # 使用二分类交叉熵计算损失 + criterion = nn.BCEWithLogitsLoss() + predicted_scores_tensor = torch.tensor(predicted_labels, dtype=torch.float32) + real_labels_tensor = torch.tensor(real_labels, dtype=torch.float32) + loss = criterion(predicted_scores_tensor, real_labels_tensor) + logging.info('[validate] loss: {:.4f}, acc: {:.2f}'.format(loss.item(), accuracy * 100)) + return accuracy, loss.item() + +def save_metadata_to_json(output_path, model_path, score, params): + metadata = { + "ModelPath": model_path, + "Score": score, + "Parameters": params + } + json_path = output_model+ "metadata.json" + try: + with open(json_path, "w") as json_file: + json.dump(metadata, json_file, indent=4) + logging.info(f"Metadata saved to {json_path}") + except Exception as e: + logging.warning(f"Error saving metadata to JSON: {e}") + + +if __name__ == "__main__": + # 定义数据和超参数 + input_data = os.getenv("SM_CHANNEL_DATA1", "/opt/ml/input/data/train/processed_data.csv") + output_model = os.getenv("SM_OUTPUT_MODEL", "/opt/ml/model") + hyperparameters_file = "/opt/ml/input/config/hyperparameters.json" + + # 解析超参数 + parser = HyperparametersParser(hyperparameters_file) + hyperparameters = parser.parse_hyperparameters() + + # 提取超参数 + epochs = hyperparameters.get("epochs", 40) + discrete_columns = hyperparameters.get("discrete_columns", "label").split(",") + embedding_dim = hyperparameters.get("embedding_dim", 128) + generator_dim = eval(hyperparameters.get("generator_dim", "(256, 256)")) + discriminator_dim = eval(hyperparameters.get("discriminator_dim", "(256, 256)")) + generator_lr = hyperparameters.get("generator_lr", 1e-4) + discriminator_lr = hyperparameters.get("discriminator_lr", 1e-4) + batch_size = hyperparameters.get("batch_size", 100) + + # 训练模型 + timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S') + model_path = os.path.join(output_model, f"ctgan_model_{timestamp}.pt") + logging.info(f"Training with parameters: {hyperparameters}") + + ctgan = train_ctgan( + input_data=input_data, + output_model=model_path, + epochs=epochs, + discrete_columns=discrete_columns, + embedding_dim=embedding_dim, + generator_dim=generator_dim, + discriminator_dim=discriminator_dim, + generator_lr=generator_lr, + discriminator_lr=discriminator_lr, + batch_size=batch_size + ) + + # 评估模型 + score = evaluate_ctgan(ctgan, input_data, discrete_columns) + logging.info(f"Evaluation Score: {score}") + + # # 保存元数据 + # save_metadata_to_json(output_model, model_path, score, hyperparameters) \ No newline at end of file diff --git a/ctgan/__pycache__/__init__.cpython-311.pyc b/ctgan/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22ad241c302e2fbe3a07d78fc8523242805d766d GIT binary patch literal 545 zcmZ3^%ge>Uz`$U6I4j+ok%8echy%k+P{wCF1_p-d3@Hpz3@MCJj44b}OexG!%qc7> ztT`;XtWm6t3@I!rY&mSX>{0AsHhT_7E@u=cBSQ*D6jusoFoPynl}t!}fo@J}S!#|# zL1J=tVtT4VT7Hp2a!Gn(o?aDyU}j>TLU>|cx{iWpUb0>lPi9_PzC&?JnO;h2SrxZ| zo}q!B5r}Q@667pR##>xznR$sh@hOQViJFYJSe--M9sM*JZ?S}gIl5{x7qKufFch(Z z2xbNb20u;iTkP@iDf!9q@wd3*;}c6uGV+V!<8N`s$EW5dX6D4l-{OvsFH0>d&dkq? zkH5taaac}%VhY%bB9JM!Bp{;2m3bu@sl}O9sYS(lU?H$|x5UA0JrwDZvcy!dJ3#g) z=H$f3uVnZP^1?4|7ps_*ti-(Z{D7kTtkmR^n3BYt?2^o~)EKZiG4b)4d6^~g@p=W7 zzc_4i^HWN5QtgTa85kHC85kIfqZt?&J}@&fGJarVU{t=qp!9%Ou7T?Yk5mKa4L<1x i?gtEN7f{g+2CWOI=mvwv1yuBaOS6F;1dI3?7#IK=EuJ(0 literal 0 HcmV?d00001 diff --git a/ctgan/__pycache__/__main__.cpython-311.pyc b/ctgan/__pycache__/__main__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..182e053419b2c39b7bce116a95f6bb164b049920 GIT binary patch literal 2550 zcmZ3^%ge>Uz`*dKuqfS=je+4Yhy%lHP{wBsCI*J-3@HpLj5!QZ5SlTH5zJ?bVoqU5 zVaj32WsPEGgvhZ)u`w~EvSqPD^@G?PFm?(vSdJ6MU&g?|uo}i;NM~HbvW%I5VKp;M zh>;<MwS^&yE0sHiEt?4>T6BhqA(bJE2gC;96!sL3WlRhVtHH7$ktkj?6`UztXe#)a z7*hGsRB@;9psHHN$iT3g5o96=M+u;*=S|^5Rlh(Gqz)aY@S~X{1P%dVn7R}JRCz`Q zG<m^P-q{RunFy#7O64O$m2j#65voK|g>b0iO65)EOBF~JTE@b_u$l$t2S$bz#$Yhk z6nzO2^wVU##p)d5?&w#guBwq(oLrKbo2pr?5Ur|_lUkOVla~nM#ww_4<faxEC#I)r z7Hcx!Vg<|HVoS>}%1tb}#gUPimy(lORGbXAh=G9tlq^BX<MSj&P_it6M-T%;7Mz{J zh#p)etRQzma27~Em|4Sss;-s=hj~mj46_+hm}^+F;BJ}CFqZ{YHzPv{H`rJ(QNxtN zvW68kiPW;<FrTf4X*NR&`&?#p6Khx(fP4%#5t*oAsbNMnn~?#B*&H>D*rKzBbpbq! zqL_>scJOdT31d!lx27-zGiY-8RY?VxBo>uq=A|oulYl}=QDSCZW?s6Uo?expOJYf) zLQZ~SN@|KiT2X$kLKTlfd1gt5LUBf7L8?NCCetmp)PnrvjN)7DWvN9;`NgSK;t+E( z^D;{^6LT`FQd1yWs@N4EN~=^tpmr!E=jRsWq?V+n=qUszmVq3Ro1c=JqfnBsP$lUL zW)>%wrKTtpmnJ8t78j?L=HyiBX>#0R&&f|u&&*4|#gdtqmVb+>AmtWEQEFmJd~$Kw zEmpAWia-hH7IRu=$t@O;A)0KrxH9t!N=xEX5=#<q@#L45fH+_ai$G=GErFEG;^d;# zlGOO*{G8I<yy9CdV5O|dCFzNI#h|EF0E1uJE><xqS&4b+`2j`wS*gh-F(rvP*(I4} zsWD)+G4b)aiJ5uv@p=W7x46MU8V{Bz5@%pw5MW?nC_cr&z|g>Omxa5<=YpW_MHam) zEP5AM^zL%=_E=A-y2!0`g<Gk?1uE~kAnJmk=LJE}9U*(5tcxsOS6I9*uz20&5u72m zLgWIE-bEh0D?EA)u24-@E6gqkT0zi77V9f4))!c;Z?JH-dp3D?q+AeDxyYh=g+=uO ziz-xu+YX)!f^HWC-7d1YUtw{-z~T;4?bYPfQFD<+<_e3<1r`~oQk{z|x>s0qFTl`e zP@=rW0rU1P-eBK&PrrcBka!oz5XUMN{rrLw{oEXVa0KXsLRB9eDw%odV4fZ*-uQ!k z<NZTJK)QVWU0i*tWMJCBk)#g}!FVuRub|`>XRvR)Yk<GANAOEf#du3N*f-wAGuS!E zHN-XE+21GB*Dv@MYffTPYR)f4ohor3NIK4_EJ!UXNGwXsO)W_+Dz>T;^aP~}1#nsq zNG!>)Qm7L42PX~%ND_n!a=3z1x|KqekP9@SC^$nBij{(<$StP);#(YvMd<~JMa8MN z*a{N!QWA@cKy_b{C<6lnI6W4DUCN!5Se%*coS&DLnSP7K)6dQS7H@D#QEFnY2c#4Q zClgR`g38|b_*-o0sU@j-WksM$t4N4}f#DW&W?l&-)v^{B<Ybl<f$FFtP>HsZp-2K` zofrcH!!Hh--29Z%oK(9a9R>ylP|;All#zkq12ZEd;|&JR1~9zAz~2CdHyDI3z|ai_ z@e8Qv1_NIM7=B=5U=(R!zrrAVgG;bS=LWaQ0}j~^_A4B+7dhmwaL9jPX5wS}z|O|V z^pSy$k?9MF_y8sZI2rXmFu(~h1{R)<s!Pn07nmg<2ue*5nj$&F{(?@>MZw@Jg24@5 zH-u#-7)`OfAgg!5!0)24{}o~X2A><e0v*8<B*BWUS1?@^)V?C9eUVq^3a?Is+YNDr z2}M&XFDRPrV81Bta7EnVy13gVakm4-7sb7=h<jgP5&6Kzz$)0_4$5{~7g@BguxMXk z(Z0bV01+|-<ut<!f`%7ajIOX4U0^YSOPMaHxgcnILD2Lfi`f+xvkNR{a4D4;ZWjbq zE(jvB6kN*Uf|}U{L5mB578hA8udrBNV6g<bK%l{U0@n<m6>1lZoG!9BUtw{+07D<- N8Caw)FiV1?5df9yZ6N>v literal 0 HcmV?d00001 diff --git a/ctgan/__pycache__/data_sampler.cpython-311.pyc b/ctgan/__pycache__/data_sampler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7b71467eacb17d46d3c02c760b931be117cb265 GIT binary patch literal 9189 zcmZ3^%ge>Uz`$VFos^y-!NBks#DQT}DC4sp0|Uc!h7^V<h7`sWrX0p7CME`Vh7{%& zh7^`m=4H$b46B);Y8j$fQka4nG+C>JT@p(YgA;QLa#D*Fa`RJ4b5iwQf>ik>gGd-= zgfc$cfGtU7h+<4(h+;}%jABk<Y5^I?+`<sWn!?h;5XF|l+QJaUp2F6`5XF(g-og;Y znZnV+5XF_k8O)%`b&DIx$|@fujzTfm9mNVI8L0}%`FSasC7Jnoi8%^osmUezMGA>| zDGJH?MMbH_1t7`1bcK|}l0=2H{2~SC5O+sEy<`@sGeHz6#6bT2%)!dQ&@M2Yp#<g$ z1_p*2#$}8Q46ETHH4O1^b_!z)Lzh4aEL<5F7&-*f7*m*9IBJ+&7?v?GFsufd55}mb z)UuSYgN4CF7Th&8EGf)uSe7v{Fsx<*OMzU#5D&K@ouP&yo((3G!dlA;vMF|TEgKU< zr$7xGsx3uLHEb!Yk}%~Awd^&lHSAe%o2wWY7-~7v8EQF81mTJpkX4|X=fV(cQ_EGu zxquU)8o@&G5hzT-QZ<Yz3@Hq)Of{@ETp;_>m=NwkH67Wds4|QUJ+?IrS#W=2x4(uH zT~`f9Eo)DI4J&dCff8E{8~)g-;$>i{W$js0!&<|b#SQib0|P?}JHqAJ;Mgf<@?vOW zEI|lE;-3Rs{HL(4;Y5wsT5gmOBa{YsVeTMF7at*AHQZ@TsNq(_5YG?ORl`!l5HA4Z zA%`P4rE_8PDTZnAv;|79DXg_TDXg`;;B<Q$N4i}^Ou9u02XK6V+=-OlI9i#IQ(Fq_ z8g8^0;HzOS5=Svd2xJBb*YKsVgVHOiZ)*8!_)yd%?BzpMS;fu3P|IJ#U&D}v2u+?8 zl-vb!AE=nD;jLlFg2!(S4_FK-pVlyB!Q&As%Zr>#!ReK^h9`v`%;Ev3`}Jt)o~MO@ z$aH@ug&~+hlh-d9UeGZxFn|h3VFm_<&tBki9i@n5K=pbGBP4s|3)C>AAZ7O~xOpke z5Ea1;nk;@VLBd6#fV{<CoS#;bn^>XAbc;DBHSZR0Vsc4lSt6)#k59|fWW2?loS##g zn-`y%mzG})(x>oC-^D5>B`Yy6JwKo*KPxr4B&H-WC%YuGEHx&%Bt0=N22}LNLn@3S zy@JYHLYc+!DVfE|MX4pJ@euQ>43H`pz4-Xdyv&mLcs-k({N%)(Vmm!d9Xt#S48>{; z3=9nncZJ0!R9_TUz9Ou=pzxBg#s%T97128c*Th}X^*j-INjK~QPuOQrl0pejP>~J} zPb~%phN+Cx89+sQEPpK{6GJBxdZn|TH-%v$Q;$F}gC=7UsQ6yVu#))}lb%5lsB|g< zC7ohW1Su#q6oFj!i_0b_v$!NVKexcHN&~;&5Gq0GwAhM)f#F93!v_W?B`ySULrC_5 zfzt&8rvug(4g9Yd_<vw#QsTP7!waSn>WV<=^cHJoUWu{MO2%8PDJ7K!so=Op@jS?_ zV9(2d6BF3$oVAQKj9H*?0^^BHJzT*IMWC{0CF3pTl+4_fOhusFd5g2SAQ74ZK#4{{ z0qo@>ka@)WRUMR+rNHrbLCOt+Rz%LQoMQ`NUyyPG`vn}@C~k*GoC9*iN#Tk(P<e!? zAW-}#9n7$j5ga&05)2FsMWAvLB^pWelMX0N8bRXk0#6bIE%3hJ6aqq_2^XCbuQ(-M z;7RI;o*~#1cZF9TqT~Wk64<YrjJMcRQqvMkb4qCE<03hbSE=G(6OezMApS)pK``rv zvfc&9U=Rw4zv!57#W4Y#P;c;v&tUBFyTT&}=7N=i39$cHGTvfItjH`z@jPis6%vu4 zbX81#+60xbMWCDkc9o`Z5hnu!!!4e8a6y)nnwMUZaf`XQH1`%$UcoKas??(V;#*Sa zMOu7rVo6bEMSO9|ErIx=%#`?~%J}5O5>SCxTm(vxw^)l3^U_mqu_YE1q~@jEV$aLZ z1DSP;wKyZOAoUh=ZeqnP2@GRFt!<E@(ik#%AoEkx^NT8P$zez)CzgPv;tPuMlWqye z=b@`8zQr34Hz6~%_!e_ce!8aMEtcZcoU~gkAiv+@&Mz$~C@le({<nB@^3&tfic$+p zQ}dE5ZwX;Bk*T-@Qb0<B0+1;+<rWtrByKTg-QvkjtU$P$E4j3&C^fGnJ~QPOKa36Y zFE88&AUkg%xF8QfN?=g_RX_s{;KJz^2dqx162Vp)gCe-tjTv0>-j$JGz%tWwj%NqU z6%NS{ER3Atj2{>f<X1KZIYn@yMI&w~>VQ-AM^MrI1w?#gV3Onh0wOMONPXpC(AHg} zalzR2MEC`*(2H83SF}PwcBp^gVo+2Dv1C86G4SzsFx`+=`@q1>>A?sgJG^i3D_-DN zT%diCU+)UP-UkLQUJu3_GU_WB=lD(Fydf$+LvxPqMN#z$Ob=w`Cve@Bl$%jGr}m<x z?iESh4-AZgZXgc`x-m|0ydffYT}1Vgi0aa$C0QFpHY9@J4xt^1JH<9+ZjryJ=6FTT z@qpqLHJ6Jbu2)1{FM!bv0g)>Ls@DZ{E(z$Yh`K0XaYex5f`G*hQSk{(6PO+d2u)3f zl-fJAFG#svlybWw<#tib{fe0TMFDq^k{kSj6GEm$Enu7ydx>A^0t|g+V~~)(0V^pl z@FZ=pzQB`oktgX2Ptpw$(GRT55?mh{m?gNrfQSyBuUrh8+G{v27&@OYzn~F(Q6u<@ zMldMgRPSo*uCcga;&vkaf_B(N?XWA_VIXmhyLtv|GA>wno$$F}9(hqO>WW?zNL>4d zvMGpGxGSrG7K>o+4I^8eQ3=u~^??HtQKDiUTsNeoJNUs-#t)7%5Q!OOe7K_Qu7c7M zwgn|C9Ko<?gVKu96&_3Kb~s*8u)C;WcSXT&0`CoB1yDTN-BnRr;<mzehv7wSn~N&8 zS5#~#_});^-O99uZA0xv1J^4Ct``+tC-B|~PME-QMZkE2*8%<mnn!dG@Lw?TzbFuJ z!^RsVuDc@ZqJZTMdyhS}Ag;j%sfz+OHw5GtST3=>C}1#w|Av6njKqoj6S#<v;t%p5 ze}7;Ak>GMO8B_p+n-UBR44|$fs3iaV1>E4MVL(*GHH@Ir5TdY_sRX17!ePjQn@|F0 zF)(DoRV;va*`ON0R1H%WsEGh(ry!Y))<$Mp!;ID$VXk3X0CFc-Co+)&Qj0nW!JG}Y zdjY5|h^!RM##9e(XG2;VARR@D3=9mK97To<3=Bn}M(i!_)QW<{yp;Hq%-rHzY(<HA zDfzj#SU_}<E~xBeOD@d?)i`X4Md_gSIZKf#NLUE19StfuSwRJ2@h!%pVvrL-#UZ2! zl#B-%m!BJ-oROcIoC>Z8;xqG7QY)%dQOZPI%6vf8eG<6I1}+ecFPOMp5peHdz01Sb z<2To6j^%v2S#~Q_*Jy7@T%&hU+VYCD<pIVEJWdyRoUZUVbuiu#5S`95iD!Z01fGim zDpv$l78G6*(74WTaEaewL*fN~gNytYSNJV1a9Cg_AW)QoBL4FZM$p(9asq;NcZ!%m zBV{N_1)gjcfb@bR2bn-gY@pBrOCp<%nrLbm(UJ^mas{~&Vge)bXh|9q=9p3~Gm71) zI>Bau2&_6;QdnzPP;?_)jy_h#l*R<^%Q9e3s5NLK54Ef{3|Sl?n?bk+sq0w73hIm3 zvZb)rpt>Bl-E2_1nQK^Tm}*$lSb`Ze*%B`^GBCKO=A{-TmZahuJp>ICmJ}ss=4Ixk z>v3^`k}wDdrIwTy<rQ1O1wewy`FSY{8cB&I$r%b23d+z%LS|~QCaOA++7PURm3k=V z<t7$qBbf)wUkZ?ZDmGIc74q{^bu;oyV1u8Tc`2F6i6xnN>0l2Qr{<(4m!zgBpqYhg zFPfx6Cd^k@{Nb5blA4}cq)?PvP?TDnnpXnupey93DS%v$YA!TDD&sR#;xqF=txu@$ zu$UVRb16)dLS`P3UI3SEez(}ua`F>PjE#y!K_wfgsMF*u0`-W%rCgC4sLW!6ls~uF z62a}?B2Z7T$QY#F7ewfTibi%w@x}^{*CJ4R{1yu+rf;!;e0WO;=0Jq6ZgC@t6@zRB z#kc~rb9##pGX4+`GNdduxk?JT#6y(}1Qm0Sz{T7LDFz;aE8L)A1#XQC+!{B8#V43N zm6e|pd0kfblC199q&1m4Sgu$ET+|D^q8E4}IO3vg<Q3V-3mj5U`2~9_W{56exuRfn zLBe>4<pqA%i~O!v_+2|#?jqzYE=X7&ki5X}eUab$3cq&;%MDYr4wl}ap2!KM7kOo` z@XB70H@v`Wc#*^ChOFXsS?f!())#F2F3S2}k@dg8A$3Dc>WUb+AQp4K;1zRGET)6| zhMqpE5i_bU@M>M;(0;%#JR#(QwC)9c#|!+97dRX-i&;>>fr`q{Z3K$h8V1B@(gJuf z3o15X7<Kr)mZ^rRhAD+H8)P~Ia*bcZgfgIqC{UPC3t#Yn9#aWY4UMShvDe!u=7M|% zb{fi1JDMG+Hr6oKfNFkb<ms4NW)xFF9S5+fHJ~CNKIoptl)|!x6|J&hNnxu&opyk` z34Oe&h8bFTf(lLcDmP+E%ydYpnwg)cP+FX-kYALUo(Y<g0d<$aT{JyNCe>uT#aK|} z56WwZT*?i}rJx20C=(}^l%(cC@~kRIAt>V(1%Ozf+yU+nBl4{#D>%0n1%c!pL0Oz5 zIX@>pGo_-qC>WHPK(!k<H$oaB;?VpCb0IuaDr3o%=!$Yc25tf8&IjE5=y~&jgv=EQ z>+2HEmn57oO1NH;aP8o|A)&e;cSGSNHH#|}mSFB$rZsFUYA<ToUeU0<C}G>ddskR= zisf}-<x9fKOP!W@u1(&cdr{Bfik`zo6~`+ojt5Ln_+AiJz9<}ZML4R111bArWbGT; zI%s*X;sUSwMGlPz{K6gFko*P;4se=bU|;~H08ozjYyrt{%nfEGFgXSWf_V%%En=i0 zrZr4xDUh{>u?A^6DwrXip_a7-o`xA1YM5(SQ1!Sl#CpUqF)-A!fhMmOAiEvzP8=>q z%BgUr3^fdh>8~0#L^)7{G`WMC3sHHD3_YIczR+Z@dJdTwg-$aT<>w;OKddDIkt<Ft z$t+H*gjdesc}&n`CHfpDQp3ZsC>^cU0U8a#-o8Mw4ap3wjh8%3Gx8LQ^2>`MbvbDK z2<&P^mF}m>c#EYtvno}S1Cl*KIrA1fsL7t0R|3wMx0tIk3yQQsnGu|0H9-|B2PBJ{ zfpRBf-Yw?j{G40N`2{7nm@`u<ZZRg`V$3fFHT=NY6kcm`Lvt#~yH(;?awk{<lp`mC zbL0mx1|I$n>$^fCQ#7v&DPIy&Uh2GpWufO1&kc?jgp@A|nOzYw>u|myEIxyIisc0! znFo9V9n5z{#Ah(i2%hLW#TPWr0v@d3bYtvv>2T~QyuriY@73iswP;G^b#e7e;_53{ zu4vfqNW7xqa6s^ihVw;nmn-5f7kON-@VH(8BTy~p+*5ghLjscBK(iFE^ao;tvm0nw zdpamF!el@*E{s)bkZ1xodP?$35_1&tN^_G^i$L)KZC_-j7VBv;f*UNFjNnj(%!5G6 zHlCEsTyS|74{BvdAO{(;Og|{hKvM?|3=jB)ukb5g;829P7UVjZn?O||i2u0{+-XLh z@<CL-HB2=~9pyBp6h@3}&s>AlX|7>RVMb&-1gn;%hNXrvg#{r3p34IDHkeb`QrKFV zYFMC^5_2$vCVQ0}_DTub3f4oaVZiM-j<Up@($x6O6!3_?pC$_={UTa_8lZFvsY-6~ zf!b-{H3i_YMsU?~i@7W{8I)e*i$SFjIHf`YT?1P6fo7BPQsUE!K%HFBx&!dM6{O`3 zHmAx2IYNllH5n9HVc-P*luw{1V+QMtk}L9t7sQQrXkOrRxya{oh0mpf`39fR6+Yb+ zc^7O0ujmC`;S26yzQMuWQ8y#&3cuz>4$Thc8)D)e%pHzkSlr>($$f)gsDld}?|z!x zpfSq4#N5>Q_*-1@@sI@`@$t8K;^PZT6LX+4?D6p_`N{F|Mb4o74Q?=pgIF;jA{|6z zf{1(&0ZQXw!@;3c1WHX{0#q&+_k#V?zyN_aSi&x_gneLfWMl;mI-*c+Ow6p79~j^S zmn5qb;|B&L(u{{y<O2hokP&8;{J;Pww3t|xJ}|%u4mMWd4-A+D#HAodYO)p;gUTt^ zywco)N^ml{#g?Cxm6}`vc0YJd3Zfi5*T!K32`sy!G6n_)P!=x^Wnf_Vz|6?Vc!NRm j0)ylO2B8Zu^nopyiBb3i19ozP+DEX=7chyb3T!O^Xst}y literal 0 HcmV?d00001 diff --git a/ctgan/__pycache__/data_transformer.cpython-311.pyc b/ctgan/__pycache__/data_transformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eef72269ccd8a69bdbac53fbddfa87b182e24a6e GIT binary patch literal 13850 zcmZ3^%ge>Uz`$VFos=G-#lY|w#DQT}DC6@i1_p-d3@Hpz3@MB$OgW6XOi@gXAU1Oj za}-MoV+wN)YZO}wV+u<SdoD*5M=oa+Cs>{}hbxymiaVDliieSbi6NCKg)N0Wg(ID5 z850A;Y9^@p3{kvbInH#36s|>#XmWf^4DJjm+${_#JgNN4m>C#WGsE=>q%a3FX!2Hx zyCjw*h7={{6{qDF<)#)X<mRW8=A`Pq1gX?yyv3E5n46kXQd*Fc>Zi$ciz6VhC^07| zHRl$4N@`AGWon8h(=9RQoYLZw)FP+E;?xws{G!~%oXo1!qFcQFd8r=xC9ZkN`6;PI zw>W|e67xLs((-RHr)1{d;!R90$t+7O$;{7-Ps_U{?3|xdnwtlA3P|Bf)?3`kVBvU> zGvJ((%7WBeoFy<<-{Q_MEh#81iO&QXE(qnMWag&k6=&w>6(@uI0>iMN`<w=jnpB1; z#uSDqrWD2~<`m%+mKMe+mK4?&hA7q)wibpcwiJ;Rjuyr!_7u(*hA55{t`>$UPEe3W zai#FIFhp^u@U}2S@ucvzFhudD@V78T@ui5S2(~ar@uvv2FhmKY2m~`|irwOOK@H}r zJs<%ExNlR7^tiYb6ciME^HWlD6q57vN;30G^Gk~rApS2_D9<d(P)Jm8N~}yR&P>d6 z_w`js%u7+og9W!jNk(d}LP@?tqC#<UVoqX_LQ-l;d1`8&LbR@-j)Gw<*m#hZM1``{ z<dXa%Jq4G{;^d;#l2o`QiAAXjso>~JQ7A3W%u82DR6vPDJ+LV+K`H-vwa7a6hgBu^ z$?#Ogz`y_sUJ(DY5(6mZ*D%&F#KVG~p@yl3A)XP&O<@dX&}2&FW@KP+E=o--NmWQx zNJ%V7RDee;JXAm-<5-klYz4L+!~@~n#EST2NYX1-(8$ay(L|L8sqsy$$jmLxRmdyN zO-e0N$WK#nPb@74g`h%8W^qYTW>P6A1r{r0=E1{O!QIyv;!%(ZAY7iBnVwM+Uy@Oj zTAY!elcJ!JmXn`|&AxE3N(F?<wEQB4MC_J=Vk;?CAv?981YsR0m+{6Ud+nBBJi7gw z%(qyIQ*+X8@t`OaKv!1`iV6h<g<r-lRxv4AiFxVy0Y&*)smUcVC5buNC7ETZG07$A ziFq-gfR9HG_=3t?9P#m)d6^~g@l_I_)QX%~VKNd73=G9m3=9kn3?GCT<g_|iu5d`* zP&5S5au+y2=rbt3lR@bM!~kJX>H#_Ivj-z2jbwpD!8nCc5=?`LTBZ_^DiD)_fgy#d z1XijtFff#`LxmX_7-|@oF)}c$hNr6<hAbACENY2Z!;l3xV+}KEiCD{A!@PhErVp8} zVXQ$bbr~5-_z@~<n6o%wJXDv|vXltJWEdD~m{M43m`g+vA}OpXY^W|omaSn)VMmos zXG&qMVX0zZU|7SkjFo|5HQYDB44Rx(A3}-}(UP$ODEVQQCeXqLl(aFk6(}hvXcVOA zfg;_lC^0t`OTKahXVo;2EV#5v%*=xn7m)G_$%3HNlG36)w5$s%&LOo5W|nnSNJzj^ z86+eq<R^iWE+VfKfl}2i?zE!(+<1t+n#{K(ax;rTMOS=TVoqsle0pkLY7w}WDFUU_ zB54K&hFe^SAbAN&H@8?y67w={u@~p3mE<N?6!9=HFlch$Vuu)7Tm(uNx0rJ?^KLQc zq~;ZgGcYg|fs+0$=CsU`Tg=6!xwj<C5_2+B;*;}p3-a?)^FWnRN@j9mNq$igH%K`s ziD<GGfs*$v7Ellu@q%2!l%9KwFRv8jeyGOcVo<DtqDG-e2Bd^PJ}t8(9^v9Dc~Hcl z<VI8lDxgfc29hbo7{sNg<Sa0plHb97msk0Qs?l9ZnYqCWL}o_LiCn>QMZ@Z%r1cd^ z>l?C)S7ePpu(R-SePm$a<N5+3I+#9iF>rGCvvsj`luU@e$f0<JL-7KK;$2CZ89`TM zEiX!1U6Hiv;CsribVEY<hJ?xuF^Rj9iVGAMgk6->xgx2vHe*BGfy#?|fmiebFGvL4 zkdt4acTvvZ12>xh*GC360j@6~;sXO4Cs&ai0|Ns}F~P~ezyK-$K3{@XCBzgH%NQ6K zR>N}`I8%cX30wrmz?!L%3l(OhLIkCFsbwhvMJUvu8m1Jc8s-vET*5?HQkc=QEn^MK z8WxNU%Ubo$Ewe-+KQC1wBfmreQtuXlb8HG~#f+4jhJJSCfO)P+6_hW)S&h9&4a5e8 ze~~7Lr2$I#>?x(WxtXcO7#XNY2c!lR=C_#hGg6@$D6u3pJ-;Y36(a-jf-?}z+A3MB z*#{==0LnahSTYZvbZ34~eh2dn4(@*TF7^p56S6OIC|%)Dy1=0X$`HYGqAyBnUXj#9 z<Oii2Vv?ZzptnKuqNL3gNt+82wm0N-a3=qi3@B|WP&ok30HAViIzu`Ga$70|)UpDV zUEpMw0%}>o6FO?`SOO|~z^WLK>Oa&ZTI`#`T*D9#Hz9?kg`tMYg#j%)F@c)KD2bgh z3lwQ!(^J?$g&Ua3fLae?n8REG%G?m8NN%cOtYKb*R_E0+)-WzW@&Qy60|T1dvfw7x zFvNpe7+{6qFoS2;8pbS8as$h_FvRxOvedBDFqJVB$<{C|03|1|3^Gx}Si_Qz)Fwgo zS1oIvObx>VK4cv*HX)rg42a@v0a9xMrk{bKhBZq7#zVE2k)ej6f-wso`!%cxb`>wE z7)7+QYZ$WN{>Nnxy6HWADIDMy1-O7J$w*a5NI=Up2?<Drl3Qj;u|guK+J>~o(K><& z3E)BvQG+4+8+j=TiSRakLIO%(1F8*{F=3jJ><!7tELJEkNi9%F%qh-SNX{?KD=7wb zgcOvaH8-ePpPN`xlvx38uYkKu#RZ9Z3ZSlHQ7*V$fYzP}aZ82=ECT}rC=-Gjai2kD z4kQ!C3f3}~AgMubfKFuU5ejA~GGkz1Sjl*cxwxcAlj#<dp200PsH;HPOaW56|KhUA z$t*4bH5=@z)X@T2FD<h~&n727IWec$P7k3l9aKhv>dOX(3-aC&w1I6Y^8&Yp+)KD2 z{0s8lkSZ0)p%i+i$P(lurXp)l;mn*@np=$Iks@0L1_q*ikio#f@EPg@Y0nE_v_g2U z^8)6X9&<b({0q{aU>_7IgREr)R}p@ig0}>q9ehv}#)AT{ND$;8PEex?+#tQhm06sb zS6q^qmz;WwsUYPRCn#LOE!kVFMTvRoskc~*GZG6@Z}Dd4rKJ|dL)w+ax7boZBMZg1 zgrKd-_@cz}c&L~(Obp&h0uMJp900DpOhJ{BKFBF-i3J6zc_~Gp?5oKJY4dS_%}Oju zyd?lj(NGhMZ}C9V1$dC77}72jg$kkCRt##0f}4hr)B~<8tAs(`Lv9bIWtK#Ml1K<6 zxML>3pr{P$kjdUqR07d*HzZ{{SbF%L@=MOpSn9OI=OVx16@J4H46K|kj5lPJI$V1E z9zfESmDdG%?~C%@SLD4f@W|W%$J18U4H;YV4g_8?bH8Zdam4@>T^FT2uSk1#_}mqj zoKmyEV?)jbap#NT&R4{pJ3JmJX<b*cxTIupQOW8G2swfrA%8<#7es5^kXF7Tt^I+4 zgVUQ4LUwrH;Fr1}r+0xLgl=f)T-UI=q+zwg;z0BP-z(NZ7d3*fXarvnP`x0a`T!(r zb4kNyN5p~Z1G!ghLoRBBUeO2zN%06=;ZeNKqj8Bx<ASE$6-}>;Jl<D$yf5&0-{29r zz@v19N9lr^^%XVWi#)y%?hTL@o$EZtmw1dXn0Q{~@w&p}bpeb%2qGK^PL?QDH7H)d zRkaWkxJ%1~DBMxHw9GY3HOz==ashJYgq8JZ725((CWFht7-$tU3ll>PQx?btU<D~` znG7||H7uEoOBnkYYnf^o5v|!;)*99Y@D4IqIh3ejLahjD*^pbhSs*`xRibo$*--V= zFd{m`Y3Sy$*Ra(v&t^zrn#)957}T(*Fr+ZFGNmx4G1-8+{19`PNi}ByC)fjMM2QGO z0M$3O94Ra{>?olCuS{z=5RC~`T{Rpjte{Rh)UE8n44UkIpf(&s6}KLIz%HeVOAl1+ zgL{jh5}kp8;e$Q6ume{en!Ml=?-oZ%etceOZb9WO_OzVDl9JTCTWpEN#hK}OMWANG zEvCGJTdY;7Mft@=pjI=aFl0+C1{Gzum^1TAZm}go3O$zM%&JsPPH=ft6a*@*Ky@0V zyb*w!oeJv0f(BxXbU`H<H@y5wO}WJeDR^(OX6B`&RuqH65>gf^Kr6FbqKF|R<PNJU zdI5!1Svjc4VgVOfH^ii1oz@%D@(Z}HNb6jeHoYWmdQsZ^inRH4Y5Pmk_7|lcuSh#y z;E=e%FV#`fQ+<Ko_5#1{jO+{A))!=LAiNHi8@x(acy&6MZiq=t&zY37z~rKs#uYJ* z4(_j93_>DPB<32=sJ<wzeMMUPg09mQ-H;1Xp%(?it_X&8u-y<>T9AD~*$#rPh}&Nm z_qZhPaZ%jsinvz?TSp-XPEhP9>EyW~AU&h{f{YmiT@f(9F5qxUz~Q2R(-i@y>jIva z1U%0wUlj1WBH-7-dc)YPgRQryr*ekq9I1=^Dp&Yb7No97T$p)*U;iS9!2^EL4qnV6 z29&u$_;VZsxQJmy6fxj#E+P#hvCy+LF$GRA!%9X!O(t;aDhdIOgs`9|8Dnr%F)%QI z;vX_fA&eBKunMk98BfH+6m@{i0;Svrh8uj!SNL>0m~Ze4_1AXQ&Iq~4uXKf9={mpm zC4TLT{JK~8buVz}Vg?PU3jz+B58$9lXQ*WcbzWi%Kn;APU}CLdtYJlN&eyWluppYw zDU8|RvbIR8h8cUQTf?vbWDimRLzp!zHE08Qs3ig;LytQiGqIP(gv~6h0ZrmCRpo^w zW`jo00}}I65{nf;fevo%fSbGU91LqfD->7eC1(`n=jE3cBQ>V|!E-2}c08y}1(|Q~ z1C=`piA6<;mBmQSWYE~YCM&q$C<1kUiu^!@Ng{|y0i_pqNP6)Hl^{}B8#3T#(=7q0 z9H`|67J-yXpq?1Gi(d?iSWps!7cgS+NRCIOxGFV}pOKqD*wyub(%Ljg=bno}Mh;YX zsM>?ZpW+!I<OLqZ8&djLq^vu<ZwO1zh@RnlMOdrD<pwC18f@^n#P4u{-{FRU!gT@l zO9JXESgvb0UD9y6;2d!wD(<32{1uJ(ivkH(1QITQ(F0_K4wp0>E;xo=h={qU5qm`= z_M$-C6@j=5VDx}r{sM<QX37TjK0qn%vjTXi9G<cZ38(B7##(li)XfA<-SSKf$V2Hh z42X6EYO2Ld)wS$-au~Ydsh6<s8s-HcA3-t=5?RBJs+W<Wr-+FGd7vSMc?}Dyni^(Q zelUY3E2%jlGf$xaI-yBcPVg%NWrdgj|NsB5$#{z`FFq?jsaTU8DQAGZTLjM;kc<Kf zFij3b9w`D%Mc(2pD9TSxEiO(iM#}}a*r77TpkWYbc7U}1Zwbf4+>ewQFiTxbMRPzo zVFN5Dh>CY`-T-9;4o)>t8(a-nHV~8Qa7W7vQnqu|7pPoOvbiW_dqv9jf?e>1(8!BY zQCFm*Ca_One;{QIk+-@iWqn1;`hrd1h2ZduQW00AA||lk;1RvfBXfyIW=8l$9@Q&6 zsu#fM0l!jj#f+j0{7R4-V1XrC=?^Zap$P~S#Na$K4V*_%d*&rbDzo5hw81Z{6vh^` zDJ4XEIR&Y$ikjjuhP6;8tPr+=#+uiljR}K>wUF#!L9~ZcShK+%D*_EIf&2#MArqiU zRYYqX)i1TodD1lu3qS=YvKBBKzrGSsFB+_lfguajwga<km{XY6pw(>);KOTRB~YS7 z49Wlv&7q8#gWA7U>fk!6D6t$d5CdsWDkLf(jnL>(aqtB^UQuKZO6j1Ohc;3n0~NaH z6$ZST#5Ulv5mdD~f(X!<1*DDc1mc3s&HTa$9_V37tjH|Z<SB9mNw|RscM#zLB0NC^ zXwaj`7sLXMFBgS_sz>hR{JiAElGMD!lGLI|kO(VeY8l$*0@a~KfgnAgE<!P==m3=f z;5HYyg;phr(<kddK0gS~;}~6iCQw(OtnT~=4h9~P9=jQeSGZ+n6kp<2y1=b;Ltgp1 zywxRns~Z9mGgPifS$$+?QsDc-z@)%;fk);7kIV-S1_hPt^5&Q1&3AAdP`ToibWuL} zihS|~9+{{7f<1K?<n%TKUf?&s$Zvjy-~2ki-6ejzy`>lVU9a%FUf^(rRE)@(0hGYO znf@9$)7K!khFOv4jghh;6SNb_kiyu?l!lb8QO3)0bdFIw+^7Y?8npRg?3G*zeC(M4 zZG60z4Y^B==s?%NT+9sW{35#;TpgmU1gc@HVMAow6r@20aEF<>(YXZFRD<TxEO-`P z0PoO3rNC4T8=`Xy4K-F|`$2;Xh$0KE2x6~c$F3jj5>^}|?PvoD!3>&gelJ1!s|cLa zK!boqg`kW864T@ZC;lRMHoL_FYPH@HN>43`hm<Sv#gLIZaIPvU0qF$iIJU&1bkH0x zXw2>wTTx<ON`CGwHb}Sp7E5MMe)28BqSUg~qT*CUoeHb2LA_@9G9*y5^%h%kW_oU7 z@hzs}5^$Fp+;F`mS)7`anp~2a5)Uyrvnn+OG9itn5|@q7%tNyex%Z4wRN+;=AC&2D zfipd<>wA}9u%~8*_X<WZG+N=f!e|BS6&by~EIWd(+Xh{-4Z3I>0vb892)$wvdVxRe zB7fKw{;=!(v6uK`uf`=`<WIT6pVGl{S3zkB=L+o&i6FQ`aEH(ip$(~5v}_I-UQlqo zsNi}<!L@^>Bcvy0hGI|rMGn~;+=3UlW$$Wet>L_`VR1>rVu#2c$peO9n0UZskNgR> z3xUCBw4ft)JgPV3%sX6qJSQZA;0(bDDV_cg_yw=?D_`PQUf^<(U-JsT<^>K-NNE7d z{ou69z`y`64R(P$fGDK_Gp^DACD$`!9}>fx>(SQR)w1Ah-l8{M84=YtA{T*HL-~Od zyCyqY>MjC}jTBXY(r_(^0QHcdiwPiڎs8<e7e&{O*q_LzJSogb18@>L(qxd8! zae=BNL|+}6xVDOHNWO07bIHu-qM6?nGrx<5{#OkBFYpIk<PW&QAJD;agJ0-6zse<k zl?5Re`L(X_YhBQGxT5WIh2QrAhc7rr{ZK|!;ZgbqT)?L=)-tCv)Uu>A)UqN^ow+c? zcGQB3KKMv#Eqe`P3R4PGD^m?~4RZ}UXe17`xUc1?VMA04Da_D*ktELEJzEV2Xxw%g z3j@PyPz?-DcD0;&Vl@odlT8g<4F_7G$XUaJy;`W@M6FC18ERN@sANMV8w@=lpYk#= z)H3&U;;2l}5@9VD^5P^!g;2v)!@7nA)%7?k2?j(Zf$E-Gt{R3cP$C6~RIx}67xqa~ z)a+HuRl^1D#mxh^6_Hwn;DHQCv917XLSl5-6iSOzA+vwQiMgo?pz%+IM9{*%g3=Pu zQajK*CcKrWkXD+P3|jez7}jvjEXhbMQphWS40tHy7bz5^fLp8JStW!*@K7$K>j%lM z{E${1r~wTb4`~8r#}*I)9vK0J1f;nKo9QVoNlea0)LxJ}$_J#u1w?>H-MAq=MDQ4G zJxH_x)I5|bN=?o$OD%#A)W(CVA;{<nxSa^9qKZK6_aacCa7z@c`dfSmVUYe}P%8k` zu!Lf8`|_3`M!i&}h+bM?C^-izBlN(-xF2{J`1m`RZm6nX;840DF4@6zLsj#Ns_O>^ zK29+(0qSQ8sb3MY>~OvzB{!oG1Q#eRC|cpXuwsGk1$q4|QU;*uAeAfph|w1fo$DH2 zmo&U?h-j{`-obW&=_50frqCA#CQTv8q|gTr22I`T8XlK4JiuyD<UepQ@Ty<u(Y?f@ zdqL0nB9F@z9+wL|E;o22uJg!W;*r0gV0uNt@gk4Y6&|MxJWe-w_$DZKdUiN<I6ak= zo+E!j)A^#L%N0qN4!*mR(sN`M<n740AnAQk())^}_jSpjOOio1K(k27ADNk?xV|tj zNpW@XePCk{5b3G=z`)39enVFKx~$11S(A&hW>;j*I{Y5+%U<A+1s8$XOTabo5-^3a zmIbZ!QvuEAAjU9jU}Km}j0}02HLOVG9BllDwT3x`8MI8Gmbr!n%3|qJMXnB0SZWw& zGo-N2MH^#jVyt0B^jdMv57jWGgVe!G<a9>#9%L<h4I7S<6Ky02drua%+(4OO<3R4m z)^MPfOYjk1F%;jj*D%dyNZ|ncnR6~Hs-7lBP-7QIS&42w*BsD@EO(W<duj=IAwp(K zF=){-sO(A4EKAK(NPv!cf>uB#=z)upB12H_1~qW21Vb`ltFghe!3qggG6@PP`KiTu z>Lm)P6`92)pd}3;<3QO~4_y9Kaf5V~ftEWeBvi@cP!F+JA)$&XK~IzUmJqA~0G)-7 z&rG?+2~h)La)6DC&rH$eh7@k#G7Qv%xWyF)vdOimD8C53nF1=einKt!0X4OYDnWz6 z(D^2i{#&eIQ&DEXOY=%ni;8+ddcY+Yyb*JYJwGosJ|n*b+eDLmaz3cZ3|bZmnydnc z6{LL~Uy>i6nNp>XUQQ9v0Gc@{76F%952R&4<H91MpfNE>amC3Q!}x&#M1l$|$XKxV zC4T!0{Ps5lL?^I9hs3-t@!MVCx4R*ru)y&mXnI=V0gvbm!A`#p&koP8Yz)G(7i6P0 zNMDeRf}o2+(N~0`FYrV^fKGfs8dIS&+%M=kU(xfwpc`<3Kky=d;1&MB4wf52qSJLJ z=`LWoD5QEtNcFmq-X$Tui$VrhgbX$aUJ)|xV86@F*HPYS(_=Hk@DjJo1#X$ULNXU* z6IPgBkWB!gz!(q-W?vLaydso%fhX|;8-ud(4N0ZD5>j(G7pPy9(7Ga_^?{v5fa?PT zgn+Ki)MP3G&65^^)-P)^f~P2pKszfm1wjkNLCHBj{uWn!d@gtcNqqb*p7{8}(!?C7 z410WhN`7*De32|DszJ*niogpdz{~v*1Ek<F67cW{XmYa%QIU3mv`q#PpmjCi(h^)r z6@eD#fC*6bTl^c`^J!p!zz<9;tQ;R0Km-pLtI-DrR6;_EmFoiooUjlBap8o4EGuY5 z0tRWv¬T0Zu3}vFd$bfD;^itO_3(;DmwzNES}0@Uyxxf;7WPB@R}t4-9ZZgqhWk z@dE=A$)&(5_kjVGaAA}LITe$DgbOH0GzE*MgW3zA4K$#xC}<;05vb2}ixoVZ3~7(B z6@aIVz%7wmY+3n9IhjdCpeAk+Xt3p$Kv7DG9=y$zS_D}MT?A^I7l8(&!7DRwu|d|) zfFl8tF2JK?pvB9-IBX#4!LDc?0|Nu7N+`Y#n&@U`WMsU-AbWv9_5p*^1yuBaLFfVu z-Cz*D07EwzR4!mcHyE5QU_%eMMLVn}WGygVp?^Wi>LR!G6>jSWjvHbM6S8iwNZw!x rz5qraS((L|zA!M0GhJX1{=ieg%*geD0XsP%?IT#`3z)=I1rB}y%zUhQ literal 0 HcmV?d00001 diff --git a/ctgan/__pycache__/load_data.cpython-311.pyc b/ctgan/__pycache__/load_data.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a7c66a7be7e09ca9ccedc03198c88c49b47c286 GIT binary patch literal 168 zcmZ3^%ge>Uz`(#RoR{v&&A{*&#DQT(DC09o$#jNvhA4&<hF}IwMn6r)TU=?Gd5Jmk zDTyVCD;Yk6RQ}R&v5HB_O3X{o4=BpdN=+__DM`%9F3BuQjY%#^Pt1$S$xloH>(DEx v{Ka9Do1apelWJGQ#=yV;va48-fq~%zGb1D8hae^f9=--H5G-P0U|;|MN){;! literal 0 HcmV?d00001 diff --git a/ctgan/config.py b/ctgan/config.py new file mode 100644 index 0000000..b04e5c0 --- /dev/null +++ b/ctgan/config.py @@ -0,0 +1,48 @@ +import json +import logging + + +class HyperparametersParser: + def __init__(self, hyperparameters_file): + self.hyperparameters_file = hyperparameters_file + self.hyperparameters = {} + + def parse_hyperparameters(self): + """ + 解析超参数文件并进行类型转换。 + """ + try: + with open(self.hyperparameters_file, "r") as f: + injected_hyperparameters = json.load(f) + + logging.info("Parsing hyperparameters...") + for key, value in injected_hyperparameters.items(): + logging.info(f"Raw hyperparameter - Key: {key}, Value: {value}") + self.hyperparameters[key] = self._convert_type(value) + except FileNotFoundError: + logging.warning(f"Hyperparameters file {self.hyperparameters_file} not found. Using default parameters.") + except Exception as e: + logging.error(f"Error parsing hyperparameters: {e}") + return self.hyperparameters + + @staticmethod + def _convert_type(value): + """ + 将超参数值转换为合适的 Python 类型(字符串、整数、浮点数或布尔值)。 + """ + if isinstance(value, str): + # 判断是否为布尔值 + if value.lower() == "true": + return True + elif value.lower() == "false": + return False + + # 判断是否为整数或浮点数 + try: + if "." in value: + return float(value) + else: + return int(value) + except ValueError: + pass # 如果无法转换,保留字符串 + return value diff --git a/ctgan/data.py b/ctgan/data.py new file mode 100644 index 0000000..8a48da0 --- /dev/null +++ b/ctgan/data.py @@ -0,0 +1,88 @@ +"""Data loading.""" + +import json + +import numpy as np +import pandas as pd + + +def read_csv(csv_filename, meta_filename=None, header=True, discrete=None): + """Read a csv file.""" + data = pd.read_csv(csv_filename, header='infer' if header else None) + + if meta_filename: + with open(meta_filename) as meta_file: + metadata = json.load(meta_file) + + discrete_columns = [ + column['name'] for column in metadata['columns'] if column['type'] != 'continuous' + ] + + elif discrete: + discrete_columns = discrete.split(',') + if not header: + discrete_columns = [int(i) for i in discrete_columns] + + else: + discrete_columns = [] + + return data, discrete_columns + + +def read_tsv(data_filename, meta_filename): + """Read a tsv file.""" + with open(meta_filename) as f: + column_info = f.readlines() + + column_info_raw = [x.replace('{', ' ').replace('}', ' ').split() for x in column_info] + + discrete = [] + continuous = [] + column_info = [] + + for idx, item in enumerate(column_info_raw): + if item[0] == 'C': + continuous.append(idx) + column_info.append((float(item[1]), float(item[2]))) + else: + assert item[0] == 'D' + discrete.append(idx) + column_info.append(item[1:]) + + meta = { + 'continuous_columns': continuous, + 'discrete_columns': discrete, + 'column_info': column_info, + } + + with open(data_filename) as f: + lines = f.readlines() + + data = [] + for row in lines: + row_raw = row.split() + row = [] + for idx, col in enumerate(row_raw): + if idx in continuous: + row.append(col) + else: + assert idx in discrete + row.append(column_info[idx].index(col)) + + data.append(row) + + return np.asarray(data, dtype='float32'), meta['discrete_columns'] + + +def write_tsv(data, meta, output_filename): + """Write to a tsv file.""" + with open(output_filename, 'w') as f: + for row in data: + for idx, col in enumerate(row): + if idx in meta['continuous_columns']: + print(col, end=' ', file=f) + else: + assert idx in meta['discrete_columns'] + print(meta['column_info'][idx][int(col)], end=' ', file=f) + + print(file=f) diff --git a/ctgan/data_sampler.py b/ctgan/data_sampler.py new file mode 100644 index 0000000..d9b67d1 --- /dev/null +++ b/ctgan/data_sampler.py @@ -0,0 +1,153 @@ +"""DataSampler module.""" + +import numpy as np + + +class DataSampler(object): + """DataSampler samples the conditional vector and corresponding data for CTGAN.""" + + def __init__(self, data, output_info, log_frequency): + self._data_length = len(data) + + def is_discrete_column(column_info): + return len(column_info) == 1 and column_info[0].activation_fn == 'softmax' + + n_discrete_columns = sum([ + 1 for column_info in output_info if is_discrete_column(column_info) + ]) + + self._discrete_column_matrix_st = np.zeros(n_discrete_columns, dtype='int32') + + # Store the row id for each category in each discrete column. + # For example _rid_by_cat_cols[a][b] is a list of all rows with the + # a-th discrete column equal value b. + self._rid_by_cat_cols = [] + + # Compute _rid_by_cat_cols + st = 0 + for column_info in output_info: + if is_discrete_column(column_info): + span_info = column_info[0] + ed = st + span_info.dim + + rid_by_cat = [] + for j in range(span_info.dim): + rid_by_cat.append(np.nonzero(data[:, st + j])[0]) + self._rid_by_cat_cols.append(rid_by_cat) + st = ed + else: + st += sum([span_info.dim for span_info in column_info]) + assert st == data.shape[1] + + # Prepare an interval matrix for efficiently sample conditional vector + max_category = max( + [column_info[0].dim for column_info in output_info if is_discrete_column(column_info)], + default=0, + ) + + self._discrete_column_cond_st = np.zeros(n_discrete_columns, dtype='int32') + self._discrete_column_n_category = np.zeros(n_discrete_columns, dtype='int32') + self._discrete_column_category_prob = np.zeros((n_discrete_columns, max_category)) + self._n_discrete_columns = n_discrete_columns + self._n_categories = sum([ + column_info[0].dim for column_info in output_info if is_discrete_column(column_info) + ]) + + st = 0 + current_id = 0 + current_cond_st = 0 + for column_info in output_info: + if is_discrete_column(column_info): + span_info = column_info[0] + ed = st + span_info.dim + category_freq = np.sum(data[:, st:ed], axis=0) + if log_frequency: + category_freq = np.log(category_freq + 1) + category_prob = category_freq / np.sum(category_freq) + self._discrete_column_category_prob[current_id, : span_info.dim] = category_prob + self._discrete_column_cond_st[current_id] = current_cond_st + self._discrete_column_n_category[current_id] = span_info.dim + current_cond_st += span_info.dim + current_id += 1 + st = ed + else: + st += sum([span_info.dim for span_info in column_info]) + + def _random_choice_prob_index(self, discrete_column_id): + probs = self._discrete_column_category_prob[discrete_column_id] + r = np.expand_dims(np.random.rand(probs.shape[0]), axis=1) + return (probs.cumsum(axis=1) > r).argmax(axis=1) + + def sample_condvec(self, batch): + """Generate the conditional vector for training. + + Returns: + cond (batch x #categories): + The conditional vector. + mask (batch x #discrete columns): + A one-hot vector indicating the selected discrete column. + discrete column id (batch): + Integer representation of mask. + category_id_in_col (batch): + Selected category in the selected discrete column. + """ + if self._n_discrete_columns == 0: + return None + + discrete_column_id = np.random.choice(np.arange(self._n_discrete_columns), batch) + + cond = np.zeros((batch, self._n_categories), dtype='float32') + mask = np.zeros((batch, self._n_discrete_columns), dtype='float32') + mask[np.arange(batch), discrete_column_id] = 1 + category_id_in_col = self._random_choice_prob_index(discrete_column_id) + category_id = self._discrete_column_cond_st[discrete_column_id] + category_id_in_col + cond[np.arange(batch), category_id] = 1 + + return cond, mask, discrete_column_id, category_id_in_col + + def sample_original_condvec(self, batch): + """Generate the conditional vector for generation use original frequency.""" + if self._n_discrete_columns == 0: + return None + + category_freq = self._discrete_column_category_prob.flatten() + category_freq = category_freq[category_freq != 0] + category_freq = category_freq / np.sum(category_freq) + col_idxs = np.random.choice(np.arange(len(category_freq)), batch, p=category_freq) + cond = np.zeros((batch, self._n_categories), dtype='float32') + cond[np.arange(batch), col_idxs] = 1 + + return cond + + def sample_data(self, data, n, col, opt): + """Sample data from original training data satisfying the sampled conditional vector. + + Args: + data: + The training data. + + Returns: + n: + n rows of matrix data. + """ + if col is None: + idx = np.random.randint(len(data), size=n) + return data[idx] + + idx = [] + for c, o in zip(col, opt): + idx.append(np.random.choice(self._rid_by_cat_cols[c][o])) + + return data[idx] + + def dim_cond_vec(self): + """Return the total number of categories.""" + return self._n_categories + + def generate_cond_from_condition_column_info(self, condition_info, batch): + """Generate the condition vector.""" + vec = np.zeros((batch, self._n_categories), dtype='float32') + id_ = self._discrete_column_matrix_st[condition_info['discrete_column_id']] + id_ += condition_info['value_id'] + vec[:, id_] = 1 + return vec diff --git a/ctgan/data_transformer.py b/ctgan/data_transformer.py new file mode 100644 index 0000000..7ed8b0e --- /dev/null +++ b/ctgan/data_transformer.py @@ -0,0 +1,265 @@ +"""DataTransformer module.""" + +from collections import namedtuple + +import numpy as np +import pandas as pd +from joblib import Parallel, delayed +from rdt.transformers import ClusterBasedNormalizer, OneHotEncoder + +SpanInfo = namedtuple('SpanInfo', ['dim', 'activation_fn']) +ColumnTransformInfo = namedtuple( + 'ColumnTransformInfo', + ['column_name', 'column_type', 'transform', 'output_info', 'output_dimensions'], +) + + +class DataTransformer(object): + """Data Transformer. + + Model continuous columns with a BayesianGMM and normalize them to a scalar between [-1, 1] + and a vector. Discrete columns are encoded using a OneHotEncoder. + """ + + def __init__(self, max_clusters=10, weight_threshold=0.005): + """Create a data transformer. + + Args: + max_clusters (int): + Maximum number of Gaussian distributions in Bayesian GMM. + weight_threshold (float): + Weight threshold for a Gaussian distribution to be kept. + """ + self._max_clusters = max_clusters + self._weight_threshold = weight_threshold + + def _fit_continuous(self, data): + """Train Bayesian GMM for continuous columns. + + Args: + data (pd.DataFrame): + A dataframe containing a column. + + Returns: + namedtuple: + A ``ColumnTransformInfo`` object. + """ + column_name = data.columns[0] + gm = ClusterBasedNormalizer( + missing_value_generation='from_column', + max_clusters=min(len(data), self._max_clusters), + weight_threshold=self._weight_threshold, + ) + gm.fit(data, column_name) + num_components = sum(gm.valid_component_indicator) + + return ColumnTransformInfo( + column_name=column_name, + column_type='continuous', + transform=gm, + output_info=[SpanInfo(1, 'tanh'), SpanInfo(num_components, 'softmax')], + output_dimensions=1 + num_components, + ) + + def _fit_discrete(self, data): + """Fit one hot encoder for discrete column. + + Args: + data (pd.DataFrame): + A dataframe containing a column. + + Returns: + namedtuple: + A ``ColumnTransformInfo`` object. + """ + column_name = data.columns[0] + ohe = OneHotEncoder() + ohe.fit(data, column_name) + num_categories = len(ohe.dummies) + + return ColumnTransformInfo( + column_name=column_name, + column_type='discrete', + transform=ohe, + output_info=[SpanInfo(num_categories, 'softmax')], + output_dimensions=num_categories, + ) + + def fit(self, raw_data, discrete_columns=()): + """Fit the ``DataTransformer``. + + Fits a ``ClusterBasedNormalizer`` for continuous columns and a + ``OneHotEncoder`` for discrete columns. + + This step also counts the #columns in matrix data and span information. + """ + self.output_info_list = [] + self.output_dimensions = 0 + self.dataframe = True + + if not isinstance(raw_data, pd.DataFrame): + self.dataframe = False + # work around for RDT issue #328 Fitting with numerical column names fails + discrete_columns = [str(column) for column in discrete_columns] + column_names = [str(num) for num in range(raw_data.shape[1])] + raw_data = pd.DataFrame(raw_data, columns=column_names) + + self._column_raw_dtypes = raw_data.infer_objects().dtypes + self._column_transform_info_list = [] + for column_name in raw_data.columns: + if column_name in discrete_columns: + column_transform_info = self._fit_discrete(raw_data[[column_name]]) + else: + column_transform_info = self._fit_continuous(raw_data[[column_name]]) + + self.output_info_list.append(column_transform_info.output_info) + self.output_dimensions += column_transform_info.output_dimensions + self._column_transform_info_list.append(column_transform_info) + + def _transform_continuous(self, column_transform_info, data): + column_name = data.columns[0] + flattened_column = data[column_name].to_numpy().flatten() + data = data.assign(**{column_name: flattened_column}) + gm = column_transform_info.transform + transformed = gm.transform(data) + + # Converts the transformed data to the appropriate output format. + # The first column (ending in '.normalized') stays the same, + # but the lable encoded column (ending in '.component') is one hot encoded. + output = np.zeros((len(transformed), column_transform_info.output_dimensions)) + output[:, 0] = transformed[f'{column_name}.normalized'].to_numpy() + index = transformed[f'{column_name}.component'].to_numpy().astype(int) + output[np.arange(index.size), index + 1] = 1.0 + + return output + + def _transform_discrete(self, column_transform_info, data): + ohe = column_transform_info.transform + return ohe.transform(data).to_numpy() + + def _synchronous_transform(self, raw_data, column_transform_info_list): + """Take a Pandas DataFrame and transform columns synchronous. + + Outputs a list with Numpy arrays. + """ + column_data_list = [] + for column_transform_info in column_transform_info_list: + column_name = column_transform_info.column_name + data = raw_data[[column_name]] + if column_transform_info.column_type == 'continuous': + column_data_list.append(self._transform_continuous(column_transform_info, data)) + else: + column_data_list.append(self._transform_discrete(column_transform_info, data)) + + return column_data_list + + def _parallel_transform(self, raw_data, column_transform_info_list): + """Take a Pandas DataFrame and transform columns in parallel. + + Outputs a list with Numpy arrays. + """ + processes = [] + for column_transform_info in column_transform_info_list: + column_name = column_transform_info.column_name + data = raw_data[[column_name]] + process = None + if column_transform_info.column_type == 'continuous': + process = delayed(self._transform_continuous)(column_transform_info, data) + else: + process = delayed(self._transform_discrete)(column_transform_info, data) + processes.append(process) + + return Parallel(n_jobs=-1)(processes) + + def transform(self, raw_data): + """Take raw data and output a matrix data.""" + if not isinstance(raw_data, pd.DataFrame): + column_names = [str(num) for num in range(raw_data.shape[1])] + raw_data = pd.DataFrame(raw_data, columns=column_names) + + # Only use parallelization with larger data sizes. + # Otherwise, the transformation will be slower. + if raw_data.shape[0] < 500: + column_data_list = self._synchronous_transform( + raw_data, self._column_transform_info_list + ) + else: + column_data_list = self._parallel_transform(raw_data, self._column_transform_info_list) + + return np.concatenate(column_data_list, axis=1).astype(float) + + def _inverse_transform_continuous(self, column_transform_info, column_data, sigmas, st): + gm = column_transform_info.transform + data = pd.DataFrame(column_data[:, :2], columns=list(gm.get_output_sdtypes())).astype(float) + data[data.columns[1]] = np.argmax(column_data[:, 1:], axis=1) + if sigmas is not None: + selected_normalized_value = np.random.normal(data.iloc[:, 0], sigmas[st]) + data.iloc[:, 0] = selected_normalized_value + + return gm.reverse_transform(data) + + def _inverse_transform_discrete(self, column_transform_info, column_data): + ohe = column_transform_info.transform + data = pd.DataFrame(column_data, columns=list(ohe.get_output_sdtypes())) + return ohe.reverse_transform(data)[column_transform_info.column_name] + + def inverse_transform(self, data, sigmas=None): + """Take matrix data and output raw data. + + Output uses the same type as input to the transform function. + Either np array or pd dataframe. + """ + st = 0 + recovered_column_data_list = [] + column_names = [] + for column_transform_info in self._column_transform_info_list: + dim = column_transform_info.output_dimensions + column_data = data[:, st : st + dim] + if column_transform_info.column_type == 'continuous': + recovered_column_data = self._inverse_transform_continuous( + column_transform_info, column_data, sigmas, st + ) + else: + recovered_column_data = self._inverse_transform_discrete( + column_transform_info, column_data + ) + + recovered_column_data_list.append(recovered_column_data) + column_names.append(column_transform_info.column_name) + st += dim + + recovered_data = np.column_stack(recovered_column_data_list) + recovered_data = pd.DataFrame(recovered_data, columns=column_names).astype( + self._column_raw_dtypes + ) + if not self.dataframe: + recovered_data = recovered_data.to_numpy() + + return recovered_data + + def convert_column_name_value_to_id(self, column_name, value): + """Get the ids of the given `column_name`.""" + discrete_counter = 0 + column_id = 0 + for column_transform_info in self._column_transform_info_list: + if column_transform_info.column_name == column_name: + break + if column_transform_info.column_type == 'discrete': + discrete_counter += 1 + + column_id += 1 + + else: + raise ValueError(f"The column_name `{column_name}` doesn't exist in the data.") + + ohe = column_transform_info.transform + data = pd.DataFrame([value], columns=[column_transform_info.column_name]) + one_hot = ohe.transform(data).to_numpy()[0] + if sum(one_hot) == 0: + raise ValueError(f"The value `{value}` doesn't exist in the column `{column_name}`.") + + return { + 'discrete_column_id': discrete_counter, + 'column_id': column_id, + 'value_id': np.argmax(one_hot), + } diff --git a/ctgan/dataprocess_using_gan.py b/ctgan/dataprocess_using_gan.py new file mode 100644 index 0000000..53cfff2 --- /dev/null +++ b/ctgan/dataprocess_using_gan.py @@ -0,0 +1,141 @@ +import argparse +import logging +import os +import tarfile + +import pandas as pd +from ctgan import CTGAN + + +# 配置日志记录 +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[ + logging.StreamHandler() + ] +) +def extract_model_and_tokenizer(tar_path, extract_path="/tmp"): + """ + 从 tar.gz 文件中提取模型文件和分词器目录。 + + Args: + tar_path (str): 模型 tar.gz 文件路径。 + extract_path (str): 提取文件的目录。 + + Returns: + model_path (str): 提取后的模型文件路径 (.pth)。 + tokenizer_path (str): 提取后的分词器目录路径。 + """ + with tarfile.open(tar_path, "r:gz") as tar: + tar.extractall(path=extract_path) + + model_path = None + + + for root, dirs, files in os.walk(extract_path): + for file in files: + if file.endswith(".pt"): + model_path = os.path.join(root, file) + + + if not model_path: + raise FileNotFoundError("No .pth model file found in the extracted tar.gz.") + + + return model_path +def load_model(model_path): + """加载预训练的 CTGAN 模型.""" + logging.info(f"Loading model from {model_path}") + + model = CTGAN.load(model_path) + logging.info("Model loaded successfully.") + return model + + +def generate_samples(ctgan, data, discrete_columns, min_samples_per_class=5000): + """ + 使用 CTGAN 模型为不足类别生成样本. + + Args: + ctgan: 加载的 CTGAN 模型. + data: 输入数据(DataFrame). + discrete_columns: 离散列名称列表. + min_samples_per_class: 每个类别的最小样本数量. + + Returns: + 包含生成数据的 DataFrame. + """ + label_column = 'label' + label_counts = data[label_column].value_counts() + logging.info("Starting sample generation for underrepresented classes...") + logging.info(f"Current label distribution:\n{label_counts}") + generated_data = [] + for label, count in label_counts.items(): + if count < min_samples_per_class: + # 需要生成的样本数量 + n_to_generate = min_samples_per_class - count + logging.info(f"Label '{label}' has {count} samples, generating {n_to_generate} more samples...") + + # 使用 CTGAN 生成样本 + samples = ctgan.sample( + n=n_to_generate, + condition_column=label_column, + condition_value=label + ) + generated_data.append(samples) + + if generated_data: + # 合并生成的数据 + generated_data = pd.concat(generated_data, ignore_index=True) + logging.info(f"Generated data shape: {generated_data.shape}") + else: + logging.info("No underrepresented classes found. No samples were generated.") + generated_data = pd.DataFrame() + + return generated_data + +def main(): + parser = argparse.ArgumentParser(description="CTGAN Sample Generation Script") + parser.add_argument("--input", type=str, required=True, help="Path to input CSV file.") + parser.add_argument("--model", type=str, required=True, help="Path to the trained CTGAN model.") + parser.add_argument("--output", type=str, required=True, help="Path to save the output CSV file.") + parser.add_argument("--n_samples", type=int, default=1000, help="Number of samples to generate per class.") + args = parser.parse_args() + + input_path = args.input + model_path = args.model + output_dir = args.output + n_samples = args.n_samples + + # 加载输入数据 + logging.info(f"Loading input data from {input_path}") + data = pd.read_csv(input_path) + logging.info(f"Input data shape: {data.shape}") + + # 识别离散列 + label_column = "label" # 假设 label 列为目标列 + discrete_columns = [label_column] + if label_column not in data.columns: + raise ValueError(f"Label column '{label_column}' not found in input data.") + model_p=extract_model_and_tokenizer(model_path) + # 加载 CTGAN 模型 + ctgan = load_model(model_p) + + # 生成样本 + generated_data = generate_samples(ctgan, data, discrete_columns) + + # 将生成的样本与原始数据合并 + combined_data = pd.concat([data, generated_data], ignore_index=True) + logging.info(f"Combined data shape: {combined_data.shape}") + + # 保存生成的样本和原始数据 + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, "generated_samples.csv") # 添加文件名 + + logging.info(f"Saving combined data to {output_path}") + combined_data.to_csv(output_path, index=False) + logging.info("Sample generation and saving completed successfully.") + +if __name__ == "__main__": + main() diff --git a/ctgan/load_data.py b/ctgan/load_data.py new file mode 100644 index 0000000..fe98e2e --- /dev/null +++ b/ctgan/load_data.py @@ -0,0 +1,73 @@ +# import pandas as pd +# import numpy as np +# from sklearn.impute import SimpleImputer +# from imblearn.over_sampling import SMOTE +# import os +# +# # 定义文件夹路径 +# folder_path = "data" +# +# # 定义需要提取的特征 +# selected_features = [ +# "Destination Port", "Flow Duration", "Total Fwd Packets", "Total Backward Packets", +# "Total Length of Fwd Packets", "Total Length of Bwd Packets", "Fwd Packet Length Max", +# "Fwd Packet Length Min", "Fwd Packet Length Mean", "Fwd Packet Length Stddev", +# "Bwd Packet Length Max", "Bwd Packet Length Min", "Bwd Packet Length Mean", +# "Bwd Packet Length Stddev", "Flow Bytes/s", "Flow Packets/s", "Fwd IAT Mean (ms)", +# "Fwd IAT Stddev (ms)", "Fwd IAT Max (ms)", "Fwd IAT Min (ms)", "Bwd IAT Mean (ms)", +# "Bwd IAT Stddev (ms)", "Bwd IAT Max (ms)", "Bwd IAT Min (ms)", "Fwd PSH Flags", +# "Bwd PSH Flags", "Fwd URG Flags", "Bwd URG Flags", "Fwd Packets/s", "Bwd Packets/s", +# "down_up_ratio", "average_packet_size", "avg_fwd_segment_size", "avg_bwd_segment_size", +# "Subflow Fwd Packets", "Subflow Fwd Bytes", "Subflow Bwd Packets", "Subflow Bwd Bytes", +# "Label" +# ] +# +# # 标准化列名函数 +# def standardize_columns(df): +# df.columns = [col.strip().lower().replace(" ", "_") for col in df.columns] +# return df +# +# standardized_features = [col.strip().lower().replace(" ", "_") for col in selected_features] +# +# # 读取所有CSV文件 +# all_data = [] +# for file_name in os.listdir(folder_path): +# if file_name.endswith(".csv"): +# file_path = os.path.join(folder_path, file_name) +# df = pd.read_csv(file_path) +# df = standardize_columns(df) +# df_selected = df[[col for col in standardized_features if col in df.columns]].dropna(how="any", subset=["label"]) +# all_data.append(df_selected) +# +# # 合并数据 +# final_data = pd.concat(all_data, ignore_index=True) +# +# # 分离特征和标签 +# X = final_data.drop(columns=["label"]) +# y = final_data["label"] +# +# # 检查和处理 inf 或超大值 +# X.replace([float('inf'), float('-inf')], np.nan, inplace=True) +# +# # 检查并替换所有 pd.NA 为 np.nan +# X = X.replace({pd.NA: np.nan}) +# +# # 确保所有列为浮点数类型 +# X = X.astype(float) +# +# # 检查是否仍有 NaN +# if X.isnull().values.any(): +# print("存在缺失值,填充中位数...") +# imputer = SimpleImputer(strategy="median") # 使用中位数填充 +# X = pd.DataFrame(imputer.fit_transform(X), columns=X.columns) +# +# # 应用 SMOTE +# X_resampled, y_resampled = SMOTE().fit_resample(X, y) +# +# # 转换为 DataFrame +# balanced_data = pd.concat([pd.DataFrame(X_resampled, columns=X.columns), pd.DataFrame(y_resampled, columns=["label"])], axis=1) +# +# # 查看结果 +# print(balanced_data["label"].value_counts()) + +final_data=[] diff --git a/ctgan/synthesizers/__init__.py b/ctgan/synthesizers/__init__.py new file mode 100644 index 0000000..881e08f --- /dev/null +++ b/ctgan/synthesizers/__init__.py @@ -0,0 +1,10 @@ +"""Synthesizers module.""" + +from ctgan.synthesizers.ctgan import CTGAN +from ctgan.synthesizers.tvae import TVAE + +__all__ = ('CTGAN', 'TVAE') + + +def get_all_synthesizers(): + return {name: globals()[name] for name in __all__} diff --git a/ctgan/synthesizers/__pycache__/__init__.cpython-311.pyc b/ctgan/synthesizers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7bfe77a077018470ff0b67846f5b064c5acf443 GIT binary patch literal 776 zcmZ3^%ge>Uz`$VFos|BDfq~&Mhy%k+P{wB+1_p-d3@Hpz3@MB$OgW6XOi@gXAU1Oj zb1q913nN1cOB8DgYYRgZTMAn+gC={GNN{CdNk(dMW>soYu|jTsN@-52-b;{ynvAzt zokQFm{WKYGv4n&<x@s~NF*7hQ6tOTcFeHOiz%a;ukj!Tli2aN$3?(o*1_lNfhGh&4 z46ETH!3>&Ae#wk*J_7?Q0|SFF0|Ucn3$Ov13^fd~;<b!5j3r2_Pz|YJl!O~FnW;x2 zm|-P@CgUyk^ql;p#GGPHrdv#U2De!95_40FLGDvf_+{Z@6_b*cn3tX(P?VpQnp_f7 zl9-cSl3A7-lU$OXm={xw6cjP>@tJv<CGqik1(m<JY*I3lOOo?*3+$>i(o;*~6LWIn zkyY#2<m4wO<`moMAyk6mt5}JFfuVun0=F3iElBBbnW{cfbBg8+mnr%Xu?yU0pFv&$ zxjr6ja(odxD4cmg1UCZ%!z~ePb{6q7Ffed3FfbI0fi1ZKwq$Mk1#Yv8+-6s}%`UK* z6@kp~(`3KJ9v`2QpBx{5O9JdwJru`-g}`xpOB~Ebl`bhuOa(`J5g*7!AZyVa_>02^ zl5*^dK+y<_v|?Wd28IvJjEsyo7^E&RNIhWCx`2vqFlbyrMGv^78`wXvF*2HeV8A56 GegOb+M89$X literal 0 HcmV?d00001 diff --git a/ctgan/synthesizers/__pycache__/base.cpython-311.pyc b/ctgan/synthesizers/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96a5fbe3cdb9bc443d8d155f16e47ad2d9aaaa4d GIT binary patch literal 8812 zcmZ3^%ge>Uz`$VFos=#s%E0g##DQTJDC2Vy0|Uc!h7^V<h7`sWrX0p7Mlj73#SErd zqF7QHvsj_38B&;A7?v?GFsz0NFhsGXu(U8lu`@BaGo-M#Fr=_8V`5-f%>+{r#gW1k z%%I6$CGM11oElu2SCWxhoLQAxq>!7RQks*h_Y$PiFBz_Zfq{XIfq_Apfq~(31PcR$ zHv=QXbcPa;Dk#okgfbWwAnSq&qB*dZse~Oa#gGN_IJzq48ip7q28LReT2>r(!2<zd z8UsTO3o&k3#>l|18t!UF1{4>8d<!<RhLs$%I2lqH!x*MBGU5s&B;SLr0TZZ^Sj$$! zwg9XDaD)XB;f~kO9AHN<Frdc*dks4dcOdx!R~(|-$iz^?Si@Gsp2mbOgTq%?L!bsz zH?mt0vB!*VOA2E!gYsyspr!qRj+JtT3g$?Laz;(2L<L3$2Iu^|lGKV4h1|rv#Prl6 zg|z%42s<+`U7;i+RiP*`FC{-$p|~WmBvp@#OF=<F!Lca4*b2-7@gNG~!3q>KGV@9l z@{1HoN(*vQH4%zH3PX@gOHEAyDNBSXRLD<L(8w#zEvVFkn5Gv5wlx@Jsg6QPeo=CU zo_lIuYEfcIevzgg-0tGklK9;Gl+>Jfh(4Gr(n|A^OEUBG&|KgK6;>$8S13*`!SJ6# zejZp1Y$U`MKTXbCOnC*j*dP|%;!IC1f!K1373{8CydaUHymY82C&*3Uz_`T=Vj^U? zkvw&a6X6|>TRafMp(codw1J&~;;UPN`9+!OnR$sh@p%PMr6O=4uv?&_{7~cI%7mZ- z$chA#ON)w9^GXng3ByE?lof+YDp2Hs@Glb=tC*Cm#Ju$UfTH}Y)Z~(wlEj?slFYKy znB<c5#JredMEzDAlayGTs#j2XOAzEAs0+XWQ(Pp*z`(%Cz`#&^osogz1;Y!528OR9 z43g4wIIl}+Uy{(iD4}~rLiYlT=v`@rE7E!wr46n~8(d(Kc%Wr=kwv}1|AMIcQx>)s z&+9DWmsrGS@L!iOyd+_GQNs9&gz-fdlPfGH7g$X0vT$Bu5x*dzbzQ>bl7z`c39~B_ zW*1q^udtY3U@`x|!o_NIgM+t&6$GD3%g(X9E^Tm0+TfzJ(G_W<3mg(RI3%ufC|%-E zT57b!^18CwC1taV$`)6YEiQ6cUg5C3z+rh;PI*rLMLGQ|a{3oIq;GIYU*}N0#G$%0 z^@^J1MK$XyYStGyY_4$FT;Q;|!6A2nL+*lt!387N3-WFkIoz*sxL@FKf6Bqt!+D)U z`Vxoq1sUCo9C}wc^e%Ad-R0oD0unK}$YFSe!|(zO{rJkyAR{;Uo%-X)k78a128Kg| zEUxklN9CQlJT#e;;Y|Pr1_n?S$_#4c2r)1)v@=X+=wL`=Okr%{sAa5S3}(<|s%j6$ zmPM0N)AEZ_6_OKka$rRiEVpAVo?wy+8qSG1If+SFizT?Cl6-~oqQnA)#G?Gtyc9?g zfZ}XWeJ2lc_9_MjhN+Cx8EP1^V3j#T0}m5JCvy!$I%5svBE}$wOoke!C5(NH!3-rJ zuR_%!>!@MK0tEzE2DRa6%D~8g-Lzz;6b2APRnN!(s%;G!ido8;ASFB_LnH$u10zEU zLokCTqn{@8EnduAcuN31cWScSVku6|NxQ|8Sd?CTi!HkxOmo1(s2J3QR8UZ8C=z90 zV7SE&3a^6HqAFQr!}V-(@{<#DitY5E(gvX1ew2ZM;YS0*4Q~D`+_D$ARc{DOPq3Kb zazWT=fyo77BM4ey2B9zT7=2}95K>-Xaz)v6L&+6$pNoRNR|I`OFfj6(gNYmB@)If+ zuv`!~TTpUA+zf(NltJhZ%rIqOwP4~SNb4660p{W)e*E}>Qw}UqEXcsXu%5?}gZUr_ ziz5&FL4HO@9`=<CnoLD<3=9lKpcGOJir54Na3~dl@&cI90EGjnasQ)%;e!AhtN8~8 zEQA09k8p$EXHc?81`#j}G7`l8>;mq#q%uS?rZ7Y?r7%V@r!YmafVwA9tl$m_8>ll9 z#h%KM!iLdh;Y?v~VTj^N0d-6`Zt<gaOsY&l0t(4FiN(d>=0IXjjzUUmT4HHVi2`x~ z5302j5}ZTa9sLp#^!!k=38*OvO7fqbm_Yq2<Ro9qSi`UYxg-F22#iZWy1*<3hAfb7 zFgt}2wPA_MV+3VDP>KYr#ol;EYwR=OvcCq+CPs!5P`rUnW&qg(Z}^r7Km|cfc$EHL z(F$C8QkW#+no2;09+Z~_?xlh_S==BJgi*70Q5{Y*nUTzdhXMmb77xf25Jod8vW6iG z?h9lWEkN=qNFM~FhFlF(3X2T`YDlCYxqwJBSixp6E#QOLfl8*ZVZ<*enX)G?VPs(N z%q=L&FH2P@$V|=#m;T`91UQi<B!E&aQvW(30a_h^S}xA{d1a|ZC7^a+LPByuX+nZR zN@`hVa;gHfIw>wmEGmK3Cvf8nit>|Fi;ER9^Az%nQc{aR`jb)>K$VC>Nj|(4nv<WH zf@n7?q?V=TDU@fV<|&jGr=}>R73JqDB<3lkR+OX`<t64Ql;r2<C={0_XDB2VD<mWY zyM!eq=qWhn6zA(GWTwGQO-Mj4P>T~16cUS4LCsbjh2j!W6<Csynpp&Cm=}YLO34Jb zc@v8=Q;QXf5=$~b4$er-Q%HxkoQsiyBq+6{v?vd$9S;g8kP)bjc93X5Wl2VUo&rcx zp(wSWD782>4<1<gNuY)xq^1NZzQvqeP+9~ks9u6H`7PmiaKj$8DJvWgX=5V$un3ft zHTiF`$3vp#78j^yg0PD~Y4#Qygma4{J|5(j`1o5a$@vA9x41HkGxLf|67!N%!G_#o z1$RG+KpCP4RHhVxGFA~NPZWXb?_10T`30KHMbe=9ix*;Od{Sa^c4@&aR&bIk29-Vv z3XsD47EgS9dTI$+B0j!K1}&lKA<H^|iubP!3=9nnAH*0WW#{m8@ZR9y>F4j_pUXOj zXF=#i3GFKq+7~%=u5jpF;Ly1tATovbx`5Ip0i}xqDpv$lI#{8~XYgL+P`kpRc7a3f zDUaBMk}KkxYfUfm7+&Eq{J_l4ss2<zYKF-bY26ixSELQLDs9ocV&r~Nz~hR5#|Kt+ zPW7*B41&@~3RJh~Ulee<BH;9aot;zt0~dpi!4+-W9f?=89j<G8UDEctsO@t_+o!|z zf{^+}9*rB?`go;z1p7U^JSXH|<WalAqjrHu?E$|?2TzeA0|NufP!=c(Kw<g$93!Y6 zNMWpDEMq8|kirC}dukXLpk{w4553H<Wx}ZvZFr)V8M)VM+!wX0>RVOaoj7t}rk zRf92%Q3C1$K@}iI0vCW15>ymHp_Z4mtTn7SY(aB1BSQ^q4NDqR4Py#3xO8N!VTcE% zd9cwaK7zZZ1eC|X@(c`FoKObC0(fl#69Lhv?kZwRVaa1mVMQ%zYuRcT7JxDeNCOJ4 z!6=hc*uZ0&$AePgMG>SR0<{Se74R3v$eJ(<IdE|a8h8NbH%OtAnWj*Zkzbq&E&-wC zk3wP|xX{gm7nrafJtRXbBqk^47o~v9^n7sX4DH@3WF!`)lqVLYBFd~Hbx@-X)HDN^ zAHUd=OH&f942o147#RFCd5S=dS8x#oDMLVkjb6CufK1Q^6(}qq&9`_mi{lf^5;MU) zjv_4v1_n)baDh`~1X9ES@0)W&N<y%2i$p+j@P0U?;DpG4j4lSH2v8Y|Trw5oEtwFq zQJ|9PE~I4QVc=0*;Cg{u6NL1f50qZe^SPksb3<HadgY|b3yNkJ#m%pXn|FBJkdmLT zIZN|`iseNqt1D7g9o|os)R*XAR5HGzWPE`~{sxcybso)2Jen)CFY=gO;W4|wWA;=@ zWr^l>CG$&4<`<PLuP9lfsMEN}V|<0j_yUjd4K>RP)_xaM{X1MI=wIYfdLSX&!P3Kf zS5Rz<%v#koTGzE4FKIbm)N;O}<$OWTZ$`$9j1!_41^urG`ggFw3Q4w$9P(E<<S)Pw zW(fo;|3NwZ^EPmk8KotPvn5DgOB6*HsN@6Zq8f%8M$}dqBLmJBC`P#eZlp6;sRSpM zrGis2cwh+BA_3)*g2ZA_6C<G@u_Plw51L+Ci!?!r5tQa2i58T0Zm|?6mZicfJr<CL zB12G;0EG=KJr&t9Ffde!pjU7po%tZO{@~>EKthI4wRV?-w})$j$qbQ;9CBB{BOkB` zL)2;)z;zpDgn<GM<fzXFz!6pgYDItpl7Rs+^MNvg4KE0>7q%FMA7dV43R*zdGNIT5 zas${-^e#aya}DzXP?7@6A`_V7<TcE}44TYU+CHewLvV?if+Ge^K`{p^b~M=#aaUvp z5;X^v4J@FxCZsrFxy78EQ(R;Rk^{#uD>!Muy62F#OA)9+f*GSAQ_4Uw8U&6}P(NFo zh+KD9Ky(V%43moj3ReUaKrI#C>l_l7I3#9>UF1->!l7`1L*WMYb^s{Iz>#_a+zzN= zY~ZW`)p76;1Wgto*K;+DCGbdKU_i}3MZP6SwJ%~i1V_X$u0e}z<lYpbH-wr@7QpLS zxD!D=AG97~8dD8JJX`@5KOpt<(EQ+s>RaS~6n5Vs+8$^oX`-3L0%}M@9905p^Fo+N z{XR7P*kS@Sv;)<b!dk;Pn<0g5E?Uc|iBXe1k(ZHy0oyPgQbz-Iwg6OpA-DR$GX^@4 zX#gG2lz?A;9{R)qNRMkKxIY7%Fi6hND@n}E1J~{usiaIEz?%yoQ$aYt2-G_R&0Apf zZ6FR+C<e_|l;nfDbKt&aQhrGW%DhN1qAB60$#{zeqygLoDpCU#sQf5Bxder(K!wzb zg4E=a)D&2cGEV{I0FYB4K>=z<X(Z$oV3`j|&{0SLw;2#KA_<zB+~Af=kshd!V9v}d zL2sXcnhxOB%PsEQ#JtkPoOrM!^+D=6Ln;eWU5kqHi!_-aZ4}U80C@Zg+65|-0{KG$ zqhd%cLFw1kfr^w@;3DON00WQ21#bBpin<*xS9s(ufYDuE@fkczlP~gWUE$UGz`)PR z_f$Y=3g>kJ)k^}ZD{L+Zs9qGXxFTQyDropG3MgI?P`oandr3fdMb1S5>nj4*A6S`{ z_>c_z%EiDdKA~o5=tUmwD?HjCn7KekO3koY!8pf$t<xIMD>{zQeh(`bCm*C&1d()I z<9|`W=8AyL2X-z_z7JdsN_t2JKNXVxz{DtMzd+{$1EZilg80bHEX?<XfmxXE0|PTJ zAGrR~<OcOg^AdAY<Ku5}#e=4#N^?MLp7{8}(!?C7410WhN`7*DJS3`&K?S!phyc~P zx46LrZn>!?8Tl#TQX3rRMG7DVpl%gp91%2fSOg+K8jC?icmo3r-e9r6z+(SEM54j> z1B(Q!=m!QkVIs_`@PPqNXtA=Ye_((Ud<<-YADCD{I#^jnKQc403ARK-1o$Lb`93gU z5@2tG9iquu<jugqaEl9+B~vR(ax#-{v4UH!kR%97c(?eVlDUvsM)3R(csLYn1ezy* zao9lCM%WbvFfcHH>i^<t3=9k(m>C%vZ!pMSV32*lAb0_W9xw=9fT0@<JPly@fsKKQ z=L)0L2PSDoqYn(yj7A@s8JMKuA~phy!XFsmgjx(EBi{!G?BqwV_!lsVsRkSp0JPkQ A+W-In literal 0 HcmV?d00001 diff --git a/ctgan/synthesizers/__pycache__/ctgan.cpython-311.pyc b/ctgan/synthesizers/__pycache__/ctgan.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8119686edfd9f07e4536c37ba0978bca11bb7798 GIT binary patch literal 29580 zcmZ3^%ge>Uz`(G0NmcsMa0Z6QAPx+(K^dRp7#SF*Go&y?F{Ci2Fy$~tF@kBPC?+t? z9K{T#S)y3MG+PuqnC6J$NMTH2&f&=AjN)Vj>1D~`%H@vY&gF^X$>ojW&E<>Y%jJ*a z&lQLg$Q6td%oU0f0;^-q5zZBk5(cx`azt`PqeQ`M_8hTX@hEXHn<GaeS29X6S1L-1 zk%5UJl`%^i>Ru3A2F6a|1k1_7_{$g=7*@kL4C#z(xRx<9Fsx>V34zRVXGr00VMyUg z<wuo^l4D|UXGq~~VMyV_kdp_?@wYIf2w=!5faL^R7*d2%CDF`OOkoOU&=juXbq;ZN z^i#;qPbtkw)q4qYkDn&vE!O;klFVF9j$7PLi6zMye)&bYhAFq$U5fGx@=Hr@ar&et zW>*HK`h?zM^U2IhO)R>_<_k9B77JLED>${VG&QdzGco5DS6XRaa!F=>USf_W<1Lnw z!jxQ1##`Jji6x1_iMa(isYRNMxA;NakfOxA;<Wst+|(jXrd#|@iN&eGm3bu@sl}O9 zsYSPViW2iu@^j;hOA<>`tJGCB5{r{dGILWkixr|(HF8qRQgiYWLEKmcRgK)#;^M^g zRLx>d=3A^_xm#=?qZ3PRabzUsrR1a*6(@td3&XIu_-w?$z|hVxogtMWiZO)&l-i=0 zQ`kEg(il^iTR5XwQaDmrTNtBQQ`lM<qS!hZDj1{KgBdhAZ}GZh7AF^F=4R$4mgE;z z$)Io*((;QGN-|OvzyYrJ5|jawSzwkkFfgz)Fff4J@Y#ovfnh4+bcPZ}glq{Df?dO~ zjER9^HC$u?oP|)sz<`?K7#V69YnZE885n9=YM5&n;z2fnmDjM;FvP>%k<L)dTEY#M z024J#%NQ9LR>NK5!Vqgx%T~i$!V8uL6E!R~Y^W|O;Rj2Ai4;au6G{ZZ!eAl=&GdBU zDqaSLTGkr2T9zJDbX!uGQEf?Qgs3hN1{($@YFPUiY8c`{feYrRFa$GbviLopH4_9* z+CK-2GCZ(<3G$LA_bt}q(t^~YB2fkghFcu*@tJv<CGqjMm<tk<Z?P35CZ}ZP-eSqg zEG{VmMNbi^Y${@9U|=W$*>H=wIJHod;}%PCYEIfM&dj`m(vo<P4#5=Uga8sM0+j@} zm{T%yixfasv1FE{=HB9rk5A4?EG~|ZFXmuiU{HX7hF_*GRxv4AiFxVy0Y&*)smUcV zC5buNC7ETZG07$AiFq-_h+I(&5z;HDECQ82RbnWqP7mfKNd^XnVo;j+(ZFz*gSR(# zM#x1Dr7Ijt*EzH=acE!U(7nQ;dx1mu2A}v0<<9&L<{QFd9V|T@H^d}6SbDf`@CkG< z-xZO#AgOgxMEi<}_6G(gPH`~N;c|nAf2!LI;{_}iMO3eds4lJCAiS_{N!`}+9mN}K zw$xk@cDcyo(&2nVNUX!T!}$RR_f)nC#WNfy@=f8pz%PH1L%xIguCVNk;tRscAS*jK zIygRq;^`K1azSY_D9k}T5C(-3DCK^h!33%skaI{DC_+JM7-|@?=LqzaSi@4q#K4fk zl)_ZQyoPxh3j@Pycs{6Qt6^IJvI%S=GLgbm!-CTWwDgQx=7P<H+kj#w@g}0=e2|GX zjG$bd!c@an#l^r-!vx}0F)%RHvNJK%Fx0T4=Du2v5+sYWKuH<QUw|CSU=a{e!%@SI z62_oR0Fq#+;lLiopcICv=fL5L;^Gum5M0ByjE#X|HQde=Mo{tu8_ZD4S;M&iS11u3 zPM}agEiD#+(hkH5Br=6Pg=q~3s>&25P)OEt)o`V-S8*^fEZ{`a0B554loe!SEq4t! zs3xeAb<WQ%C@o0^7x?K#i7A<>c_j)3sd<SxC6#(bpaQ!{ih+S)CF3o&l+?1!<kXk{ z|Ns9VqRDoPJ-@W1ptPj;78|%cDZa%6(jN~K<w-6|O)N=`PcKR=$hgH*lv<LQnFr-@ z=jY{A#zVBd1lh5Yu}Fo1fkBh&7Hdg<QF6vD7Eld(iv?7j-eN0CEl5o)xy4eJnOc5} zBeAq3A7s-l76_T2ms)&_B`?1y_ZCZTYGR%yPmv<1<l-z!P0Wc;Ni0dc#hI3voeE|a z>43^aa2dy*lbD;7k{ExBH8H0kBk>kbW?o5ZQ9*uAVo7T8EpW*gk0QhgvJVtK#kct2 z%y_V)1YrDlXwVh&Ac{$FVOgXOa+7FsVoowB5WwLKQ(UEtS`=bcqy{P%JsBAo8W^4m zh)&^}VL9J!mfZ@)i_$t*q;)nhUz9f95V(W!fZ7EC&x-<{R|GseSZ@f3PUo4#Gb3;T z^Mc6baf{+MI9^mWyP|4#QNa9)fH_0~*CehPMl&ks*UhS1k$6$w;EKG#MFGPr0)`!| z55#3=R9+O<SP{5EbY<iY&5JtD7sXvVcpk_o%y3yCb5Ta`ij3X`4#}q?l2iPyi>O}` zQNJOgdP7+LuB6-?sS7Hm*Hx@8saRc9vALpRb0F-3q|Zf3pDU6+HzZ|2NJ8cVFN3JY zM+SCLt}h_s10Mr7&li4nQQi*D4?+x5%Jb!B$*o{o;k-d}yWS?f1A-TgoUa%;gPbAh zaz)bRx}@hNNzW6C7XpJWNP1qB489^6d?76SqGWgn-wj#q4U89M46jR@Uy?SzC~bK~ z+H!~cMQP^?91;)s<u7o^7ilmsFeHPLA1tqc@;Rt1{k)C=RBG2SE`Z6P(5N{IxlE{G zK$PxPYzz!3jM-rGi<!I_niy*su$M(Oj5tehhAemqgPQkIOUqyeP39^M$AW^KN^nj_ zNlywT`C!2WaE6OdfM#o5P=;eIvVdg>aD8871>(wqq?jPt2$H+wi$VTSP=F+3_O$$> z^2DN)Dp3@R^q|6a;H1g`N~#a|Mf<C}syl0YYG(vq;g?^avP9z|zt)PxOZ@s5VCV+F zaDPQt#SFIj+_ShBIIUn=!+lZ0<cfsJMSjyO{H7N;Ou<p5$yg)?au{=wCdiQ>Hbj5} z)F1*isN>^ramB|&x(M;{w|L^?3riEhP3`#jTkP@iDf!9q@!-}+ktfIqA5fm;g|=;y zQ*(0S<5w~kfpZ$Dl6C<57Tn|l6CfLl(-;^Sel#$^;0+eF3oL3MSh!hLJ}}@Sq{LVi zJ}{sX2F$Ex9~kfvU{8ZBMrr<oA_9~gKZBAldh;LL_-AV2h+;`$ZefUG1vmQHKuybA z96_nYnJJ}-IaRVSmO@TqWoi*xgFhLR13(slFevqb_@B=)(4wth0<WhTK<$1e)HIIV z-Y>z^)~{iR2jyIldl`^h^r-o?h9Mqavw~XND;fPXIg3D@;-U~xVgNNt!HwrzOliqQ zAdlZ-O3Et&)v32wic)h*HCZ4zf-&<JV?HF6fK(|cC^SG43dn>iL0I6zTeR+=?BE4S zHlP-5#tfE=9I{t9WUq6mUE)x?$f0qCL*oL6#$6$a8G<t$C+bYm>0s$$zbhm=BXOeM z6g@CUL}H@b6t@nR9?lNV4o+~=Kq<OFE&)f>21wDxfW7FdLF?qyGSx6F0Ht+s;2{&} z9V)O&SlUBY$bhL5Ia(0~SPfGRV>(j`;~KPR1{YqMh{6jNfY8*5QF?*eBSog5l7O`+ z929dAAR-b(n1M1hb8=#dCNns$ioC%B`K2Yrpukap7GOo7j9(=Hvj|ph`GO+K15$4B z3H7IUrB6uf%<lp9F`(S!&YYf{4(1ztLj9RtnKPI>^Lp|+nD6q7^i<AJo>4YKc7@{w ze!YwQdRO@ME^z39qiH335y;>oKad0cK|}zE03{Z%JBk89+#rzMpooALI^ZM;B3wWc zAOlbe9rX(=>K|BGSfxKOU?F%|S@l0K;3L2;1E+hclr5a@sd=eI;K2k5I1{~G0VPpb z!U5%1P>g&QqtO^bEem*r0J(fY>+`rU#7cul2tYX;>MqO?0$AT2Bm?fnv-HT+ur7f2 zN1z(P6xvt=+9&{Ojm?NsQb9)pP|ByGFi;YQmQabHGy+aC@*ox{VccR%EGS6LO97Wz z;BG#+2G`^Or~F&Iskuq1DJhwG>7Z`?E#7ok41u{Epms5c1#WkPH5RFYOa@neAa{UU zt*Awt5Ikz(Wg95#755>^w#*qK7dhmvaL8TfP`|{Xevw1-3Ww$e4$T{U(x7Yu>6_1x zy(pr7MMNFcE0+KhkgmDMjNk<-5UC5o8W(vqI-GCt@J}%5^agduZ*Xw;b9He|HJhP4 zqvE2l#uZ@=xbf({a&YoSDIY<J4TL`%kXk;1h94QLR1oD8JQOe*R+^09kVGkbSU@3- zrSuVm8v`qTl0Z3K0#f|IT0|H5<*)F|cd*>x7w+H&`>H4&6iw*G1vq_xh)|Fr7{!Is z1s0_bEF7RV3Kl|ynN<YT6v0hEYzCQ+Qe1!n6qNEm&jAN2@;D?X18$finZnb;7{!vx znj)3L-@=GmPe-w(2(&Onv8M>OFhp^r2(>Uoai(&mNT-NmsN_!NN|8ws!w~055pQ9L z;!TlgVTj^O5ea6{l)c3Y9?ndZ;$&cO&d*E9gpBhlgd`^Aq$+?$mXQW<^|-h|$v-3` zvsfXs7~FVH&M!(;0FQww<fnl}z~(6wK*qgv6v{JFiy+<Tl+3iW)FMznIypbLAU`iP zucR1kUSd(I0;qSJT3k|;SdyBeP?DdXT9T1kqz9IE%P&&M1zD2<>MQ3ID<mf6mzIF- zDk#cNPAx9hQ7FhsO)O4TNX|%2&IZdw`{t*l=49rjgM3n&lUSqxN^c6K#hH2OU{jGj z>+a|mtC3MsQc!HAub)^{ky)mfUzDz&m{hE9XlY=sXJBAzV63T7kXQh66Ii=rQF^fz zm<0+Elrlv@BQvi=6Cnvw5S&?+3JV}m7d}5%p%^@Lu27IzTnzRy*x880ui%oJmROoo zQVhyXhDH{UumG8bQer7+l#~|afWl7!G|CCGBQvk07?1Ukfe5Je#o)>)H8D9uAwLi9 zLRbT!SWm%G0mf1A0X2%@UIw|OJToUpAt_Y>GNh4$)vU6_oYK@{P<klKOi4}AL-Ch} zk*S%Ef|042Cc=9dL$5UPB1#7h;x({uA(J5x0owZ&5tTVb3L0rS`HAQW$0s$hC=Zm3 zK{*2)z@RAzxIWb2G)mPqK?EdRQG7~ja$+TR^Wh0jp*%G+J)=Ycq82vS;h2(`tKbis zN&-#bpxS1bs%wU@4K><G^A3t$R7+63grX3iw@{;(La!x(X3*k6V{X{8wqI#(Qfd(> z<AF;UNVW#&T4?E{keR0d&SAwRsRg(K3?x@fz<A_235C*v6wrVisB}ol2PX`$9q>qp zDnQr=%CK%l`MKb7CLC0<gUd531snyRxq+dcfw_gbsh$EtHA<F9fs|07s=yQ;YYL^H z;w-aR!4TxMTu{&{<QHY8XXYj5K&H+VGC^~Oxv6<2iJ-~`;liB!^!T(Q@VsMkrGiFM zetr&U+!D1U4bK2q4Islyi&GVH^3xR%dh*j0k`qf()ANfmlM{0kz_W(Mpzu#dRBb3; z0LNBlUOL?Q;D}2|2q`K}O-MjEtt_=DDZe<C2uEckmZd5b6lLa>D1fIqLH>l4vIRx? z=|!o<#R^5K#h`*6-A=c}oMIG9Qw#EwGYBM#lA^>+aQOgHjB2K_fdRq`1&PV{(-$}_ zQo#X`UX)*20I!BYb$FgaB4`v1R1$#-Ae4-VumhCFQ0grt`;$vk5+UJ-RyiTXXktl8 zYHmRZEONnGKt+*zK&S$!(NKz>WIz^rg4!LB76~{A74nM|oI_n4LFqCtzeFLiEHN_& z)SS^#a1IDnfK^(cbc);Ms7c^u18A(D=_Lar14B_R0|UczU1yoskIb~}pZC7qy5Re< z4R$ZTfaE}Z@DR6`Iv_S^SS%T7B9#Hu(E~LNKW_ssjHuyAVFV9`GNBBHGS)DFnzEqw zBWOCPh9Mr*{RMNu6Tk4W2Jpl$$UKM~EAp^V4O<OEJiIej!;U-@RKroj5D)JF)^H*Z z0M&5SFvP<<el^_4Q?WHXH4O3a;ieki8ish#Xg<Vtz9P92kmtcX28Jv?D1%`EatOc# z(I&TxnNyf+`SYYo6yXXOYWPZ&5IocYup;IfzFPh~rWBT1{u0ng6;xrC5R}1C!;jjn ztYL_UPco-~CYVKFD#1L^>@ipX%mYOnf+r4UGB7ZJ2FEm6{fbOL9Su+{fX7>Jv4B#k zro=72c$8M)ExveoSq$n>-V%yO$uwXwUL-X+MYjZyxZryDmLQs@oT6JIXhIN$-0_H7 z3PTCFa=FD9k5a(i;)sWqYPZ<qA*Ba+lDh~rSX5LD8X)4yERKigo?Bd*#hH1<C5d^- zskfMmONzjAMz`4GA@g6ic;ZVCt3Ph>f~S?@A!SO@EiQ!5ZgJ=27Z=Av+J%||;9hxA z3MfodL4-2cVnop86y4%S@F1QQKn=;9qFchKA`oR<NI`<801_g+r~$$b3lBEr-~fdL z*sH|?XmeMf!Ut)DlMmeUfe!&p2lY6=fxE3Alo$l0W;k}%EfBoKuXF*19>{5Tuw3Dg zz5%0^j6r<)2eO(VTIz;^K8TjRp=b!A<!-2$gJ{JE(yAa@0>+2ymyqvZ>EXR0F5AJ< z!vmfFf50#Nfq|7%X0GWRtLxH6m!yp@N}F7fHUSIX5S8k1>u|fvD?Wo|Zt_K5r7OHj z9~hW9Wp0Q{f+QY@N_V(D6_=S(v%u}5xb_usZIH<iWE4R5OWu%G1kqAAq~tnSdiZZh z%Y(QQ7dSu&+%<l_d++s|_iJw37J+6}lHmms0|NtSECp0}eLe;5BO}jJqtElZFvPmn zvX+2UL-S%5DEoogH4HV3HB2?kH7skGma#H0tcKfB%Z6jPl?}CMz&>V$+V?JIsbNcI zVq^$r=y5|H%}QZLHLHnHlciLQnSsG2H8H1Hp**uB1KivK=bNO=oXnC+P(u+^gQkO; zjXCkf`DrD&i51`mTz*bUYLP-&YEf}!eqJ$HA*fpn?RA4v8=kIkW^r+8Dx#?bYt*F_ zWtOGtrRV3T=cMW-=jZCDW#**nr|26xI+~UTq(vEK7nUXlm4p?QI;WU9WkyE1C%dHj zRa6G(gC=n8ON$F^i!&07K)q&o!wI$T3vyRZetKp}u|o8thGjYmd8N7WX{m`NrA4X5 zu?QD|)PZnl9%xQFF(<PsH3i&!D=5lON(2W^W-4Sr0Hh3rOA<>l_50=L>E@-TCxTYf zC=@3rg8KHLRs(39wX`S|?n;mW8Hq(HSlb*RNzXI|P%}>l979E^C8b4qsVVS&F|3&Y zE?81az^+X!2KAD4GxAFm%2Ja{@{5XfP^<%)np9c>Z(_hlJTk$<8i~aUnQ01{C7^~0 zC<%b#9%>(`jS8E<M=>%bGZ$N{4P=v}LP}<CY91)DC?w|O=cOx@XJjU4D1eLwB`KH> zAPpy^uD6wft|3yQ3Icl_X*2_5IAmd8ib6?hUU5FScT=30o2pQpkywzbkXQ_gI8cnI z!8#P~;N+|u3{B1{pcPk{Nu{U_e8{vU%Pr=T#L`<Vpvbw!oLW!}9&4?VMavP87y&gn z{Sy83G}&*l7A5ATrxt-)Ah-BX6@%-nTdbMId5L+qm=p6VZ*hfz#vNRXit>vz*>ABy z5@^w6Pz^H$L`(${ptTjZ7~^koBWkXcVo=Kn6auhGkz4%nsFqcULJN9S;S-?pz6@O6 zf7N0T5}!~sQF4kTs5Dgu6CI9srBtp+8E<eqz<fgGLP+?9*n|sGi5I03uSg|!c;6M4 zm|{Jn`l7J*6=Ch`!upql^)Ct=UJ*9@z`)3>d_!F6f&dhKU}of1{>sL{FL^=I;v%2r z6+X)g9G2kn7G-Fc3$`Bd5V*OS&QQyg!dS~(!;l3oIY9(sR^Ejnc4sY14T}py?3P;A z8dgLJn*wc~7CF=~*RW>6OJVRDCAJ!-1)ydv$S@R)JTI9An$-Zy)-a|pr!cpI%4@b( zrZgtRsv*>J6qU!wP{UTs+~bSlDi%nixyTi(s~8r5q7iH-GJ)Iv6xKCpYanacarq5> zO%~LBb5T}CvLm~8HbV->T#UM+XBSE+A)4JxDV%G#P@BBL44T|k4v@k2f};F_)FOq% z<dV!Xa2Hwuwm=uuaRQYr&@n#fh%CJM1X<y5i=`wnFT?K^JG7uE0*$R}GJ-Qlm3(Mk zYDGb6GH4hUn+3lZ^}q|CAQO#TMJ=G_ngGPU_{_Yt{CLnX<Skx=-uSdUa0V{|4UvO0 zY7uA-wrCzG2hIW!v7iYZ$n;E62FN@%&`3&r$t|Yhl3U!#`8lPzd0?||aTXUOLYYjd zDMfQZSq{{YfaX`|G7SEBa9=eZ>V{N!hD8;A0@4kdXKP@1ASm3yc7vC%gXxB}{1s`P zE7G<f7<f3786jkc_YGN%4-DMAAOR3Lf%As2%oSm+6=|1*O)dzV+>n%>kvS*#qNMf& zryC+t(|srTE>&HkwIXDLA{YkkVA>G6L1RtaMJ1anN;VfoY_EveUJ$XpAt^n<X@b)Y zQMoIkIvY4IiCSI|wL~#YN$t9l*(D{j4JA7q!LaCn(vH#{9$V^87++8_yQt)UMah4H zH`LMv+80Iiu88Pe5YdAg{ZvZjf@<P{$}8@17gQ4==%Q586{(~PB1s>_L7w?4#ULQs zQ#nIyLEr^`^^5%KSNPR0aHvBH6y(eeDtSTK|ML%6fx?o)Sj(EiSj&d9P+3sRUc>If z5Ieh;qXgbGW58I}n8K6|GMS;6yM_Z%#H27oM2om<SZg>CGv?q4N7f!~G*Qr^NzNKJ z6qO(`E|hGBTo|#Wu(UGOu-0&*7DmX6C9**23tWcOFr`4$v(<2+s!w5B!-|?+Q`pw9 zqb(lgu3^Ex5V3|EqqyU&Wv$_=W$T&Kvk=8CpduaO78Xo9LB%&@iTDCgOo7!Q6R1tv z8pZ|i-aBqp$cwnDxEVl;u{o-2kQYZL7v&chE2QR?6y+CGDuB9n3i)~9CRJKyMQVxy zwEdK-0LjOYA_!81EM#C{@B<elpmmCQshW(pIEzwKO2K;#G<m^gOVMIbfxw)Tnpd<0 zR91khY)BJ;7i@Dp)MjwOa*HjoC_Oi^0$kMGVl6I7OwPW=TwIz9Ud;tv00}8#7?W=? zf>ws-7Z(?S0_GM=aY=k~(Q=R-D?r3bki9IaDe=j-m`idCa4cNpiU*A>#)G_FB>-&} zz(sz53K&iB(!>Y+f*mZ70)}6pg9WLS;pJ5Lz`)C?04-y9cojY{@bD^3;Jhm=Hlg~W zu<{jQ<pqV8gf%V*Ypn3RqT_s^@{&%#1)YEgB4RT%=jdJ(Q3n+v50FbBY54`3OLQ+v z>rZe2w>gY%NXX0x1i=MN3yc=9&WM=9x0ZDc&xXJag&?>?Wk<;l%?&kM;!h}E(D1pa z;d4d9=R!o>2WBQ|v5yQ)(qdmg#02LLYz(53)4eBoFJQhXqH;w<<${O`xS+WqEkD6! zg3AX<9$tm7GNAILdWH#7dD6l1KtN<d;Y5!9x~@9N<n02L<(i8$7iurjUXi#&?;^j! z1r7s9nFB5fplK3R{($(O)4*j8ayPMtu>{@?WMDuvSWt6fkt%XG5PeM!iX0=Ngh9=b z;N`^3d9i4=!*e481EPt8X1hr`LoIU+lM6$v9w>jNFlK{2RxAiwgq{U&q@*A<g|flr zWRYkMGxjCVH7uyEV`S*j&*MklM1k2Ytzj+_DuIu3F)*aC)G*Ixn2WabqlqzvHJCw@ zEis&lfx#I(ms$=TZcj*nW#o8BMlMcBP)M!FEG_|$uO}pc2cp4)KA>@PcpnC{uZU2n zppjRaTTrQ&my%dilvoKGw=GD_OGzx&1MMYpD*|l=!tDQrz(%}5iu4pbOB8ZTi%US$ zkckRLy6`ShVh*SyT2QF~H#<=Q-E`F9Uvv*DXn>kBx;dHIsc0kCpyq~8W^oCqiw@HT zE#JVy=SiuMS&))^1!&s~G6Rk@=L`=ZkXaxc2JTSmDR`zSpav$Wx0tBl2MQ^OUqBt! z%wouNZ*fL`X-*1;VafS<pk?LYVi>emJ3X~XAu}%}GdUG9tdIdKj*Ima{6RDA<(b8) zI`G~phIxq!2?^+dmym$z5Cvq5A?ZW`v;egj5z<AVJPR)Wt1Qtx3h@jmCQ~y(j#4N{ zL8Jsw>c}h51IMObmAGeKSz=CR3Ov-)@=NnltQ7n-xr-V>B{x$+$}LV%3IQi0&_<t< zTkH@Q72jgYE4alDO?9_e!F^gtmBATOS&$0uG!}tcrkbqa(aBp}h%|c(GWM8Sk_yWQ zxA-$5Zh;H1LG1zMV^GvWo262qG0#lUtQ}+o7s-q&$jq1?ZWRKI3=G9rAiYZo1|G2q zSxc3c=v?I1zrwBmfq{cFlJTyb;sqt6i*m+S<czl_@6ftp<#W-@_llYCMLFLN{|<l9 zARpIKrzKt&dG)XG>VIHh<BVjyA*ayce?v;^0|OJUF_@UZc0*WlhT|lw304clR#;up zw%*}*MceL@qTK~WyBi|P*G050iD<1byC`CGMa1ZWh|!M^tW3PdUpW|rWiH5sUlfYC zA{23fC*lFO)CUGePA|s0LNXU*JvP)|koCAA>v2)Y^NNt?1s=~2%pggK5^%A<lEDvU zI~}NW1(*Cz0^mjGI9G1gpskx^WT;_o;6Pq-Udx;(h1~i?YGE>!z}t5WXxs7_8IaqJ zphg1H#u%ifwKYhkKdx>k7B?d6LUIGRdH@p`Yf>4J%lH~bWP7tfp#{;!kOgo3p|(I# z=4+5_&w_6StYs+yZ2<wB!@y9(kOiN+sbR>1w-y(G#=)U#5EN?ofCdL@7_#771#4JP z-Le2aQ-d%8Y$JS@2DOf>VTgy%*pz4^^nm8T$Z{Ecj;Dqp3qHfMhM6)~*0PouBHWe* zAKj>7$b!3m0n#iPLM>|O)i5A7C8w~|Fl2#NID!?gVMX8OhQmfyBJF|4H`pG~1_$i+ zu$S0_CBQ@$XbU8mRl|@4pR+?tFYuUL!-g8FwH!4ZDeRyWQp;K53N{8zWPzrBz^rtJ z6plrVYdA5|0I2m3R#5_4zzk+FFqD9n^+DM+XoD;@3|XKJMqpVUh7!;UV=$Y6Aq%uw z0?bC=%3aHatgA>Tg}au!hHHTqSTPz=!@U40H6kQacv5)Q@S?T_QGLV6Py$+-4b@T- z4P`KZ7W|{CL=I!{hDUZ(b8C5OxLg<}u*UY(@|J+cQlLg<=|dR|DSS1&n5{iV)R5yr zwbg}T0!yrAEgwqgGGu|4iGq#AP|XxOzm^{*mB8~t4Mz><8h$ia3!s!Kpgr4QbCC&T zpQGAWD_A4gpp_yJ!@|H&D^x37D^e?pVxBRwIWRW3%tWMDks497+$5Iwv_=rwOceis zCVs(2)d-=whSbumRtSf=!W5b-jAAZmF95{70yRSEOeySZ1kvN51hkj|tcHOh%N)u8 zmHhBfLUmECcu5XS6$3*SXp<C}yTB4mp%FFWsJa<JCL`(*;^IggUW_u-h}IzS(BrN~ z2wApPqK07sywpW@Sd9Rxt7;`{7#4u`cOfeRvuh;LR7#<QFCsmNEwBb_KqFGn(g~_* zB&Qm&8lgr*6c-?3SPDlBOXE;4RwIEmMabl_)QHu})JWAzGcnY#)<~nexmFe|DuX5} zSHrr%2F>$eZjDThbPYomye>elZPDUS9>t}$VC|sNP!82D9I9kdRe^f!+zhn}HS7yO zyY66y!RZv?H6o}w7#V8h7a-MfaHR|>qHDx3)GA<Ci&~E}GSsjya6p)ZV4<s<z|^x4 z#b>lGJ2_Cx&b;!GRqoNEY^V`xR08D=lz1eVLs3&!t!y4Q$c!2(bnzO&2Idr=TE#r^ z8ioaqh?In2)kxL|)d-`twG>fYgUD5A`GuRIMiJa1Q>tMjA;%~YnG4v6%mpkx)-_5C zoDhCOuuyZ%0%wE_f`z6Av_%Xd31y+Gsa38~psRf<*jv0cyy=WJ%IS<XD(Q?V;&aei zv9+o-3|a8DZx+1GTx^PBnj6%MpjsBLx<)k}ZIo{TXyYYJDV$CbT_b^-=4u$?;ilxV z*Qg@%OD4iKGT>?h*C^Mh5NZRXw6wKIZDkW<!vxlzz7)n_22IJtKqdwTw@k>+5%5^O z0@B_V&<RSZIpE<t&<rYg#0Rwa05lbev{+(vJQKX3adbQrJg}#rpx{^4mjGEVn4pkY zlwJy2?+seblaQ8KlAw^AT9T2UqL7hTtdNwNnx~MGT2Pb<TB4q!keHVOUno?Rnwwvi zngY_7sE}3)S~>+=2d1YG0NL9F+9;#|w+g!ZH#tAAxTL5Qv=&JLw8$(qF=Ztqcp0ZA zQxT})oPPEU2)wb^WWL3eQ*?_pDYYcA_!bXjO9W&KMG>eA3m%dM9kt~OT6O-553w=C zC%?G3O2`Ggq6d03mI6pjld)(YXy%^j7JEu&F=&DHEf&z~vMP1=)I2=}jq3E&JS#n; zv|3Gt8U^Tv43Kb2W-&w<JQS?Sbc;DXH4k*OQgIbGc*GhUoC+(Mia^aV&{ESBO%^{t zKj>5;J9wLnv5_0-46@?XlK7JR_`LkQ)DTU^B2X_<leq}ADZl7E0|Nut*@zRppw7O< zl2Vjkpy^q395m5a4;n7!D2HsbDZa%5V%-vQfksqjeqK04stB~C89XO>i#aW`<Q6A< z(NGa+01Z4*SyT%$u@z*ZAaqRxd_tnQ2((YI2y}eUEuNIjT*%B}S!!}oDM&q2Nq$i? zNOv2E08P~tfff^i<FW`e4tk3Pv>E3XS3x3ZVx%OssJN&cq_hG=RDuXl1rC`9Wvxmr z$}cXe28r$l5qm&HJBZ)`5z|0K9mrbd;>zM%Y{exhpi|-aLBW}l3OWO=0JL(e2sBD} ziy1Vd3Of1Y7JE*9dU|GF`YjgFN{^x%kO6Fv)qJ=3AWMD0zA8&izQqmYgUZiaT%dJg z@!;8>BG3p$(SESjMDU8PTWpXWaktouGfLCaa#D*{L1c?dDhpB}CsKe@0w>5<;8S95 zaU>-sXM>K~0L|N`78LCS>i}&KOwIr=%Leb-y~PUJO;CA@Ejd3gIkDsxcTp<HhnaaP zsTH?a!J9NR)xbmiMTbElas)(701;^*^LW8?JFw&6ZgGK+)C28<az|pjfKzGFRgeZg zaLSF(ECH=W%rCmdjTDx*_`z$A;|o%YAgis5qCi@hGgIPkG3MQ31s%mx1)8GBO92HR zQ?lVLrd-2Y%*pvVx0v$_N<e2VB<9>=DM&5Ky~UJlbc+R~;}$!JOiV7h#Q|c(gM7&W zKHvw$Vylb?iQZzXj0cI`VlGI{1FeNCE{;zrS`V_G4K#BdpMHxEI!~2f1fI*jC4}O2 zh!`)#cVHbUX~m!&(Xc`gKMh$UeFu~uIl=X=URq`eXkNVd0Xyh0ln+u2oV<{Y(JQzv zn7CZzaJ|CedV$0B0k>d>-3?L6DKVhM$WM7hdVD{yFp5Sqer012RQ|}##>@AGfsL2% z1Bm!4z#yx0QAYQQj4o&yvg8d8o+})Z*EwV_amdcdx}arsk;D25hxG*x>j$EeS47pW ziyB-KHP|3^QPk#&s7(jgQ(280JYq9gu1IKJk+9m~a>d5~BQuj6*B1sRIWCZiQXe=N zL?o^Vt6UK_>)?1QrF?@&;0llY0+lODCfAi5FDW@*RC2ze<b09G`9R1O*N6*INgtV+ zq(KHqb3Nzc=~C(l>4~|>C3S^M>H{kSCvOk`R}KbAwHrJl6H2a#YhD*OyCiOQQQYE+ zxWz>tiybOgtO75DM_&n#|G>;7$@P(eNfKn1BsR0ezH%^#Nl%Ho!6VS)f1O9|5|7$N z9*rwJ8aMcbukb705R$neDSJgy?*k{R7}rMzRxz$GAfkioD<4#!<aHjEOFSwUdDO1( zsDX8<!gLv7(<LC<Q$54vinR8MkSn?lCq&N3T;LD8$RBovKdgi02A|M`z|M>g=BLuK zb8IdcBtnsN(naZ{E7D1z;FNeODLXfHf!oa7Ik`Jj_UK+Pb3dDWChLM^<VDHIE0U2N zd^foHdu&1b;C&fyNI7?S-{6tD&ZBjSM{9-kb#2>A+O`K2PY7Pnw!NtBeMQ^*B9G4% z9-j+5K2N1&=O!=EnwdW*e}l*sL&uA9jt2}cO1WN<a_!*1At?WWA&HSw=p&fyaJ<VS z+~adWQfoos1xYQ?>3BTaS9r89@Mzx^mzfcNQC#<mxb6oAA4XoKk6>~F^Ib`~86{Wb zEiOu0UXir?z@RIr_7Ox(aJ(y{ILBv2$`0lWGBy`wY_7=IOmKZ5r*uWm;JTdsB{};8 zELR)?F3JU7kqeyQenVVtLj4VG$BPmgGni&LEnr>{d_!LOioD@<dD~0!wl`$dR<K;r zu-Kuzqxd5;lY-P21||im87wncK5#JT7$fPrE30%xR(D0xMOniuvWC}XO)klrYzWyB zb5YjmimcNNCW7(`s!KRmm~K#7p}oWKf`aWu1=}kMwli3t$}7&PTp@Br)AFLc)fIWG z87y}tq~>s5m(aW<p}9ikqJ;hx3H=Ka`Z!#jvPbhk@ClJ4k;tz9z{Vh_uz+Pw#0;hf z((?1|XW6eXy(n#XMcQzJ%MCfD>vBey<cv0$T$Ho8B4;td{i(D9D22JK_g>|_gY%+} z(-j@3i_*?lq@AxzyIqoYJ5X|@=Av}S73q)(E)T>dC)C|l)>vY>B7H;Air52;7nGeY zDmz_KcACKf5l`9Rv?6Lp;00y-i^}#_l<jA5%;0#StUiO|hO*WSjt8<@Gh$a*T$DAr zB5N|i?FK6Mft31nDceg@wl^dc7HF(+S>v-qWse3ZgTofbd}U)$P@Pk?LhYiw;T3tq zt*%?Vc9dN-al2yTc2VB#f~@-lw+A9}*F`igiD+yHy&$4-QN;R+i1h^#>l-5C(|soS ztVo&Qb5X?jiiq(A5o6GbKL1JnkP{g%h^XEWk(-XpzbK;mR6=$R&kB_d&KD#sE=pKj zk+7J+dP7R?ij>ZEDeFs8)}S)X=c1JF6)E2d><@&cCfMIlx4kH$GQnvE(+uYY!V82S zsOur|?@G#Dk<?tFbWu|Gilpv!NrOw01{*}S$Xt}PyCP{f!HJ-}jN%;61*I#J7G!TI zydYzFQO5F$jO7H^2jbGx>nGJOD23#v2`o3H<gQEUU6RsUQF2kr<cgFD$a$g@?7%_3 z)^v^KhV&gp8)8o|UeNZssO@z{+v}pN_Z3<132qbIAhBeAQP%j1tnmalh__T!C-_eA zeW0ST+<%e(2A3U5JCZJ{I9^e4oZ$OZN`8*kg0PEHI#;B0E=cOEsJ$p@GlA`afYfyX zl}iFD8$>S%s6ftryC7hGLqKdg|0Mp40*Y4z6fX!U-Vl&N;a?O`d?2elKXz8^%=kI+ z*JTYZ$r@agHM$~ebY0f`lC1efS<5T3mK}aK6xEl@E|Og+zeN7JqRAyilZ%RGR}{^z zD_UPtw7#fldqvUq0*~BN8M!$=H>9Lz2;WdPyr5)sL*4R%n$-t(7A?LH3@loF9~oF= z`M-dO4&M)4;L1n&im>Jf23B6fyUMCdELWsl(X%_idqLOlqO$)LW&i8S(U+8?Z%E31 zU}jK>Vf@Izq$2bML`-10BA{?VK;Z)igNh-TbC*YOg7$SGl}kb@E95TN`d<)IxhNEH zMJV7RPv8}vzzaNqH~58nYCkYA^9tP%kh?CRcS%6+hKT%i5w%MqYAft6*auz^QM)J- zbVVfSBQujA$R0tV4-8C#LO(w^GJ}d(NI?sp+(((&2d#?$wP-(|VFu6aGd0@PFgMyD zuOY>C_EHVn;W4$K*-_An67X;Y_T%EfbHt!^WMDODGa#&aphfYl;F%$~8t^O|6R0^| z%a+bi%Z|JbpoR@~l*EN$0%NRaEeG<pAjG<^8l;2jY8Vh}Nzq)vS;LtEv5yPI98R#4 zL8IWDHC$`Z{JsD_=MPf_qHB=n|51+qt>vy^Z4|9xC1%}cEl-|&4bqyr1)wet$aoZt z>^sy@spYBRX<*EAtKq5TMez}+69qP-h8I-_$-Y4~7qL=-yM{NNDTQeb3tAX+*DzpT zxq!a@wwAAkVF7&Q0@$%oqJ|IE&9(eB>^Np|_!q$EJfNDv6zb?B_W3C^HK48_ST8bx zss`0nj0_VPd%R0PJMN%bvfwKmYS^)EbYM?mtYK_LT7y{2j*=TeI|QI+At>~15Wx(Z z?1}MA3=ELHLy!YkiZgRF!8^1dXD#9yR&&nC&j%gNqL2t$LIyd72D+XKJWv8!Qwu(t z2DCz|G!-&%l9>lOc%&FI7KPYF2HRo>U&RLQ%cPco*1SOvxPh9Ikq9~|3N*@>30bm? zum-Y(3-7>R9<Gg9h!Y`-^2@;^ijclAVi6puzl3g=f<|#kQD$B`mbHd{psiQ=X$pxL ztEu6BL^Tn-xt0h6K_i`@(WvzNqDloEVF2A#Rt^gE;tbFZ6XbA5aRA5>Xz>6KVx*wN zy5S5wIDl^{DBLDL@K99|XfYgk@i=4{iSZUoVnt@LrX*yDMm{+|uPn8w1e%EAL8}lU zu?`w9%1kM01$BQlpxs{3V!phTc<?9%dc1&lo<N3Dz{4is?a`nM55Pk?MXx}^1rI<3 zX!Pb5OLBfe<t=eYgvY}UzJQI26ukv01do7#hj~B)9pIIRMPES@njkl^f<po_Lc$Fh zxJb=QEJ?j3m<jSxaVlb;HEj1ZBO?RDEdg|g-{MDzfD^<mK7;^h0OyuSdQN^)Vh-4K z(A<@p3|co2+LQ;{i31*^Kn!pcF@erCU<MH^Ac7l2@PJk(f|i7VVjme7fmYt!VuJ)l z6%VB658+(~>A3=48UI0zfk);7xB3eI3p(yMWEHQ->Rp%hyCmy(LqK7HKBx!I2kL?I zb@+W?V^Gn!qHKCyIqs5j+zkP_1uRz-3_db5sen|d@Lk|hxWEH$;b>goQFtIAI3eVU zsOm)l)fJ2v1vEQYZ-Ca|-xU;_P&!d&ip&QFVNQ#WAfm$&viGj!io8Ck!{~BR*!7CA zYlq7XA+Z^P6E!-V?}|uH@mY|vf%$@n$wd*9D<US>Ma(aWnD1cO!*fx@^@@mVhwB4= z<qQ1!H>C8gOBr91GQKEfdPT}~0{aB^2hbL|%Oxon(6~V8MX9hWQehL=AMlHH)ZY*g zc_1n?f$6TO)D=<H1xXi0HLr+jUKiE5B&xF_WKGOPQL8JWRuh=+O32RPTwuCFX@T|z z!wV9o7bQ%uNSIDweIP78-F}k&0@I7anpcE1J6vvvNnIDyx+JEx!sMcu!4)xs4)+^8 z0{z}y-ZR+e^UdPBD4}shLgTuG&Ls(*ixPTQB=oLJm|T)DxhP?FMZ)YNkNFiI^9wxY z7kSK|@(cFV%}7}xc!6L0BER+(er@o1YZII%1Ws`3spw#N02*6RxXy2Mi67Kw<F~lN zZ*hUc0@C1vH_8|ww<3TWe4uUq(?L77k&n1QIvcoA5qa4rj#Gh=_AjDWe4vRQ5Y}X> z(g{w5%=9Ec&L>C!m0Yle*&6EZ0io)kCB^E_0il|DewxfkgRh|6tO+^h7qamm<QM1> znYXw=BZv?qVLLt$B0NkC48@>jGz|<7pu6z7`FrdyaH~NFX5=pM$jvC3Q*)6==L(O` z1sM9k%)q05fm;nT=s-IJz(F^Q1>9Cf@_Zv7a$q%rYHW~0!4-4^Kl087^bPe*j3w}v zC1|}_2~tZE=S~LDIu7_rB)HbOqDDw9V_pH8O`w7T>_P^HEVvs`TWUq_HH?V$tjIef zko}8T<qF*+rd<M>?|_<}!i>v|*`U1-NJ$Q1&TNJh*14>xK4@ahLpfn21+Dep#E86N z7HPdKBlfmt3fgi0sJ7u~buu8%H9=JmI(rte?*)5{6D1rt5J7}y5_1i6qYDndpv^{; z>`E3CbKz^UYM5$Rz%6x7)cCJuLRsO<gflH*U&D)bF%QYEWvyXdfW5tol=l`O0uq|y zKotH~IZ7zNR{^8yF=YT<BviwSzE7!^ErkIDQA3K6A&*;`!H}VtrJSjPIg+8Ak&z*i zfsp~UA$bAPxi}!VK`^Rq3*hUDA+it>O$~g7FhmwYqN)LHPWCGWEiQSsV&T(y-A^~n zdNya(i*;+C?{0p-r{~4`?N3`fUTkQ7K4pahWM1^yw4Kj3Plc#}=y}}G1a8-rg8C{i z_Dz1;y7bx94bOVlJ)hO`v}5|SsT&kDN<r?`RCuzZ52_PG0jPTkQqcRfXXne--p371 zrJ&U$$mTuW-}-Xl?q@UBV{8Xfc-AouvS|q34t>6H_tSaZAh%ruHPAriE2PZ{I<-Vo zt_a+U1MPu^v?qCjO7lP$8bP*@fHtBOA*SE%f=mESd=(u4u?~WWdmsX|)uiYUhy`w2 zg4>86L0nKp1gV!~KrJnByAjkKM9d7{;?K!XOo3de6rYlrTyl#GeCim4&7PMJo-RYQ zE{kLt85lI#A#+}!7Tzsx&;(sP<d}ydMAPmUs42=;3_9Mk81qO3_JX2RkZsW0r}Us) zAyD;~&IFp%0gc_;U*J|+ka9)M{DPYO0h24vp%<KE?+VFXkdM146n{l1{sK?@S2oaI zmXx_pb9^pxYhK~j1np%>V!R<NennX2y0G3QVbB<%u*nr+lMa^-m%HGDYLez;UF6lg z!mA0|!;-`Z8mFq7q1aOkY7ZJ-QE<M*?|gyZ`L3A42POtlU#1l)YqE9-uF3tt0OCQ& z3z4x`BD1bU)<C$(<VR*EanN9vIAk#E12cnw3*%QV@b;F3i$aN4gc2|CBz|RM5SO_k zrnaDDdG(^|4VD+xZLg@?UKF#tB4&48%<+<#;{lZ;Iv2(KuZa0~aDQN95Ej28q;g%z z=8}-jj*vZ3Cm8p{UKH}aBIMn{euG=+0=L{<L9r`>iVM^)3Tj;u)VeOHdr44tMar6- zi-Oiy1g$%q?n=qe(OeL{LFs~u=?=pSQnnYRY_CY!c6fhf15YA}U*VCRk$8ni0W=|H zxuaxH1rgImAJ`Z;#jbOxUgA()kg_EIx|+i!HHV98PFK{NE^;_u;c&jd;rx}2fls(6 zwllsb9<*J@08HEvRhkgH!1|)7{uNRE4mZ$l9*6{3_9Fu$r@<Ex0p{W)e*E}>RgR58 zK&-#EtG2Vgr~W#>(j|VSi~K5A_*Jg+YhB{ky2!6{g<t0ahfXnQe)_l(j|Ug?2`&~7 zKK2uWj2?XKntVl|U6w_lmC6tz50snoK?H(T0OEp*;Sjf?LJ%8NorGwzfJ-5_BG3Y- zBG4F$CQ}jUXb(S4CO<z-ew=5P-{Jvfyv$_CiVw)?2H@@;Xlfd=(g4)<DFRJ{7O8@= zvpR^-2N59i!A(DKm#GMJJ_ne{0Bs-^0e6#N_bPl~(qZNJzyKl=7#YN+Zip$|kW#rJ zDtSW!G{tvAM&X9A*bN!k8)8y71VnG}3*F%1yTL2?L0+F#;R6FIk;KTrCiH<xlGW-1 z1C02{%)lnp5(5!66Jpi)zyK#S#8?eM*EK*1194U_#t#fgq!HM>bS74z4-5zrZX{SJ zh!w2JjMef31B`&H00|T@vasrcj&Fw&JW3!gCIJaBP*`ay6@fN(6ukh&BIs0yqMINV zsHlT<quC&P7a?UaxQqogsc&(VfR9(o%LCtP0J#<nbUHv0sL@>n+LHyo<p6SqIA~K< z5oqzmEg|p~UV7ki6Vjb60<QqMC59vcUmTKJ1YS6DOB^CujCAprUQ%LlY7uDUvIvwk zZgD3i7H1|q=jWwmrr%=m^mFsS#T#5wl$x090lDw42weDs4mpRcB>{OJoH!w20qSr6 z;;_jD-{WUjWXZ??I<JtS_$ng<!v|(YM#dWq!WS5XZ!qvRfZ+`Wt_Co?!NA`Dh7TBo zF2E2(@dc<Bo(3?y!61GC72RM+x<D?v!61486+K|kx`2vqFc@7xMK>56FJMCt7=#<x zuP_MT;1cZ7xxsC6gHO35{R*G*MLyLle5x0C)jqJZFfx5)U}0qX0wO+u2`+6$vkwfI z#00mGAki-%0wN#7$j7MvfdP}~DEkN!{Q@E&@=5}XA|Dtqi3w~UL84zk1VmoVhn*2L Wzl@da@c+mFlK%oGKVXpHXa@jQ;6yk8 literal 0 HcmV?d00001 diff --git a/ctgan/synthesizers/__pycache__/tvae.cpython-311.pyc b/ctgan/synthesizers/__pycache__/tvae.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d6ab1a3b33b5e4a9046deed08c1c7ac4bfa0f48 GIT binary patch literal 13794 zcmZ3^%ge>Uz`$VFos{mX$iVOz#DQT}DC6@B1_p-d3@Hpz3@MB$OgW5EOkkQhiUmxw zMzN+arZDHQ<+4Yy=W;}G<Z?!F=5j@G<#I=HGlKN6<nZM3M)88#tT}wS{89X1Hd~HB zu3(g4u27T^n9rUgoGTn93}$oWh~$b!iGtajIbyluQR2B0Q4)*{ObqS}DO@cKDcq?X z%a|D$Rx?B0$q*$8mg8w*Na4khlS<)hVTh6jtKe^8ND)XCN7F5n!W7J)DOkl566WZt zkei=Unv<&c5@fiaChILWpUk|}#G+elzF>)4oB@eNiMgpIsYSO~f>M1#Z*c{u7M7;w zm1HL7Xfod7O)knWE{;#lD=Eq^sMKV<#p0Nfn5)Tji_0akB+(~7F(tL=7H>#uUU7aA zNTfKm1f-~>FeO)$@fJTwD5NMcuQ)BgC^xl8lj#<}Q(|#yaAjUeMrv_pRcg^Ko}$FO zl>FTI;*!LY)MSt|VVDWZ_^iRez|hVxogtMWiZO)&6m3z=9SmuVDNHRKQ7kFUEeuht z9SjwWQEb5snk={2UGtLjQ&NknenS}wY57G8B^jv-pb*jH;!;phP;e|tFSY`+K)jU1 zlEnCw%v=SH%)Am!gg8jfB{MfQuQ)S5uUH{J4QvcZwH{n=a(-?>QEG89NPn?{MoDQw zPO3tFkwQ*paS2FAW?o4#nvuboRjDAU)WqZrg^bLUl+-+hoW#o1B7~`_xk;%hDVcfc zAY-vP3uXq$U-_ja1*IhlWvR&}`A{>GS)ic@qSzQ17(mJEvjrmq!&JuU3?+;pVF)gP zg);*KLk+_+CI*JpaFGRY7DO$CT*kn_uo}W-U|?Wms9~&SNoS~KEn$ZXGSo0FV`N}h z4Y$XIAy&1Pt%kLP6G;tA4I8R`CEN&ER9n)Ss(2X~YFTU8YFT>JN_Y|KYFPUiY8c|- zcB9$AjA}y-Lp%peFF}=Xx27-zGiWmUX>#3SEiNrcEh-XZU|_h#5g(tKmst`Ye~Sea z*+rlrE&_#h5y&yOn2S>jZ!xDO8-fTUO^#bE#i==Iw>Ut_2NbKfcu|thEnbv3zr~!A znOh_VvWF$JBsKRIXMB8ePGWI!e0(v;2cV>>prFw3%f!VhCM7E|FFik?C_gJTxg@3} zF(<nuvn(|xxg<R?FQyog+lpgK$`Vuc3Mz|47#J9;1fUs259SdW1_p-WX$%YuKN=YB za`5)1Pbj&_A#;U8<~oP!B@We#9O_p%)Gu(T-{2GLV7|d8(7}9HL}G^QMG^HYBI+L) zm^j73M2E`_9{#CrGmICoToh5gBBHvqWMSEovJ1j`7kTtLoNow;bvSo8KNXgpQG7vI zxr3#L<F2se4CNWM7ln1N2<u`M=-}w!_za4<WaKmnO68zn|GWcSxYsZ&04c%1HH@gK zv6iWZVF5gCA}dEvx3$bQ3=5ELge%5S$%35zvp_Bbt4(37VX0zeU_f&LYYlS^YYj^p zb1;J@Q<XF%Ccwo@L1J-nYKlTaW?n&QNqmA{I@lP72lkq*;Pg-=&cMJ>BmqhptR?wH z$r-npQ!5HI*@`4VLTnIix7gEC6H7{qQg1Qkmfm8^$xkm!EV{*9T#`}@N}-U{Q=|aW z!Jd|1RGwIrQpFE(M3Ej;095W4hcGZOG%!365}U3uNn?iDMIogtLP{O%H+Tj6Bf26c zq+jHfyTU8i!E{4Vbh^YOi5cn_1(mJ{Ds`~k<rD77oWVRJZ6?n}KGiFHsvXQX_ys2z z&tU7QzrZ0^q{zU)u#z2=iSiP2Q{&@ramB|&YODD8TRidcg{6r(P#O04_>}zQ_;_%t zEz$xxnirZIl2dbX;^S8`7RiH>E+|SJz`g?KP%t3?k_d#P=>`S}{J_M?%JG2#M2K;* z>V06qLP)T&%70+MM}WNmwu3tLmP;z6-ueS&5U#gSN(@4^1UO-!R1%<6kdlg03Q?^B zgOzgFT!pVn#a?mM43mnhhIIib`$2Lv5{bVmLM=KN8A|vOTA|e+s7e3}mI#9>Fi|23 zrocoNylSmsMl~;m5wm!zVTk7eYhz$w09SrYe(AGjg1}CDP2M7q9Yw~VVhB`^6oITk z)PA?v5(^4a^HRXIAf&_swZV&&LHU*?KQFcT7He^4dTyd72e`;80@XUV_|Xz#krc>y zSr7rL1dHTAEN~_Vc@>m$LA4(wuY(*>B>>Ge@XE{<l=0U=GCuywOd6>&lfNjUc|}AM zT#tbX;wrNT9Nhg}U0hSmW+>07xG1b~MOfn^hvpRy%?liwcabYPQMtLLbE;NoT$DGy zB5%AQ`J%k#MNz9OqE;O&JzO1J9bDj21*MV$Rd68u*@~n}u7)8CR33t(F%8rwV62jc zgb~(iMw6)s9MI+r3=En~keE>g1tF;RQ&50}G^p&Y;)j?5uO~b~s-z+H1ix^9MOVcP zwu}7oSNP=@h%Ax1z^{FQLmTX?B5hD48iJIATm~(v!RZJ@fF<yj)YhPq8iSBvV%7P; zfRBLK4zeQ|)HDMJ2?GNIs1*qg*m=z0t|4PBdpbicM-8hB!vw}y^;*svPB2TMmaB#< z3#N-9g((}Pi=k+14SNk2B9N;X7#M1~YnW=-K{YW%c`Z*3V+unGLn~7adm0l$yabdb zAW9grz@~vXH4NC)qGs1BE(V5L-WrYt@D>J0Ck)r{rZ87=FfgPrS8+2i)bK!fAoc<z z8)2#$7*Jc;s5UV&;Iak9_3#3%h6h!D4dVjj(1u%zYG(=~B2+L7B1VQ9?ppSq!&pNa zM;MfVTIW#P5FuK_-O7YB1~{S084$6Bu0Dlj4J&G2B86=YJ8I_})&CH;Phjj(n!wnj zS;Jn#KvZ~eAi^UXl=m2lnY<X97*jY=LLDBPDI7H{pp;j`42^3p28J4z1>gY@n6F?o zYRH395~_<q;a|&F0<S6<7-~3BRn)M7)Q}Ogpw!4mo<4Bl#pPE73XEIq#rbI^xrr4o zLA;m$|NsB5$yfxcCT=komx2ZzIEzwKN|Q@6^YdPU3K>}IdA}z2E#AcBlFYJ1kX(FP z9#|K&Rw_x%%P2Ahl}>h`T8cR*Km8U9cz^;_+uUMHEK1K!tSC|em5uzj*o#t=^Yh{> zZZTHeVgpxY#YO5M&7c<ZEpAA6JU%loE&mo<S|X@Zev7HN<Q7X#esS?F?&SQO(%d|V zG-q)^B9zILno^_*vYZLjZS?kWDF)T@(25R5flJ|AeDNTY;?qj=!2T=>VqjqK1C^?G z7(tEo8=PDXPB*xD8(eOPD}7)PWR>^`B09KlNUD5b5a5*f2qHQ>Z}3Z8;a6Lsd5Pcn z0>AMMQSlj?b966?s!w3LAto~;XiC-u<{P476AD3ahSG%6DG^IqmvF5x-(U!aNgGVp z*j-exx}sopQPgSz(+vTU>D-gJX9O(}oEf?xaAxER<rRf1)h@~!UY9YuBx80_#^Q>M z#YF+jD*~1m1T3M-=eo}ET#!7|Z;s!J!1ZCP!d6DCiP(_1Ci)_>Qc>{<OcR(surY{9 zAlwMC?yiX94GH-Lj9_TAAZdZ=0*)&Znrlth*z923p$LXS2bgw*?$Fr7eM0Dhw$DXv zpDWrv7b0RmFf)mYePmz~75f4rJ~Fci3VmT<5fu6WCb$^HBqlITWcsKmz$x)Tg@K1} zg7|cqNiq}VrpPT|ydtP{iCgIc4Bg=3n<_d*a)#lIvYA#Zm{vHhWW6Y@yMb*x*Cwuw zJX?4!O50o%w7nu|dy&hw!TByXUq?x&Wsl{AvWwiZ3mBKPEn-{9v4mrV;1ceO+}bOG zHyCa=+hn%UVvEHN$1T<ub?q*2+g)I>1J`3Hoefyk@i_t9Hb5RC<77aGDFP{6EsRkt zsjMl2DLgHVsKr4PTPkY`Zwp$L5XBB2apXwh4`$F5y2Sz-L#tv1k$Nv1KwUVdqEJxM z2NBOdW<2iL`s0b+%P))!3@>#U85oLe7#J8r+>()tBT&T!a`)#bMo_g|!&t)*4=bS< zYM5#m;z1b~EK<W<!w?V3=@1@E4MRLgCxpjZ!w?TL3&LZsVTgwpiZyIC4Ds-a3Aw$T z1<KH19q9}yjEfl8pbb>jFvKG&1I`+Tc(}`JIEv&-KusjDZU%-dcnf&}s38FrMNp`r zTg;roT+5XwRRSujpo(fZN<ht8C=a!ADPpeSsO8FIN@1zxDgpKQp{lY#{aG-(h6}A| zW{8JZeBd@atKUlITdb}H`N<i#Se+6}k~40x_~aKCuVlQ%4k?6+Z!sqql-^=VE=@_) z6aqKIi$D#zA`4JK#GYeRoSc}GdW$Owq&Xf`9^T@H6g`kqjV%>qc=0W!f|OgFput`@ z&|orDO<7`2X=?E;_OjHXr2OJka7kN~1TvT>vp7DnEHN`DF$rXPW^ra-aY<rcaw=$m zqUaV|N@`hVa_TMicnCw28`2sD^%D@zD+1+%q7+ctO9K(9AOd835vZkBlm{vgK@}OS z=v08UVZprxUT_MAjeF&Q6mJH%U_VGO$Z2=5T;Y(qA*%(U<@G?c%mZ<`4wfFC8&WDA zEIs@;q}4#85;w$TI#_zR?@G(gvAJO21x3={7p1+gNP9z+Kr}w!7yiJ&$|*C~bdJ?^ zX`@ThMi-?`u1K4J1#gH-b+~o7-Q^XZ!7?}bBCpaFUZoEV%$zbeL?uBI4@9Lq+@6Zd zOsQGmc2QjWinun&Egd`^Jm4~6B||biUobE*fcy>Oe{SIb<uK&oUer{HlGl*>3pj?n z7;BhNQz;_@axY^6$WCzDL?$o>-x!g5yU2}=EO?H^)lf)58^}Y=0Sn-H8^wNxEKV37 zRXr-NmKnLf3@UjLedZLV6y`N7%UBo~R>Mo2S{7u#BT9o>))G)2f?8R_TEmdV2jzf5 z3skFvxu{J?Qr%w5RstG91DgaIxX*%{Tf>kAF9&K^DYK=9Azlb-1w#p_i38=7fLbb0 zb`9GC&;S6GhoDeX8fsXh=x4}++d;ZX;HED_7Q9?WZ=~0<m*^whlm!|#1#{CGQrH$T zu3<-YR}Divy!BlIY7c_dGBA{Ydc{z74MP^FBm(o$b!Wj#cOHf{9H@DvmZOFNAz!4O z!dc5%!?6IV1^^ogCTchrfW|t&Y-A#ZYYjK5ov1uem(7J?0#mGOEmt~2Eq6LYEl-If z!h{+YRP`>Pj&4{C6R1mB%U8p`04Yr&bfLy1BSQ^e3Ku+>7)VaJ{7eirY&HC-E{S1g zV5k+S6|5CPDG`)luBZ`2RbMMyBU~d0?u-b4(m<_94Z{L>I)do}(KRBdI>EjbtraVA z2gyQkjc5&jjaZF9jSxzn2DSIW3Tha#Kw}|b_A+(`hSl(nE!Z^iT8SF*8VRsFBx@uW zsDTxt5g4J!z5pplArxb%L&_gnpdJQTX%;*OE&vV6fW?rB6rL1b)bz*5P{XqTGy;RH z49u?KMGYZP#}$W4$pt!KEocO)n-}P!NrSm)YV^Q@XauSnbXQGa?1{jhn`^kz8A}k2 zNCxzw2yi;#so_pxn}as`Su0h;kOfc7Sq3l%6`P{i1R9$HOChVSkxEBvJuWZ;D?%eu znAY&2`USb1$z!jPLgu5C67e-8mJl^`jUhs%(FCR*cW}=^ATfZ6fx#`a1U&W#YJ4dm ztpZTU%}+_qDOM=S2TPX}C1&Pj=A|ouCJ~?$_Mm1Q>ZCnL1gtV1q)b60uQa!yQZFwh zv8X7qQX#)cp&&6YC9zl!-Y!L(g$L;mf!PF7q^ICnqL5o!T%wSas*tE)q?-bp`cKSJ z0GqA=H#-r-bi~{}$l#RB;^d;#lGJ!;Lt8-uG)1hNlbM~0#Th=(wE{3*3NT$I`5?EH zg60ED@)gok^HPfvOH#ps<eZ<Ek_j5hNX&t+MF80h!eNk=01BRI3JD2_FiJ>J$ShV! zRPY0Z6vQt&3MCnt#gN4c#TogfIVl*1CFkdrBxdG;EzQg;Nli~JQpn6p$xKd#uZVy; zS5LveBqOz`JhM1eM<Fv!A+tmwvlzp?M1_O|^r%Tlz;uWLvc(V^6hIS##fXr;#h#Xv zpIBmS6ym1I4DJiuVlU1pO-sv3y~UYQlwS~^lUNK|k-}4+nwg$a5}%TqoLE^D4eBa@ zN()WKTkI*B#h`9Tl>n$~WTj9I>fc!D8K>2P$D1@6i+mUu7}Bdn*112dYO#k7Cw;I7 zEpq|c;HPO^1nU0X;x0ifbh*WxmRWL(6E0E&YLJ6FK)1Nkit=;g!KvvMTVioZWkKpK zrjq=ke2|+!eNBir1fhdS@FZ9ao@j)0e?ZL-$n*=Op9Y#bNl7g#0yT5NlQFlr3Si56 ziXr_Zkh_aAKz3w;2vFDc7Hd&rUV3T~xc3C=Gu`4ZP6Y)~adJ^+0VrkNVgs$)Ni4a= znVMIc3r>2sIIB{N^5fHs5>sw*ftK#%#e>o`cnIp20MyZ-k*|1AdM^TX%5HHaB_?Ns zW|(iW6qlqH6oJ}Ow^%@#_7+=mPGW9BJb360)bA@vOwK4u1i28@EV{*<lbUynH8r=O zr1BP9a(-TNV#zJ;qEwJGGxJhXD>S9SJ;PgEh%|mn0G5ehS??A*$g1Mh63{d_Xu%SA zYUvhxDrA|^E$;k+lFVGtVkAe`TO64sAb;i;-4cPZ;!{!)A;Oqh1PZ`gtl%#3E#}OW z_@XXQpO&R4H8BS=3J9v1i$E3iE#}mM;#<tAxk<N}i&B$|JU|NBKnX0~5agKR;&>x) z^18(f?npz1^-|J`LH!3%KE{E;qku(~pd}m}prTYSEwiKt)P?=Y3>yFYD#ak9eNkHH zinI=-|H;8~g+t;xhs-4onHf<RG%YW3SY6?;y1-#|LsasLsOoi5{Y#?y8@MitT3->h z?%=vBB)^pPih{ugsq2PLmkgZ_SY9;ry<+HlQNizug5Pz8z)K2&7lOjB1jSuch`*u` ze^Dsmicmra`&~il8SWccw(y@2xeyt5K`{QJVEh%q_zt!Q!s1tim9Ge!c5vJklANKj zz~zdH*$$Q~7G4*Hysrp(Ul;PbB;<FZ<V?*)p_nT|F&*r8#pPy{T#+~2U~<LS{e;9t z@t`Z>LD$7YFNud<2#dNBmUK}(`HFZls1JNsOzJL=zy$H@g36Z!l`jgaUJ+Ei$fI_J zM{TXhij*}u*Y#X4>A7CibHAeJeo@ooil)a!9*+y*(O1HgKQJ?ibA4oB66g8?B09Lg zaxloq;q1rX6&9b%IwNIH_C*P;D-v2eME1yBFm=5k?0!+${fe-A2geO={vI39q@fq% z4Q0~{JPJ2>M6UCwT;fr=plW$t)!~w=!-2#Lsty-bJ+7#FT;%b*!sB^?$Mb=p<Oc>V zPN9z=qQmiqluU>B4H2;p*SoT63(79a>Rpl5`@kT^YyS~MOyr!vJR#(Ufb?|%)k^}Z z7X{R>2&i8WP`@EAJ;QNI`ShwuRSQ&>=v)-nzap+bf#rsh<^uH<DhqX&=uYIHz%;>m zhS*&(=__Jt3yLp_X<rf3p1^!VOmc?v#LNXs6Z0l8-;h(9pEE0Gh4Do>{VQ_%6S*fa zPYAoKsDDF5{Dz3c4H2muB4S`t;);mM6%o@LA|S>`ehwv}FAN+?LKFEWu+3omAjBXc zw}5S8{R9@!I4|!c-iiEE_$Tm#ovm<5K;fc*(iH)v3j#_v1VpCrTmkXauM6m163|;w zvLP5gNc=!RbUM!@o*BUx1r)CcC|(dy1e>6ANkD0V$#UyO)(dTy*j`sQxuk4zQQ7Q@ zve|WI>r2Yk7nN<VDBE5Xu)88)cR|4JqJZ5~8M!$=H>9Lz2+u8EU^=sEPSr(8%^Q+( zHzbrlu(N3KePCeG;`_+JBFp~;M0EIm;9?Mw0H<2b4-BllhIf@!msqYyxuR!xfcJu~ z-$iBrE6V=Ym7^~yN8gZ?{lLti62thBfk{Q^3y7G&az#Mlf`GyY4h9uNFy{fka8Jz# z24-HN8v+v71vD-RXxtFdxGti9Nko4`$oA+>(Hmp8#9lXbyJYHi(bVILsmFCw-%F;x z7ft=InEGE3(Z47Xa785GBQq0dkeUfJNX;ZD^z%a?qZX$SIQf7Ff>0)oLBpY-A(_vw zm_WloH4LCq9?~|&zWNEZ`3+jxSHqCS1ku63P{qW+kirP=<S?Z()G{NlcS2tc<-!nq zx|RibpeGCD8n98wX4Wtu27OSwE48dOtWew7P|SgMy=qu%*w&y8NRiwb1<lHVU4>56 zpbZUxJ9-?ooG5MqRoGyS=<{T?95u`fK(#(t7MZBwSOBVhk%ho))UYRJNO=K#SP5(Z zlt48HwcEkS&~v5)X<z|qFqs)Vz=g8<vzEDrIfb=`u@!0F8nhk`F}$<@dxx93hLf&- zL~(U6gC<+mvEam9&^laD&96|LnVXrDSVVg7q8MyQu>zJ=!JwMiuQWF)wFtD3w<y0H z+<Ad$K<cpsrIwTy<sq#!2I&Jg{_u5F;7#Tt2hfBg(=C?V)WkgS>>#vJdW$8oBD2^} zQ@E%e)TUrbElbP+Hyn!^L5(NyQX5uKORX3@sDFzsFTW@^F{h{&Bnxh8I)PZA5w0Ro z%K_Zxy2T21*)68L0?;fasI`@vmspZoQ~|O>Ff*?#wWv5X9^R_e<bt%m*dg8qE$b;s zEhvVxhZy6FmVm4PHHL1n=H+J=r`}>oOUzCMw}Og683z;=5Dae6++u^6Q^gH#(?B>k zK+PIU@U-R!F$PZF9-iwQl9xCnFLFp<;gG(-A^ku=e1^#!`-=ivD*`VH=ykB(;1}#* zxhp6(p>(3m6qyeUY@DWGqQmj7u*4MW86|V-H?UkWaJne$d_~x~!{vsM*bKpm8XeAe zMI@*AEJ)eFd_lzIqKL^A5tHj8=9fgwcd+c?xhUd#MZ~qk^@g<ke4AM|D_Ac|>s^u7 z>-6bxnV<$T*J?(}MNkt*8MItTpx?X8dj`vV?pfS3dFSw6m(aW<p?Oh4`-+72bqT{u z5{4HgjIT%-U*s{l!eep)jGpof_SDTtSs{3VU;84z_7#5Z4wf7I!V^3ugirA3sRMJa z@XKH4H@w7exItwH+kW0%ygT{#@L#v^zGUHj(ZctNh3|EX;7b<47cD}sScG2W54*x2 zc7Y=dvswUE2_XC#G;che0d4+`@GJmQHG*Eeq%eY)v+4w=LOS{hkQtl=P~J;QRRDKn zHPqb$Le)VXbam%|P)$8QO=d{%2b7e-DHOK+2t0=as(zqTH@CPzEq{oSRs4`t0`9SZ zri(y}gNs3fbPWs-#AQJ(!y6ntJv{yVUHmgx=I~tPP`$#TdI5&Oq2i~>37NbM1Eo|j z5e{NSfQU#C0hztjWGpHGWkk**(BM}QXmqOxG<;PA8d-)cN(GI+gInZ9n?U-ufC!Kt zusz`9T?AT53noAvh~l~6*(S(x)CPtdEc^}LA2=AqrEZ8R+>ijZif#yt-QX9xAtHK% zSMY{_=m$|b&?+ZX!hw;2P3QxYCnKxS2L=TBk(q%_s3is}5yk{oB*7Za_<;dL!lgl+ z4<^j4Y9AOd35d%<j@FcgcyuMmQ>#E_0UM<MUIgw{f(siC@JeL8ygbM{K+q&f5vamP zS`RM_Q=$i3xR97r1e&qA#SM`G_rr=no!DCf5RuZ7%$#C9P$6CfDlj13Xi)8WOANfW z9V~??kWz~vU2$=UB=W*{y`;qA)FM!YSp<$M&<qKr3k}Ky;AjEIF~rxPb<w{#Y#_^# z?TYq*w>dB{6c;cuFnnNUWMsU-AasF2=mD6#048rR2w#Aq8w@-RV0eSU;R12!27~Se zRCI$u<^n3Z!C-a)8@j<DdI1$ZU@*FXif%ACUciPvu*osXe_+5QI?_IZM8ALth`bY{ w6r;`u2228^REkmO3qs0Dg3<m11Dr7PWnxtOz<{0Xu>S~_`2r>}Re@s!05w*xK>z>% literal 0 HcmV?d00001 diff --git a/ctgan/synthesizers/base.py b/ctgan/synthesizers/base.py new file mode 100644 index 0000000..add0dd7 --- /dev/null +++ b/ctgan/synthesizers/base.py @@ -0,0 +1,151 @@ +"""BaseSynthesizer module.""" + +import contextlib + +import numpy as np +import torch + + +@contextlib.contextmanager +def set_random_states(random_state, set_model_random_state): + """Context manager for managing the random state. + + Args: + random_state (int or tuple): + The random seed or a tuple of (numpy.random.RandomState, torch.Generator). + set_model_random_state (function): + Function to set the random state on the model. + """ + original_np_state = np.random.get_state() + original_torch_state = torch.get_rng_state() + + random_np_state, random_torch_state = random_state + + np.random.set_state(random_np_state.get_state()) + torch.set_rng_state(random_torch_state.get_state()) + + try: + yield + finally: + current_np_state = np.random.RandomState() + current_np_state.set_state(np.random.get_state()) + current_torch_state = torch.Generator() + current_torch_state.set_state(torch.get_rng_state()) + set_model_random_state((current_np_state, current_torch_state)) + + np.random.set_state(original_np_state) + torch.set_rng_state(original_torch_state) + + +def random_state(function): + """Set the random state before calling the function. + + Args: + function (Callable): + The function to wrap around. + """ + + def wrapper(self, *args, **kwargs): + if self.random_states is None: + return function(self, *args, **kwargs) + + else: + with set_random_states(self.random_states, self.set_random_state): + return function(self, *args, **kwargs) + + return wrapper + + +class BaseSynthesizer: + """Base class for all default synthesizers of ``CTGAN``.""" + + random_states = None + + def __getstate__(self): + """Improve pickling state for ``BaseSynthesizer``. + + Convert to ``cpu`` device before starting the pickling process in order to be able to + load the model even when used from an external tool such as ``SDV``. Also, if + ``random_states`` are set, store their states as dictionaries rather than generators. + + Returns: + dict: + Python dict representing the object. + """ + device_backup = self._device + self.set_device(torch.device('cpu')) + state = self.__dict__.copy() + self.set_device(device_backup) + if ( + isinstance(self.random_states, tuple) + and isinstance(self.random_states[0], np.random.RandomState) + and isinstance(self.random_states[1], torch.Generator) + ): + state['_numpy_random_state'] = self.random_states[0].get_state() + state['_torch_random_state'] = self.random_states[1].get_state() + state.pop('random_states') + + return state + + def __setstate__(self, state): + """Restore the state of a ``BaseSynthesizer``. + + Restore the ``random_states`` from the state dict if those are present and then + set the device according to the current hardware. + """ + if '_numpy_random_state' in state and '_torch_random_state' in state: + np_state = state.pop('_numpy_random_state') + torch_state = state.pop('_torch_random_state') + + current_torch_state = torch.Generator() + current_torch_state.set_state(torch_state) + + current_numpy_state = np.random.RandomState() + current_numpy_state.set_state(np_state) + state['random_states'] = (current_numpy_state, current_torch_state) + + self.__dict__ = state + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + self.set_device(device) + + def save(self, path): + """Save the model in the passed `path`.""" + device_backup = self._device + self.set_device(torch.device('cpu')) + torch.save(self, path) + self.set_device(device_backup) + + @classmethod + def load(cls, path): + """Load the model stored in the passed `path`.""" + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + model = torch.load(path) + model.set_device(device) + return model + + def set_random_state(self, random_state): + """Set the random state. + + Args: + random_state (int, tuple, or None): + Either a tuple containing the (numpy.random.RandomState, torch.Generator) + or an int representing the random seed to use for both random states. + """ + if random_state is None: + self.random_states = random_state + elif isinstance(random_state, int): + self.random_states = ( + np.random.RandomState(seed=random_state), + torch.Generator().manual_seed(random_state), + ) + elif ( + isinstance(random_state, tuple) + and isinstance(random_state[0], np.random.RandomState) + and isinstance(random_state[1], torch.Generator) + ): + self.random_states = random_state + else: + raise TypeError( + f'`random_state` {random_state} expected to be an int or a tuple of ' + '(`np.random.RandomState`, `torch.Generator`)' + ) diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py new file mode 100644 index 0000000..704468e --- /dev/null +++ b/ctgan/synthesizers/ctgan.py @@ -0,0 +1,564 @@ +"""CTGAN module.""" +import logging +import sys +import warnings + +import numpy as np +import pandas as pd +import torch +from torch import optim +from torch.nn import BatchNorm1d, Dropout, LeakyReLU, Linear, Module, ReLU, Sequential, functional +from tqdm import tqdm + +from ctgan.data_sampler import DataSampler +from ctgan.data_transformer import DataTransformer +from ctgan.synthesizers.base import BaseSynthesizer, random_state + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler()] +) +class Discriminator(Module): + """Discriminator for the CTGAN.""" + + def __init__(self, input_dim, discriminator_dim, pac=10): + super(Discriminator, self).__init__() + dim = input_dim * pac + self.pac = pac + self.pacdim = dim + seq = [] + for item in list(discriminator_dim): + seq += [Linear(dim, item), LeakyReLU(0.2), Dropout(0.5)] + dim = item + + seq += [Linear(dim, 1)] + self.seq = Sequential(*seq) + + def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lambda_=10): + """Compute the gradient penalty.""" + alpha = torch.rand(real_data.size(0) // pac, 1, 1, device=device) + alpha = alpha.repeat(1, pac, real_data.size(1)) + alpha = alpha.view(-1, real_data.size(1)) + + interpolates = alpha * real_data + ((1 - alpha) * fake_data) + + disc_interpolates = self(interpolates) + + gradients = torch.autograd.grad( + outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=torch.ones(disc_interpolates.size(), device=device), + create_graph=True, + retain_graph=True, + only_inputs=True, + )[0] + + gradients_view = gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1 + gradient_penalty = ((gradients_view) ** 2).mean() * lambda_ + + return gradient_penalty + + def forward(self, input_): + """Apply the Discriminator to the `input_`.""" + assert input_.size()[0] % self.pac == 0 + return self.seq(input_.view(-1, self.pacdim)) + + +class Residual(Module): + """Residual layer for the CTGAN.""" + + def __init__(self, i, o): + super(Residual, self).__init__() + self.fc = Linear(i, o) + self.bn = BatchNorm1d(o) + self.relu = ReLU() + + def forward(self, input_): + """Apply the Residual layer to the `input_`.""" + out = self.fc(input_) + out = self.bn(out) + out = self.relu(out) + return torch.cat([out, input_], dim=1) + + +class Generator(Module): + """Generator for the CTGAN.""" + + def __init__(self, embedding_dim, generator_dim, data_dim): + super(Generator, self).__init__() + dim = embedding_dim + seq = [] + for item in list(generator_dim): + seq += [Residual(dim, item)] + dim += item + seq.append(Linear(dim, data_dim)) + self.seq = Sequential(*seq) + + def forward(self, input_): + """Apply the Generator to the `input_`.""" + data = self.seq(input_) + return data + + +class CTGAN(BaseSynthesizer): + """Conditional Table GAN Synthesizer. + + This is the core class of the CTGAN project, where the different components + are orchestrated together. + For more details about the process, please check the [Modeling Tabular data using + Conditional GAN](https://arxiv.org/abs/1907.00503) paper. + + Args: + embedding_dim (int): + Size of the random sample passed to the Generator. Defaults to 128. + generator_dim (tuple or list of ints): + Size of the output samples for each one of the Residuals. A Residual Layer + will be created for each one of the values provided. Defaults to (256, 256). + discriminator_dim (tuple or list of ints): + Size of the output samples for each one of the Discriminator Layers. A Linear Layer + will be created for each one of the values provided. Defaults to (256, 256). + generator_lr (float): + Learning rate for the generator. Defaults to 2e-4. + generator_decay (float): + Generator weight decay for the Adam Optimizer. Defaults to 1e-6. + discriminator_lr (float): + Learning rate for the discriminator. Defaults to 2e-4. + discriminator_decay (float): + Discriminator weight decay for the Adam Optimizer. Defaults to 1e-6. + batch_size (int): + Number of data samples to process in each step. + discriminator_steps (int): + Number of discriminator updates to do for each generator update. + From the WGAN paper: https://arxiv.org/abs/1701.07875. WGAN paper + default is 5. Default used is 1 to match original CTGAN implementation. + log_frequency (boolean): + Whether to use log frequency of categorical levels in conditional + sampling. Defaults to ``True``. + verbose (boolean): + Whether to have print statements for progress results. Defaults to ``False``. + epochs (int): + Number of training epochs. Defaults to 300. + pac (int): + Number of samples to group together when applying the discriminator. + Defaults to 10. + cuda (bool): + Whether to attempt to use cuda for GPU computation. + If this is False or CUDA is not available, CPU will be used. + Defaults to ``True``. + """ + + def __init__( + self, + embedding_dim=128, + generator_dim=(256, 256), + discriminator_dim=(256, 256), + generator_lr=2e-4, + generator_decay=1e-6, + discriminator_lr=2e-4, + discriminator_decay=1e-6, + batch_size=500, + discriminator_steps=1, + log_frequency=True, + verbose=False, + epochs=300, + pac=10, + cuda=True, + ): + assert batch_size % 2 == 0 + + self._embedding_dim = embedding_dim + self._generator_dim = generator_dim + self._discriminator_dim = discriminator_dim + + self._generator_lr = generator_lr + self._generator_decay = generator_decay + self._discriminator_lr = discriminator_lr + self._discriminator_decay = discriminator_decay + + self._batch_size = batch_size + self._discriminator_steps = discriminator_steps + self._log_frequency = log_frequency + self._verbose = verbose + self._epochs = epochs + self.pac = pac + + if not cuda or not torch.cuda.is_available(): + device = 'cpu' + elif isinstance(cuda, str): + device = cuda + else: + device = 'cuda' + + self._device = torch.device(device) + + self._transformer = None + self._data_sampler = None + self._generator = None + self.loss_values = None + + @staticmethod + def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): + """Deals with the instability of the gumbel_softmax for older versions of torch. + + For more details about the issue: + https://drive.google.com/file/d/1AA5wPfZ1kquaRtVruCd6BiYZGcDeNxyP/view?usp=sharing + + Args: + logits […, num_features]: + Unnormalized log probabilities + tau: + Non-negative scalar temperature + hard (bool): + If True, the returned samples will be discretized as one-hot vectors, + but will be differentiated as if it is the soft sample in autograd + dim (int): + A dimension along which softmax will be computed. Default: -1. + + Returns: + Sampled tensor of same shape as logits from the Gumbel-Softmax distribution. + """ + for _ in range(10): + transformed = functional.gumbel_softmax(logits, tau=tau, hard=hard, eps=eps, dim=dim) + if not torch.isnan(transformed).any(): + return transformed + + raise ValueError('gumbel_softmax returning NaN.') + + def _apply_activate(self, data): + """Apply proper activation function to the output of the generator.""" + data_t = [] + st = 0 + for column_info in self._transformer.output_info_list: + for span_info in column_info: + if span_info.activation_fn == 'tanh': + ed = st + span_info.dim + data_t.append(torch.tanh(data[:, st:ed])) + st = ed + elif span_info.activation_fn == 'softmax': + ed = st + span_info.dim + transformed = self._gumbel_softmax(data[:, st:ed], tau=0.2) + data_t.append(transformed) + st = ed + else: + raise ValueError(f'Unexpected activation function {span_info.activation_fn}.') + + return torch.cat(data_t, dim=1) + + def _cond_loss(self, data, c, m): + """Compute the cross entropy loss on the fixed discrete column.""" + loss = [] + st = 0 + st_c = 0 + for column_info in self._transformer.output_info_list: + for span_info in column_info: + if len(column_info) != 1 or span_info.activation_fn != 'softmax': + # not discrete column + st += span_info.dim + else: + ed = st + span_info.dim + ed_c = st_c + span_info.dim + tmp = functional.cross_entropy( + data[:, st:ed], torch.argmax(c[:, st_c:ed_c], dim=1), reduction='none' + ) + loss.append(tmp) + st = ed + st_c = ed_c + + loss = torch.stack(loss, dim=1) # noqa: PD013 + + return (loss * m).sum() / data.size()[0] + + def _validate_discrete_columns(self, train_data, discrete_columns): + """Check whether ``discrete_columns`` exists in ``train_data``. + + Args: + train_data (numpy.ndarray or pandas.DataFrame): + Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame. + discrete_columns (list-like): + List of discrete columns to be used to generate the Conditional + Vector. If ``train_data`` is a Numpy array, this list should + contain the integer indices of the columns. Otherwise, if it is + a ``pandas.DataFrame``, this list should contain the column names. + """ + if isinstance(train_data, pd.DataFrame): + invalid_columns = set(discrete_columns) - set(train_data.columns) + elif isinstance(train_data, np.ndarray): + invalid_columns = [] + for column in discrete_columns: + if column < 0 or column >= train_data.shape[1]: + invalid_columns.append(column) + else: + raise TypeError('``train_data`` should be either pd.DataFrame or np.array.') + + if invalid_columns: + raise ValueError(f'Invalid columns found: {invalid_columns}') + + @random_state + def fit(self, train_data, discrete_columns=(), epochs=None): + """Fit the CTGAN Synthesizer models to the training data. + + Args: + train_data (numpy.ndarray or pandas.DataFrame): + Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame. + discrete_columns (list-like): + List of discrete columns to be used to generate the Conditional + Vector. If ``train_data`` is a Numpy array, this list should + contain the integer indices of the columns. Otherwise, if it is + a ``pandas.DataFrame``, this list should contain the column names. + """ + self._validate_discrete_columns(train_data, discrete_columns) + + if epochs is None: + epochs = self._epochs + else: + warnings.warn( + ( + '`epochs` argument in `fit` method has been deprecated and will be removed ' + 'in a future version. Please pass `epochs` to the constructor instead' + ), + DeprecationWarning, + ) + + self._transformer = DataTransformer() + self._transformer.fit(train_data, discrete_columns) + + train_data = self._transformer.transform(train_data) + + self._data_sampler = DataSampler( + train_data, self._transformer.output_info_list, self._log_frequency + ) + + data_dim = self._transformer.output_dimensions + + self._generator = Generator( + self._embedding_dim + self._data_sampler.dim_cond_vec(), self._generator_dim, data_dim + ).to(self._device) + + discriminator = Discriminator( + data_dim + self._data_sampler.dim_cond_vec(), self._discriminator_dim, pac=self.pac + ).to(self._device) + + optimizerG = optim.Adam( + self._generator.parameters(), + lr=self._generator_lr, + betas=(0.5, 0.9), + weight_decay=self._generator_decay, + ) + + optimizerD = optim.Adam( + discriminator.parameters(), + lr=self._discriminator_lr, + betas=(0.5, 0.9), + weight_decay=self._discriminator_decay, + ) + + mean = torch.zeros(self._batch_size, self._embedding_dim, device=self._device) + std = mean + 1 + + self.loss_values = pd.DataFrame(columns=['Epoch', 'Generator Loss', 'Distriminator Loss']) + + epoch_iterator = tqdm(range(epochs), disable=(not self._verbose),file=sys.stdout) + if self._verbose: + description = 'Gen. ({gen:.2f}) | Discrim. ({dis:.2f})' + epoch_iterator.set_description(description.format(gen=0, dis=0)) + + steps_per_epoch = max(len(train_data) // self._batch_size, 1) + for i in epoch_iterator: + logging.info(f" the epoch {i}") + for id_ in range(steps_per_epoch): + for n in range(self._discriminator_steps): + fakez = torch.normal(mean=mean, std=std) + + condvec = self._data_sampler.sample_condvec(self._batch_size) + if condvec is None: + c1, m1, col, opt = None, None, None, None + real = self._data_sampler.sample_data( + train_data, self._batch_size, col, opt + ) + else: + c1, m1, col, opt = condvec + c1 = torch.from_numpy(c1).to(self._device) + m1 = torch.from_numpy(m1).to(self._device) + fakez = torch.cat([fakez, c1], dim=1) + + perm = np.arange(self._batch_size) + np.random.shuffle(perm) + real = self._data_sampler.sample_data( + train_data, self._batch_size, col[perm], opt[perm] + ) + c2 = c1[perm] + + fake = self._generator(fakez) + fakeact = self._apply_activate(fake) + + real = torch.from_numpy(real.astype('float32')).to(self._device) + + if c1 is not None: + fake_cat = torch.cat([fakeact, c1], dim=1) + real_cat = torch.cat([real, c2], dim=1) + else: + real_cat = real + fake_cat = fakeact + + y_fake = discriminator(fake_cat) + y_real = discriminator(real_cat) + + pen = discriminator.calc_gradient_penalty( + real_cat, fake_cat, self._device, self.pac + ) + loss_d = -(torch.mean(y_real) - torch.mean(y_fake)) + + optimizerD.zero_grad(set_to_none=False) + pen.backward(retain_graph=True) + loss_d.backward() + optimizerD.step() + + fakez = torch.normal(mean=mean, std=std) + condvec = self._data_sampler.sample_condvec(self._batch_size) + + if condvec is None: + c1, m1, col, opt = None, None, None, None + else: + c1, m1, col, opt = condvec + c1 = torch.from_numpy(c1).to(self._device) + m1 = torch.from_numpy(m1).to(self._device) + fakez = torch.cat([fakez, c1], dim=1) + + fake = self._generator(fakez) + fakeact = self._apply_activate(fake) + + if c1 is not None: + y_fake = discriminator(torch.cat([fakeact, c1], dim=1)) + else: + y_fake = discriminator(fakeact) + + if condvec is None: + cross_entropy = 0 + else: + cross_entropy = self._cond_loss(fake, c1, m1) + + loss_g = -torch.mean(y_fake) + cross_entropy + + optimizerG.zero_grad(set_to_none=False) + loss_g.backward() + optimizerG.step() + + generator_loss = loss_g.detach().cpu().item() + discriminator_loss = loss_d.detach().cpu().item() + + epoch_loss_df = pd.DataFrame({ + 'Epoch': [i], + 'Generator Loss': [generator_loss], + 'Discriminator Loss': [discriminator_loss], + }) + if not self.loss_values.empty: + self.loss_values = pd.concat([self.loss_values, epoch_loss_df]).reset_index( + drop=True + ) + else: + self.loss_values = epoch_loss_df + print(description.format(gen=generator_loss, dis=discriminator_loss)) + if self._verbose: + epoch_iterator.set_description( + description.format(gen=generator_loss, dis=discriminator_loss) + ) + + @random_state + def sample(self, n, condition_column=None, condition_value=None): + """Sample data similar to the training data. + + Choosing a condition_column and condition_value will increase the probability of the + discrete condition_value happening in the condition_column. + + Args: + n (int): + Number of rows to sample. + condition_column (string): + Name of a discrete column. + condition_value (string): + Name of the category in the condition_column which we wish to increase the + probability of happening. + + Returns: + numpy.ndarray or pandas.DataFrame + """ + if condition_column is not None and condition_value is not None: + condition_info = self._transformer.convert_column_name_value_to_id( + condition_column, condition_value + ) + global_condition_vec = self._data_sampler.generate_cond_from_condition_column_info( + condition_info, self._batch_size + ) + else: + global_condition_vec = None + + steps = n // self._batch_size + 1 + data = [] + for i in range(steps): + mean = torch.zeros(self._batch_size, self._embedding_dim) + std = mean + 1 + fakez = torch.normal(mean=mean, std=std).to(self._device) + + if global_condition_vec is not None: + condvec = global_condition_vec.copy() + else: + condvec = self._data_sampler.sample_original_condvec(self._batch_size) + + if condvec is None: + pass + else: + c1 = condvec + c1 = torch.from_numpy(c1).to(self._device) + fakez = torch.cat([fakez, c1], dim=1) + + fake = self._generator(fakez) + fakeact = self._apply_activate(fake) + data.append(fakeact.detach().cpu().numpy()) + + data = np.concatenate(data, axis=0) + data = data[:n] + + return self._transformer.inverse_transform(data) + + def set_device(self, device): + """Set the `device` to be used ('GPU' or 'CPU).""" + self._device = device + if self._generator is not None: + self._generator.to(self._device) + + def predict(self, data): + if self._generator is None or self._transformer is None or self._data_sampler is None: + raise RuntimeError("模型尚未训练,请先调用 `fit` 方法训练模型。") + + # 预处理输入数据 + if isinstance(data, pd.DataFrame): + data = self._transformer.transform(data) + elif isinstance(data, np.ndarray): + if data.shape[1] != self._transformer.output_dimensions: + raise ValueError( + f"输入数据的列数 ({data.shape[1]}) 与训练数据的列数 ({self._transformer.output_dimensions}) 不匹配。" + ) + else: + raise TypeError("输入数据必须是 pandas.DataFrame 或 numpy.ndarray 类型。") + + # 将数据转换为 PyTorch 张量 + data_tensor = torch.from_numpy(data.astype("float32")).to(self._device) + + # 为 Discriminator 添加条件向量 + condvec = self._data_sampler.sample_original_condvec(data.shape[0]) + if condvec is not None: + c1 = torch.from_numpy(condvec).to(self._device) + data_tensor = torch.cat([data_tensor, c1], dim=1) + + # 使用 Discriminator 生成分数 + discriminator = Discriminator( + data_tensor.shape[1], self._discriminator_dim, pac=self.pac + ).to(self._device) + discriminator.load_state_dict(self._generator.state_dict()) # 使用训练好的权重 + + # 预测输出 + with torch.no_grad(): + scores = discriminator(data_tensor) + + return scores.detach().cpu().numpy() \ No newline at end of file diff --git a/ctgan/synthesizers/tvae.py b/ctgan/synthesizers/tvae.py new file mode 100644 index 0000000..ecefbb5 --- /dev/null +++ b/ctgan/synthesizers/tvae.py @@ -0,0 +1,246 @@ +"""TVAE module.""" + +import numpy as np +import pandas as pd +import torch +from torch.nn import Linear, Module, Parameter, ReLU, Sequential +from torch.nn.functional import cross_entropy +from torch.optim import Adam +from torch.utils.data import DataLoader, TensorDataset +from tqdm import tqdm + +from ctgan.data_transformer import DataTransformer +from ctgan.synthesizers.base import BaseSynthesizer, random_state + + +class Encoder(Module): + """Encoder for the TVAE. + + Args: + data_dim (int): + Dimensions of the data. + compress_dims (tuple or list of ints): + Size of each hidden layer. + embedding_dim (int): + Size of the output vector. + """ + + def __init__(self, data_dim, compress_dims, embedding_dim): + super(Encoder, self).__init__() + dim = data_dim + seq = [] + for item in list(compress_dims): + seq += [Linear(dim, item), ReLU()] + dim = item + + self.seq = Sequential(*seq) + self.fc1 = Linear(dim, embedding_dim) + self.fc2 = Linear(dim, embedding_dim) + + def forward(self, input_): + """Encode the passed `input_`.""" + feature = self.seq(input_) + mu = self.fc1(feature) + logvar = self.fc2(feature) + std = torch.exp(0.5 * logvar) + return mu, std, logvar + + +class Decoder(Module): + """Decoder for the TVAE. + + Args: + embedding_dim (int): + Size of the input vector. + decompress_dims (tuple or list of ints): + Size of each hidden layer. + data_dim (int): + Dimensions of the data. + """ + + def __init__(self, embedding_dim, decompress_dims, data_dim): + super(Decoder, self).__init__() + dim = embedding_dim + seq = [] + for item in list(decompress_dims): + seq += [Linear(dim, item), ReLU()] + dim = item + + seq.append(Linear(dim, data_dim)) + self.seq = Sequential(*seq) + self.sigma = Parameter(torch.ones(data_dim) * 0.1) + + def forward(self, input_): + """Decode the passed `input_`.""" + return self.seq(input_), self.sigma + + +def _loss_function(recon_x, x, sigmas, mu, logvar, output_info, factor): + st = 0 + loss = [] + for column_info in output_info: + for span_info in column_info: + if span_info.activation_fn != 'softmax': + ed = st + span_info.dim + std = sigmas[st] + eq = x[:, st] - torch.tanh(recon_x[:, st]) + loss.append((eq**2 / 2 / (std**2)).sum()) + loss.append(torch.log(std) * x.size()[0]) + st = ed + + else: + ed = st + span_info.dim + loss.append( + cross_entropy( + recon_x[:, st:ed], torch.argmax(x[:, st:ed], dim=-1), reduction='sum' + ) + ) + st = ed + + assert st == recon_x.size()[1] + KLD = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp()) + return sum(loss) * factor / x.size()[0], KLD / x.size()[0] + + +class TVAE(BaseSynthesizer): + """TVAE.""" + + def __init__( + self, + embedding_dim=128, + compress_dims=(128, 128), + decompress_dims=(128, 128), + l2scale=1e-5, + batch_size=500, + epochs=300, + loss_factor=2, + cuda=True, + verbose=False, + ): + self.embedding_dim = embedding_dim + self.compress_dims = compress_dims + self.decompress_dims = decompress_dims + + self.l2scale = l2scale + self.batch_size = batch_size + self.loss_factor = loss_factor + self.epochs = epochs + self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss']) + self.verbose = verbose + + if not cuda or not torch.cuda.is_available(): + device = 'cpu' + elif isinstance(cuda, str): + device = cuda + else: + device = 'cuda' + + self._device = torch.device(device) + + @random_state + def fit(self, train_data, discrete_columns=()): + """Fit the TVAE Synthesizer models to the training data. + + Args: + train_data (numpy.ndarray or pandas.DataFrame): + Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame. + discrete_columns (list-like): + List of discrete columns to be used to generate the Conditional + Vector. If ``train_data`` is a Numpy array, this list should + contain the integer indices of the columns. Otherwise, if it is + a ``pandas.DataFrame``, this list should contain the column names. + """ + self.transformer = DataTransformer() + self.transformer.fit(train_data, discrete_columns) + train_data = self.transformer.transform(train_data) + dataset = TensorDataset(torch.from_numpy(train_data.astype('float32')).to(self._device)) + loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, drop_last=False) + + data_dim = self.transformer.output_dimensions + encoder = Encoder(data_dim, self.compress_dims, self.embedding_dim).to(self._device) + self.decoder = Decoder(self.embedding_dim, self.decompress_dims, data_dim).to(self._device) + optimizerAE = Adam( + list(encoder.parameters()) + list(self.decoder.parameters()), weight_decay=self.l2scale + ) + + self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss']) + iterator = tqdm(range(self.epochs), disable=(not self.verbose)) + if self.verbose: + iterator_description = 'Loss: {loss:.3f}' + iterator.set_description(iterator_description.format(loss=0)) + + for i in iterator: + loss_values = [] + batch = [] + for id_, data in enumerate(loader): + optimizerAE.zero_grad() + real = data[0].to(self._device) + mu, std, logvar = encoder(real) + eps = torch.randn_like(std) + emb = eps * std + mu + rec, sigmas = self.decoder(emb) + loss_1, loss_2 = _loss_function( + rec, + real, + sigmas, + mu, + logvar, + self.transformer.output_info_list, + self.loss_factor, + ) + loss = loss_1 + loss_2 + loss.backward() + optimizerAE.step() + self.decoder.sigma.data.clamp_(0.01, 1.0) + + batch.append(id_) + loss_values.append(loss.detach().cpu().item()) + + epoch_loss_df = pd.DataFrame({ + 'Epoch': [i] * len(batch), + 'Batch': batch, + 'Loss': loss_values, + }) + if not self.loss_values.empty: + self.loss_values = pd.concat([self.loss_values, epoch_loss_df]).reset_index( + drop=True + ) + else: + self.loss_values = epoch_loss_df + + if self.verbose: + iterator.set_description( + iterator_description.format(loss=loss.detach().cpu().item()) + ) + + @random_state + def sample(self, samples): + """Sample data similar to the training data. + + Args: + samples (int): + Number of rows to sample. + + Returns: + numpy.ndarray or pandas.DataFrame + """ + self.decoder.eval() + + steps = samples // self.batch_size + 1 + data = [] + for _ in range(steps): + mean = torch.zeros(self.batch_size, self.embedding_dim) + std = mean + 1 + noise = torch.normal(mean=mean, std=std).to(self._device) + fake, sigmas = self.decoder(noise) + fake = torch.tanh(fake) + data.append(fake.detach().cpu().numpy()) + + data = np.concatenate(data, axis=0) + data = data[:samples] + return self.transformer.inverse_transform(data, sigmas.detach().cpu().numpy()) + + def set_device(self, device): + """Set the `device` to be used ('GPU' or 'CPU).""" + self._device = device + self.decoder.to(self._device) diff --git a/dataprocess/Dockerfile b/dataprocess/Dockerfile new file mode 100644 index 0000000..b5b5ebc --- /dev/null +++ b/dataprocess/Dockerfile @@ -0,0 +1,20 @@ +# 使用官方的 Python 基础镜像 +FROM python:3.11-slim + +# 设置工作目录 +WORKDIR /opt/ml/processing + +# 安装系统依赖 +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +# 升级 pip 并安装 Python 包 +RUN pip install --upgrade pip +RUN pip install pandas numpy scikit-learn imbalanced-learn + +# 将处理脚本复制到容器中 +COPY dataprocess.py /opt/ml/processing/process_data.py + +# 设置默认的命令 +ENTRYPOINT ["python", "/opt/ml/processing/process_data.py"] diff --git a/dataprocess/dataprocess.py b/dataprocess/dataprocess.py new file mode 100644 index 0000000..c7399be --- /dev/null +++ b/dataprocess/dataprocess.py @@ -0,0 +1,191 @@ +import argparse +import os +import pandas as pd +import numpy as np +from sklearn.impute import SimpleImputer +from sklearn.preprocessing import LabelEncoder +from imblearn.over_sampling import SMOTE +import json +import logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[ + logging.StreamHandler() # Log to stdout + ] +) +label_map = {} # LabelEncoder 映射 +params = {} +def encode_numeric_range(df, names, normalized_low=0, normalized_high=1): + """将数值数据归一化到指定范围 [normalized_low, normalized_high]""" + global params + for name in names: + data_low = df[name].min() + data_high = df[name].max() + # 如果 NaN,赋值为 0 + data_low = 0 if np.isnan(data_low) else data_low + data_high = 0 if np.isnan(data_high) else data_high + # 避免 ZeroDivisionError(如果 min == max,则直接赋 0) + if data_high != data_low: + df[name] = ((df[name] - data_low) / (data_high - data_low)) \ + * (normalized_high - normalized_low) + normalized_low + else: + df[name] = 0 + if name not in params: + params[name] = {} + params[name]['min'] = data_low + params[name]['max'] = data_high + return df + +def encode_numeric_zscore(df, names): + """将数值数据标准化为 Z-score""" + global params + for name in names: + mean = df[name].mean() + sd = df[name].std() + # 如果 NaN,赋值为 0 + mean = 0 if np.isnan(mean) else mean + sd = 0 if np.isnan(sd) else sd + # 避免除以 0 的情况(如果标准差为 0,则直接赋 0) + if sd != 0: + df[name] = (df[name] - mean) / sd + else: + df[name] = 0 + if name not in params: + params[name] = {} + params[name]['mean'] = mean + params[name]['std'] = sd + return df + +def numerical_encoding(df, label_column): + """将类别标签编码为 0,1,2... 并存储映射""" + global label_map + label_encoder = LabelEncoder() + df[label_column] = label_encoder.fit_transform(df[label_column]) + + # 生成 label -> 数字 和 数字 -> label 映射 + label_map = { + "label_to_number": {label: int(num) for label, num in zip(label_encoder.classes_, range(len(label_encoder.classes_)))}, + "number_to_label": {int(num): label for label, num in zip(label_encoder.classes_, range(len(label_encoder.classes_)))} + } + return df + +# def oversample_data(X, y): +# """Oversample the minority class using replication to avoid heavy computation.""" +# lists = [X[y == label].values.tolist() for label in np.unique(y)] +# totall_list = [] +# for i in range(len(lists)): +# while len(lists[i]) < 5000: +# lists[i].extend(lists[i]) +# totall_list.extend(lists[i]) +# new_y = np.concatenate([[label] * len(lists[i]) for i, label in enumerate(np.unique(y))]) +# new_X = np.array(totall_list) +# return pd.DataFrame(new_X, columns=X.columns), pd.Series(new_y) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str, default=r"E:\datasource\MachineLearningCSV\MachineLearningCVE", help="Input folder path containing CSV files") + parser.add_argument("--output", type=str, default=r"D:\djangoProject\talktive\dataprocess\result.csv", help="Output folder path for processed CSV files") + args = parser.parse_args() + # Feature and label column definitions + selected_features = [ + "Destination Port", # formatted_data["Destination Port"] + "Flow Duration", # formatted_data["Flow Duration (ms)"] + "Total Fwd Packets", # formatted_data["Total Fwd Packets"] + "Total Backward Packets", # formatted_data["Total Backward Packets"] + "Total Length of Fwd Packets", # formatted_data["Total Length of Fwd Packets"] + "Total Length of Bwd Packets", # formatted_data["Total Length of Bwd Packets"] + "Fwd Packet Length Max", # formatted_data["Fwd Packet Length Max"] + "Fwd Packet Length Min", # formatted_data["Fwd Packet Length Min"] + "Fwd Packet Length Mean", # formatted_data["Fwd Packet Length Mean"] + "Fwd Packet Length Std", # formatted_data["Fwd Packet Length Stddev"] + "Bwd Packet Length Max", # formatted_data["Bwd Packet Length Max"] + "Bwd Packet Length Min", # formatted_data["Bwd Packet Length Min"] + "Bwd Packet Length Mean", # formatted_data["Bwd Packet Length Mean"] + "Bwd Packet Length Std", # formatted_data["Bwd Packet Length Stddev"] + "Flow Bytes/s", # formatted_data["Flow Bytes/s"] + "Flow Packets/s", # formatted_data["Flow Packets/s"] + "Flow IAT Mean", # formatted_data["Flow IAT Mean (ms)"] + "Flow IAT Std", # formatted_data["Flow IAT Stddev (ms)"] + "Flow IAT Max", # formatted_data["Flow IAT Max (ms)"] + "Flow IAT Min", # formatted_data["Flow IAT Min (ms)"] + "Fwd IAT Mean", # formatted_data["Fwd IAT Mean (ms)"] + "Fwd IAT Std", # formatted_data["Fwd IAT Stddev (ms)"] + "Fwd IAT Max", # formatted_data["Fwd IAT Max (ms)"] + "Fwd IAT Min", # formatted_data["Fwd IAT Min (ms)"] + "Bwd IAT Mean", # formatted_data["Bwd IAT Mean (ms)"] + "Bwd IAT Std", # formatted_data["Bwd IAT Stddev (ms)"] + "Bwd IAT Max", # formatted_data["Bwd IAT Max (ms)"] + "Bwd IAT Min", # formatted_data["Bwd IAT Min (ms)"] + "Fwd PSH Flags", # formatted_data["Fwd PSH Flags"] + "Bwd PSH Flags", # formatted_data["Bwd PSH Flags"] + "Fwd URG Flags", # formatted_data["Fwd URG Flags"] + "Bwd URG Flags", # formatted_data["Bwd URG Flags"] + "Fwd Packets/s", # formatted_data["Fwd Packets/s"] + "Bwd Packets/s", # formatted_data["Bwd Packets/s"] + "Down/Up Ratio", # formatted_data["down_up_ratio"] + "Average Packet Size", # formatted_data["average_packet_size"] + "Avg Fwd Segment Size", # formatted_data["avg_fwd_segment_size"] + "Avg Bwd Segment Size", # formatted_data["avg_bwd_segment_size"] + "Packet Length Mean", # formatted_data["Packet Length Mean"] + "Packet Length Std", # formatted_data["Packet Length Std"] + "FIN Flag Count", # formatted_data["FIN Flag Count"] + "SYN Flag Count", # formatted_data["SYN Flag Count"] + "RST Flag Count", # formatted_data["RST Flag Count"] + "PSH Flag Count", # formatted_data["PSH Flag Count"] + "ACK Flag Count", # formatted_data["ACK Flag Count"] + "URG Flag Count", # formatted_data["URG Flag Count"] + "CWE Flag Count", # formatted_data["CWE Flag Count"] + "ECE Flag Count", # formatted_data["ECE Flag Count"] + "Subflow Fwd Packets", # formatted_data["Subflow Fwd Packets"] + "Subflow Fwd Bytes", # formatted_data["Subflow Fwd Bytes"] + "Subflow Bwd Packets", # formatted_data["Subflow Bwd Packets"] + "Subflow Bwd Bytes", # formatted_data["Subflow Bwd Bytes"] + # "Init_Win_bytes_forward", # formatted_data["Init_Win_bytes_forward"] + # "Init_Win_bytes_backward", # formatted_data["Init_Win_bytes_backward"], + "Label" + ] + + standardized_features = [col.strip().lower().replace(" ", "_") for col in selected_features] + all_data = [] + logging.info("Starting data processing...") + + # Load and process CSV files + for file_name in os.listdir(args.input): + if file_name.endswith(".csv"): + file_path = os.path.join(args.input, file_name) + logging.info(f"Processing file: {file_path}") + df = pd.read_csv(file_path) + df.columns = [col.strip().lower().replace(" ", "_") for col in df.columns] + df_selected = df[[col for col in standardized_features if col in df.columns]].dropna(how="any", subset=["label"]) + all_data.append(df_selected) + + # Combine all data + if all_data: + final_data = pd.concat(all_data, ignore_index=True) + else: + raise ValueError("No valid CSV files found in the input folder.") + final_data = numerical_encoding(final_data, "label") + X = final_data.drop(columns=["label"]) + y = final_data["label"] + # Data cleaning and transformation + X.replace([float('inf'), float('-inf')], np.nan, inplace=True) + X.fillna(X.median(), inplace=True) + X = encode_numeric_zscore(X, X.columns) + X = encode_numeric_range(X, X.columns) + scaling_params_file = os.path.join(args.output, "scaling_params.json") + with open(scaling_params_file, 'w') as f: + json.dump(params, f) + label_map_file=os.path.join(args.output,"label_map.json") + with open(label_map_file,"w") as f: + json.dump(label_map,f) + # Encoding the label column + # Save processed data + + final_data = pd.concat([X, y], axis=1) + # Save the final data to a CSV file + output_path = os.path.join(args.output, "processed_data.csv") + final_data.to_csv(output_path, index=False) + logging.info(f"Processed data saved to {output_path}") + logging.info("Data processing completed successfully!") -- GitLab