From ac60206cd11ad3222a9bccc32f8b11ce68106b40 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E6=B1=AA=E5=89=8D?= <1104239712@qq.com>
Date: Sun, 9 Feb 2025 15:17:40 +0800
Subject: [PATCH] ctgantraining

---
 ctgan/.idea/.gitignore                        |   8 +
 ctgan/.idea/ctgan.iml                         |  12 +
 .../inspectionProfiles/profiles_settings.xml  |   6 +
 ctgan/.idea/misc.xml                          |   4 +
 ctgan/.idea/modules.xml                       |   8 +
 ctgan/.idea/vcs.xml                           |   6 +
 ctgan/__init__.py                             |  16 +
 ctgan/__main__.py                             | 231 +++++++
 ctgan/__pycache__/__init__.cpython-311.pyc    | Bin 0 -> 545 bytes
 ctgan/__pycache__/__main__.cpython-311.pyc    | Bin 0 -> 2550 bytes
 .../__pycache__/data_sampler.cpython-311.pyc  | Bin 0 -> 9189 bytes
 .../data_transformer.cpython-311.pyc          | Bin 0 -> 13850 bytes
 ctgan/__pycache__/load_data.cpython-311.pyc   | Bin 0 -> 168 bytes
 ctgan/config.py                               |  48 ++
 ctgan/data.py                                 |  88 +++
 ctgan/data_sampler.py                         | 153 +++++
 ctgan/data_transformer.py                     | 265 ++++++++
 ctgan/dataprocess_using_gan.py                | 141 +++++
 ctgan/load_data.py                            |  73 +++
 ctgan/synthesizers/__init__.py                |  10 +
 .../__pycache__/__init__.cpython-311.pyc      | Bin 0 -> 776 bytes
 .../__pycache__/base.cpython-311.pyc          | Bin 0 -> 8812 bytes
 .../__pycache__/ctgan.cpython-311.pyc         | Bin 0 -> 29580 bytes
 .../__pycache__/tvae.cpython-311.pyc          | Bin 0 -> 13794 bytes
 ctgan/synthesizers/base.py                    | 151 +++++
 ctgan/synthesizers/ctgan.py                   | 564 ++++++++++++++++++
 ctgan/synthesizers/tvae.py                    | 246 ++++++++
 dataprocess/Dockerfile                        |  20 +
 dataprocess/dataprocess.py                    | 191 ++++++
 29 files changed, 2241 insertions(+)
 create mode 100644 ctgan/.idea/.gitignore
 create mode 100644 ctgan/.idea/ctgan.iml
 create mode 100644 ctgan/.idea/inspectionProfiles/profiles_settings.xml
 create mode 100644 ctgan/.idea/misc.xml
 create mode 100644 ctgan/.idea/modules.xml
 create mode 100644 ctgan/.idea/vcs.xml
 create mode 100644 ctgan/__init__.py
 create mode 100644 ctgan/__main__.py
 create mode 100644 ctgan/__pycache__/__init__.cpython-311.pyc
 create mode 100644 ctgan/__pycache__/__main__.cpython-311.pyc
 create mode 100644 ctgan/__pycache__/data_sampler.cpython-311.pyc
 create mode 100644 ctgan/__pycache__/data_transformer.cpython-311.pyc
 create mode 100644 ctgan/__pycache__/load_data.cpython-311.pyc
 create mode 100644 ctgan/config.py
 create mode 100644 ctgan/data.py
 create mode 100644 ctgan/data_sampler.py
 create mode 100644 ctgan/data_transformer.py
 create mode 100644 ctgan/dataprocess_using_gan.py
 create mode 100644 ctgan/load_data.py
 create mode 100644 ctgan/synthesizers/__init__.py
 create mode 100644 ctgan/synthesizers/__pycache__/__init__.cpython-311.pyc
 create mode 100644 ctgan/synthesizers/__pycache__/base.cpython-311.pyc
 create mode 100644 ctgan/synthesizers/__pycache__/ctgan.cpython-311.pyc
 create mode 100644 ctgan/synthesizers/__pycache__/tvae.cpython-311.pyc
 create mode 100644 ctgan/synthesizers/base.py
 create mode 100644 ctgan/synthesizers/ctgan.py
 create mode 100644 ctgan/synthesizers/tvae.py
 create mode 100644 dataprocess/Dockerfile
 create mode 100644 dataprocess/dataprocess.py

diff --git a/ctgan/.idea/.gitignore b/ctgan/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/ctgan/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/ctgan/.idea/ctgan.iml b/ctgan/.idea/ctgan.iml
new file mode 100644
index 0000000..8b8c395
--- /dev/null
+++ b/ctgan/.idea/ctgan.iml
@@ -0,0 +1,12 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<module type="PYTHON_MODULE" version="4">
+  <component name="NewModuleRootManager">
+    <content url="file://$MODULE_DIR$" />
+    <orderEntry type="inheritedJdk" />
+    <orderEntry type="sourceFolder" forTests="false" />
+  </component>
+  <component name="PyDocumentationSettings">
+    <option name="format" value="PLAIN" />
+    <option name="myDocStringFormat" value="Plain" />
+  </component>
+</module>
\ No newline at end of file
diff --git a/ctgan/.idea/inspectionProfiles/profiles_settings.xml b/ctgan/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/ctgan/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+<component name="InspectionProjectProfileManager">
+  <settings>
+    <option name="USE_PROJECT_PROFILE" value="false" />
+    <version value="1.0" />
+  </settings>
+</component>
\ No newline at end of file
diff --git a/ctgan/.idea/misc.xml b/ctgan/.idea/misc.xml
new file mode 100644
index 0000000..4f4e5a4
--- /dev/null
+++ b/ctgan/.idea/misc.xml
@@ -0,0 +1,4 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectRootManager" version="2" project-jdk-name="D:\anconda" project-jdk-type="Python SDK" />
+</project>
\ No newline at end of file
diff --git a/ctgan/.idea/modules.xml b/ctgan/.idea/modules.xml
new file mode 100644
index 0000000..b65f96f
--- /dev/null
+++ b/ctgan/.idea/modules.xml
@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectModuleManager">
+    <modules>
+      <module fileurl="file://$PROJECT_DIR$/.idea/ctgan.iml" filepath="$PROJECT_DIR$/.idea/ctgan.iml" />
+    </modules>
+  </component>
+</project>
\ No newline at end of file
diff --git a/ctgan/.idea/vcs.xml b/ctgan/.idea/vcs.xml
new file mode 100644
index 0000000..6c0b863
--- /dev/null
+++ b/ctgan/.idea/vcs.xml
@@ -0,0 +1,6 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="VcsDirectoryMappings">
+    <mapping directory="$PROJECT_DIR$/.." vcs="Git" />
+  </component>
+</project>
\ No newline at end of file
diff --git a/ctgan/__init__.py b/ctgan/__init__.py
new file mode 100644
index 0000000..f693c93
--- /dev/null
+++ b/ctgan/__init__.py
@@ -0,0 +1,16 @@
+# -*- coding: utf-8 -*-
+
+"""Top-level package for ctgan."""
+
+__author__ = 'Qian Wang, Inc.'
+__email__ = 'info@sdv.dev'
+__version__ = '0.10.2.dev0'
+
+from ctgan.load_data import final_data
+from ctgan.synthesizers.ctgan import CTGAN
+from ctgan.synthesizers.tvae import TVAE
+from ctgan.config import HyperparametersParser
+
+__all__ = ('CTGAN', 'TVAE', 'final_data','HyperparametersParser')
+
+
diff --git a/ctgan/__main__.py b/ctgan/__main__.py
new file mode 100644
index 0000000..6a38b64
--- /dev/null
+++ b/ctgan/__main__.py
@@ -0,0 +1,231 @@
+# import os
+# import argparse
+# import pandas as pd
+# from ctgan import CTGAN
+# import logging
+# 
+# # Set up logging
+# logging.basicConfig(
+#     level=logging.INFO,
+#     format="%(asctime)s [%(levelname)s] %(message)s",
+#     handlers=[
+#         logging.StreamHandler()  # Log to stdout
+#     ]
+# )
+# 
+# def train_ctgan(input_data, output_model, epochs, discrete_columns,embedding_dim,
+#                 generator_dim, discriminator_dim,generator_lr,discriminator_lr ,batch_size):
+#     logging.info("Starting CTGAN training...")
+#     data = pd.read_csv(input_data)
+#     logging.info(f"Data loaded from {input_data} with shape {data.shape}")
+#     ctgan = CTGAN(
+#         embedding_dim=embedding_dim,
+#         generator_dim=generator_dim,
+#         discriminator_dim=discriminator_dim,
+#         generator_lr=generator_lr,
+#         discriminator_lr=discriminator_lr,
+#         batch_size=batch_size,
+#         epochs=epochs,
+#         verbose=True
+#     )
+# 
+#     logging.info(f"CTGAN initialized with {epochs} epochs")
+#     ctgan.fit(data, discrete_columns)
+#     logging.info(f"Training completed. Saving model to {output_model}")
+#     ctgan.save(output_model)
+#     logging.info("Model saved successfully.")
+# 
+# if __name__ == "__main__":
+#     # Define default values that match SageMaker hyperparameters
+#     input_data = os.getenv("SM_CHANNEL_DATA1", "/opt/ml/input/data/data1/processed_data.csv")
+#     output_model = os.getenv("SM_OUTPUT_MODEL", "/opt/ml/model/ctgan_model.pt")
+#     epochs = int(os.getenv("SM_EPOCHS", 80))
+#     discrete_columns = os.getenv("SM_DISCRETE_COLUMNS", "label").split(",")
+# 
+#     embedding_dim = int(os.getenv("SM_EMBEDDING_DIM", 128))
+#     generator_dim = tuple(map(int, os.getenv("SM_GENERATOR_DIM", "256,256").split(",")))
+#     discriminator_dim = tuple(map(int, os.getenv("SM_DISCRIMINATOR_DIM", "256,256").split(",")))
+#     batch_size = int(os.getenv("SM_BATCH_SIZE", 500))
+#     generator_lr = float(os.getenv("SM_GENERATOR_LR", 2e-4))
+#     discriminator_lr = float(os.getenv("SM_DISCRIMINATOR_LR", 2e-4))
+# 
+# 
+#     # Log loaded hyperparameters for debugging
+#     logging.info("Loaded hyperparameters:")
+#     logging.info(f"Input Data Path: {input_data}")
+#     logging.info(f"Output Model Path: {output_model}")
+#     logging.info(f"Epochs: {epochs}")
+#     logging.info(f"Discrete Columns: {discrete_columns}")
+# 
+#     # Call the training function
+#     train_ctgan(input_data, output_model, epochs, discrete_columns,embedding_dim,generator_dim,
+#                 discriminator_dim,generator_lr,discriminator_lr,batch_size)
+import datetime
+import json
+import os
+import random
+
+import torch
+import torch.nn as nn
+import pandas as pd
+import logging
+from itertools import product
+from ctgan import HyperparametersParser
+from sklearn.metrics import mean_squared_error
+from sklearn.model_selection import ParameterGrid
+from ctgan import CTGAN
+
+
+# 设置日志
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s [%(levelname)s] %(message)s",
+    handlers=[logging.StreamHandler()]
+)
+
+
+# 定义模型训练函数
+def train_ctgan(input_data, output_model, epochs, discrete_columns, embedding_dim,
+                generator_dim, discriminator_dim, generator_lr, discriminator_lr, batch_size):
+    logging.info("Starting CTGAN training...")
+    data = pd.read_csv(input_data)
+    logging.info(f"Data loaded from {input_data} with shape {data.shape}")
+
+    # 初始化 CTGAN 模型
+    ctgan = CTGAN(
+        embedding_dim=embedding_dim,
+        generator_dim=generator_dim,
+        discriminator_dim=discriminator_dim,
+        generator_lr=generator_lr,
+        discriminator_lr=discriminator_lr,
+        batch_size=batch_size,
+        epochs=epochs,
+        verbose=True
+    )
+    logging.info(f"begin train ...")
+
+    # 训练模型
+    ctgan.fit(data, discrete_columns)
+
+    # 保存模型
+    ctgan.save(output_model)
+    logging.info(f"Model saved to {output_model}")
+    return ctgan
+
+
+# 网格搜索主程序
+def evaluate_ctgan(ctgan, input_data, discrete_columns, n_samples=1000):
+    """
+    使用 CTGAN 的 `sample` 方法评估生成样本的质量。
+
+    参数:
+        ctgan: 训练好的 CTGAN 模型。
+        input_data: 输入数据文件路径。
+        discrete_columns: 离散列的名称列表。
+        n_samples: 每个类别生成的样本数量。
+
+    返回:
+        accuracy: 分类准确性。
+        loss: 分类损失。
+    """
+    # 读取真实数据
+    data = pd.read_csv(input_data)
+    logging.info(f"Data loaded for evaluation with shape: {data.shape}")
+
+    # 获取真实标签
+    label_column = 'label'  # 假设第一个离散列为目标列
+    unique_labels = data[label_column].unique()
+
+    # 生成少数类样本
+    generated_data = []
+    for label in unique_labels:
+        logging.info(f"Generating samples for label: {label}")
+        generated_samples = ctgan.sample(
+            n=n_samples,
+            condition_column=label_column,
+            condition_value=label
+        )
+        generated_data.append(generated_samples)
+
+    # 合并生成样本
+    generated_data = pd.concat(generated_data, ignore_index=True)
+    logging.info(f"Generated data shape: {generated_data.shape}")
+    # 获取真实数据和生成数据的分布
+    real_distribution = data[label_column].value_counts(normalize=True)
+    generated_distribution = generated_data[label_column].value_counts(normalize=True)
+    # 使用均方误差 (MSE) 评估分布差异
+    common_labels = real_distribution.index.intersection(generated_distribution.index)
+    mse = mean_squared_error(real_distribution[common_labels], generated_distribution[common_labels])
+    logging.info(f"[validate] Generated data distribution MSE: {mse:.4f}")
+    # 假设生成样本的标签完全正确,计算伪分类的准确率
+    predicted_labels = generated_data[label_column].values
+    real_labels = data[label_column].sample(len(generated_data), random_state=42).values
+    correct_predictions = (predicted_labels == real_labels).sum()
+    accuracy = correct_predictions / len(real_labels)
+    # 使用二分类交叉熵计算损失
+    criterion = nn.BCEWithLogitsLoss()
+    predicted_scores_tensor = torch.tensor(predicted_labels, dtype=torch.float32)
+    real_labels_tensor = torch.tensor(real_labels, dtype=torch.float32)
+    loss = criterion(predicted_scores_tensor, real_labels_tensor)
+    logging.info('[validate] loss: {:.4f}, acc: {:.2f}'.format(loss.item(), accuracy * 100))
+    return accuracy, loss.item()
+
+def save_metadata_to_json(output_path, model_path, score, params):
+    metadata = {
+        "ModelPath": model_path,
+        "Score": score,
+        "Parameters": params
+    }
+    json_path = output_model+ "metadata.json"
+    try:
+        with open(json_path, "w") as json_file:
+            json.dump(metadata, json_file, indent=4)
+        logging.info(f"Metadata saved to {json_path}")
+    except Exception as e:
+        logging.warning(f"Error saving metadata to JSON: {e}")
+
+
+if __name__ == "__main__":
+    # 定义数据和超参数
+    input_data = os.getenv("SM_CHANNEL_DATA1", "/opt/ml/input/data/train/processed_data.csv")
+    output_model = os.getenv("SM_OUTPUT_MODEL", "/opt/ml/model")
+    hyperparameters_file = "/opt/ml/input/config/hyperparameters.json"
+
+    # 解析超参数
+    parser = HyperparametersParser(hyperparameters_file)
+    hyperparameters = parser.parse_hyperparameters()
+
+    # 提取超参数
+    epochs = hyperparameters.get("epochs", 40)
+    discrete_columns = hyperparameters.get("discrete_columns", "label").split(",")
+    embedding_dim = hyperparameters.get("embedding_dim", 128)
+    generator_dim = eval(hyperparameters.get("generator_dim", "(256, 256)"))
+    discriminator_dim = eval(hyperparameters.get("discriminator_dim", "(256, 256)"))
+    generator_lr = hyperparameters.get("generator_lr", 1e-4)
+    discriminator_lr = hyperparameters.get("discriminator_lr", 1e-4)
+    batch_size = hyperparameters.get("batch_size", 100)
+
+    # 训练模型
+    timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
+    model_path = os.path.join(output_model, f"ctgan_model_{timestamp}.pt")
+    logging.info(f"Training with parameters: {hyperparameters}")
+
+    ctgan = train_ctgan(
+        input_data=input_data,
+        output_model=model_path,
+        epochs=epochs,
+        discrete_columns=discrete_columns,
+        embedding_dim=embedding_dim,
+        generator_dim=generator_dim,
+        discriminator_dim=discriminator_dim,
+        generator_lr=generator_lr,
+        discriminator_lr=discriminator_lr,
+        batch_size=batch_size
+    )
+
+    # 评估模型
+    score = evaluate_ctgan(ctgan, input_data, discrete_columns)
+    logging.info(f"Evaluation Score: {score}")
+
+    # # 保存元数据
+    # save_metadata_to_json(output_model, model_path, score, hyperparameters)
\ No newline at end of file
diff --git a/ctgan/__pycache__/__init__.cpython-311.pyc b/ctgan/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..22ad241c302e2fbe3a07d78fc8523242805d766d
GIT binary patch
literal 545
zcmZ3^%ge>Uz`$U6I4j+ok%8echy%k+P{wCF1_p-d3@Hpz3@MCJj44b}OexG!%qc7>
ztT`;XtWm6t3@I!rY&mSX>{0AsHhT_7E@u=cBSQ*D6jusoFoPynl}t!}fo@J}S!#|#
zL1J=tVtT4VT7Hp2a!Gn(o?aDyU}j>TLU>|cx{iWpUb0>lPi9_PzC&?JnO;h2SrxZ|
zo}q!B5r}Q@667pR##>xznR$sh@hOQViJFYJSe--M9sM*JZ?S}gIl5{x7qKufFch(Z
z2xbNb20u;iTkP@iDf!9q@wd3*;}c6uGV+V!<8N`s$EW5dX6D4l-{OvsFH0>d&dkq?
zkH5taaac}%VhY%bB9JM!Bp{;2m3bu@sl}O9sYS(lU?H$|x5UA0JrwDZvcy!dJ3#g)
z=H$f3uVnZP^1?4|7ps_*ti-(Z{D7kTtkmR^n3BYt?2^o~)EKZiG4b)4d6^~g@p=W7
zzc_4i^HWN5QtgTa85kHC85kIfqZt?&J}@&fGJarVU{t=qp!9%Ou7T?Yk5mKa4L<1x
i?gtEN7f{g+2CWOI=mvwv1yuBaOS6F;1dI3?7#IK=EuJ(0

literal 0
HcmV?d00001

diff --git a/ctgan/__pycache__/__main__.cpython-311.pyc b/ctgan/__pycache__/__main__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..182e053419b2c39b7bce116a95f6bb164b049920
GIT binary patch
literal 2550
zcmZ3^%ge>Uz`*dKuqfS=je+4Yhy%lHP{wBsCI*J-3@HpLj5!QZ5SlTH5zJ?bVoqU5
zVaj32WsPEGgvhZ)u`w~EvSqPD^@G?PFm?(vSdJ6MU&g?|uo}i;NM~HbvW%I5VKp;M
zh>;<MwS^&yE0sHiEt?4>T6BhqA(bJE2gC;96!sL3WlRhVtHH7$ktkj?6`UztXe#)a
z7*hGsRB@;9psHHN$iT3g5o96=M+u;*=S|^5Rlh(Gqz)aY@S~X{1P%dVn7R}JRCz`Q
zG<m^P-q{RunFy#7O64O$m2j#65voK|g>b0iO65)EOBF~JTE@b_u$l$t2S$bz#$Yhk
z6nzO2^wVU##p)d5?&w#guBwq(oLrKbo2pr?5Ur|_lUkOVla~nM#ww_4<faxEC#I)r
z7Hcx!Vg<|HVoS>}%1tb}#gUPimy(lORGbXAh=G9tlq^BX<MSj&P_it6M-T%;7Mz{J
zh#p)etRQzma27~Em|4Sss;-s=hj~mj46_+hm}^+F;BJ}CFqZ{YHzPv{H`rJ(QNxtN
zvW68kiPW;<FrTf4X*NR&`&?#p6Khx(fP4%#5t*oAsbNMnn~?#B*&H>D*rKzBbpbq!
zqL_>scJOdT31d!lx27-zGiY-8RY?VxBo>uq=A|oulYl}=QDSCZW?s6Uo?expOJYf)
zLQZ~SN@|KiT2X$kLKTlfd1gt5LUBf7L8?NCCetmp)PnrvjN)7DWvN9;`NgSK;t+E(
z^D;{^6LT`FQd1yWs@N4EN~=^tpmr!E=jRsWq?V+n=qUszmVq3Ro1c=JqfnBsP$lUL
zW)>%wrKTtpmnJ8t78j?L=HyiBX>#0R&&f|u&&*4|#gdtqmVb+>AmtWEQEFmJd~$Kw
zEmpAWia-hH7IRu=$t@O;A)0KrxH9t!N=xEX5=#<q@#L45fH+_ai$G=GErFEG;^d;#
zlGOO*{G8I<yy9CdV5O|dCFzNI#h|EF0E1uJE><xqS&4b+`2j`wS*gh-F(rvP*(I4}
zsWD)+G4b)aiJ5uv@p=W7x46MU8V{Bz5@%pw5MW?nC_cr&z|g>Omxa5<=YpW_MHam)
zEP5AM^zL%=_E=A-y2!0`g<Gk?1uE~kAnJmk=LJE}9U*(5tcxsOS6I9*uz20&5u72m
zLgWIE-bEh0D?EA)u24-@E6gqkT0zi77V9f4))!c;Z?JH-dp3D?q+AeDxyYh=g+=uO
ziz-xu+YX)!f^HWC-7d1YUtw{-z~T;4?bYPfQFD<+<_e3<1r`~oQk{z|x>s0qFTl`e
zP@=rW0rU1P-eBK&PrrcBka!oz5XUMN{rrLw{oEXVa0KXsLRB9eDw%odV4fZ*-uQ!k
z<NZTJK)QVWU0i*tWMJCBk)#g}!FVuRub|`>XRvR)Yk<GANAOEf#du3N*f-wAGuS!E
zHN-XE+21GB*Dv@MYffTPYR)f4ohor3NIK4_EJ!UXNGwXsO)W_+Dz>T;^aP~}1#nsq
zNG!>)Qm7L42PX~%ND_n!a=3z1x|KqekP9@SC^$nBij{(<$StP);#(YvMd<~JMa8MN
z*a{N!QWA@cKy_b{C<6lnI6W4DUCN!5Se%*coS&DLnSP7K)6dQS7H@D#QEFnY2c#4Q
zClgR`g38|b_*-o0sU@j-WksM$t4N4}f#DW&W?l&-)v^{B<Ybl<f$FFtP>HsZp-2K`
zofrcH!!Hh--29Z%oK(9a9R>ylP|;All#zkq12ZEd;|&JR1~9zAz~2CdHyDI3z|ai_
z@e8Qv1_NIM7=B=5U=(R!zrrAVgG;bS=LWaQ0}j~^_A4B+7dhmwaL9jPX5wS}z|O|V
z^pSy$k?9MF_y8sZI2rXmFu(~h1{R)<s!Pn07nmg<2ue*5nj$&F{(?@>MZw@Jg24@5
zH-u#-7)`OfAgg!5!0)24{}o~X2A><e0v*8<B*BWUS1?@^)V?C9eUVq^3a?Is+YNDr
z2}M&XFDRPrV81Bta7EnVy13gVakm4-7sb7=h<jgP5&6Kzz$)0_4$5{~7g@BguxMXk
z(Z0bV01+|-<ut<!f`%7ajIOX4U0^YSOPMaHxgcnILD2Lfi`f+xvkNR{a4D4;ZWjbq
zE(jvB6kN*Uf|}U{L5mB578hA8udrBNV6g<bK%l{U0@n<m6>1lZoG!9BUtw{+07D<-
N8Caw)FiV1?5df9yZ6N>v

literal 0
HcmV?d00001

diff --git a/ctgan/__pycache__/data_sampler.cpython-311.pyc b/ctgan/__pycache__/data_sampler.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c7b71467eacb17d46d3c02c760b931be117cb265
GIT binary patch
literal 9189
zcmZ3^%ge>Uz`$VFos^y-!NBks#DQT}DC4sp0|Uc!h7^V<h7`sWrX0p7CME`Vh7{%&
zh7^`m=4H$b46B);Y8j$fQka4nG+C>JT@p(YgA;QLa#D*Fa`RJ4b5iwQf>ik>gGd-=
zgfc$cfGtU7h+<4(h+;}%jABk<Y5^I?+`<sWn!?h;5XF|l+QJaUp2F6`5XF(g-og;Y
znZnV+5XF_k8O)%`b&DIx$|@fujzTfm9mNVI8L0}%`FSasC7Jnoi8%^osmUezMGA>|
zDGJH?MMbH_1t7`1bcK|}l0=2H{2~SC5O+sEy<`@sGeHz6#6bT2%)!dQ&@M2Yp#<g$
z1_p*2#$}8Q46ETHH4O1^b_!z)Lzh4aEL<5F7&-*f7*m*9IBJ+&7?v?GFsufd55}mb
z)UuSYgN4CF7Th&8EGf)uSe7v{Fsx<*OMzU#5D&K@ouP&yo((3G!dlA;vMF|TEgKU<
zr$7xGsx3uLHEb!Yk}%~Awd^&lHSAe%o2wWY7-~7v8EQF81mTJpkX4|X=fV(cQ_EGu
zxquU)8o@&G5hzT-QZ<Yz3@Hq)Of{@ETp;_>m=NwkH67Wds4|QUJ+?IrS#W=2x4(uH
zT~`f9Eo)DI4J&dCff8E{8~)g-;$>i{W$js0!&<|b#SQib0|P?}JHqAJ;Mgf<@?vOW
zEI|lE;-3Rs{HL(4;Y5wsT5gmOBa{YsVeTMF7at*AHQZ@TsNq(_5YG?ORl`!l5HA4Z
zA%`P4rE_8PDTZnAv;|79DXg_TDXg`;;B<Q$N4i}^Ou9u02XK6V+=-OlI9i#IQ(Fq_
z8g8^0;HzOS5=Svd2xJBb*YKsVgVHOiZ)*8!_)yd%?BzpMS;fu3P|IJ#U&D}v2u+?8
zl-vb!AE=nD;jLlFg2!(S4_FK-pVlyB!Q&As%Zr>#!ReK^h9`v`%;Ev3`}Jt)o~MO@
z$aH@ug&~+hlh-d9UeGZxFn|h3VFm_<&tBki9i@n5K=pbGBP4s|3)C>AAZ7O~xOpke
z5Ea1;nk;@VLBd6#fV{<CoS#;bn^>XAbc;DBHSZR0Vsc4lSt6)#k59|fWW2?loS##g
zn-`y%mzG})(x>oC-^D5>B`Yy6JwKo*KPxr4B&H-WC%YuGEHx&%Bt0=N22}LNLn@3S
zy@JYHLYc+!DVfE|MX4pJ@euQ>43H`pz4-Xdyv&mLcs-k({N%)(Vmm!d9Xt#S48>{;
z3=9nncZJ0!R9_TUz9Ou=pzxBg#s%T97128c*Th}X^*j-INjK~QPuOQrl0pejP>~J}
zPb~%phN+Cx89+sQEPpK{6GJBxdZn|TH-%v$Q;$F}gC=7UsQ6yVu#))}lb%5lsB|g<
zC7ohW1Su#q6oFj!i_0b_v$!NVKexcHN&~;&5Gq0GwAhM)f#F93!v_W?B`ySULrC_5
zfzt&8rvug(4g9Yd_<vw#QsTP7!waSn>WV<=^cHJoUWu{MO2%8PDJ7K!so=Op@jS?_
zV9(2d6BF3$oVAQKj9H*?0^^BHJzT*IMWC{0CF3pTl+4_fOhusFd5g2SAQ74ZK#4{{
z0qo@>ka@)WRUMR+rNHrbLCOt+Rz%LQoMQ`NUyyPG`vn}@C~k*GoC9*iN#Tk(P<e!?
zAW-}#9n7$j5ga&05)2FsMWAvLB^pWelMX0N8bRXk0#6bIE%3hJ6aqq_2^XCbuQ(-M
z;7RI;o*~#1cZF9TqT~Wk64<YrjJMcRQqvMkb4qCE<03hbSE=G(6OezMApS)pK``rv
zvfc&9U=Rw4zv!57#W4Y#P;c;v&tUBFyTT&}=7N=i39$cHGTvfItjH`z@jPis6%vu4
zbX81#+60xbMWCDkc9o`Z5hnu!!!4e8a6y)nnwMUZaf`XQH1`%$UcoKas??(V;#*Sa
zMOu7rVo6bEMSO9|ErIx=%#`?~%J}5O5>SCxTm(vxw^)l3^U_mqu_YE1q~@jEV$aLZ
z1DSP;wKyZOAoUh=ZeqnP2@GRFt!<E@(ik#%AoEkx^NT8P$zez)CzgPv;tPuMlWqye
z=b@`8zQr34Hz6~%_!e_ce!8aMEtcZcoU~gkAiv+@&Mz$~C@le({<nB@^3&tfic$+p
zQ}dE5ZwX;Bk*T-@Qb0<B0+1;+<rWtrByKTg-QvkjtU$P$E4j3&C^fGnJ~QPOKa36Y
zFE88&AUkg%xF8QfN?=g_RX_s{;KJz^2dqx162Vp)gCe-tjTv0>-j$JGz%tWwj%NqU
z6%NS{ER3Atj2{>f<X1KZIYn@yMI&w~>VQ-AM^MrI1w?#gV3Onh0wOMONPXpC(AHg}
zalzR2MEC`*(2H83SF}PwcBp^gVo+2Dv1C86G4SzsFx`+=`@q1>>A?sgJG^i3D_-DN
zT%diCU+)UP-UkLQUJu3_GU_WB=lD(Fydf$+LvxPqMN#z$Ob=w`Cve@Bl$%jGr}m<x
z?iESh4-AZgZXgc`x-m|0ydffYT}1Vgi0aa$C0QFpHY9@J4xt^1JH<9+ZjryJ=6FTT
z@qpqLHJ6Jbu2)1{FM!bv0g)>Ls@DZ{E(z$Yh`K0XaYex5f`G*hQSk{(6PO+d2u)3f
zl-fJAFG#svlybWw<#tib{fe0TMFDq^k{kSj6GEm$Enu7ydx>A^0t|g+V~~)(0V^pl
z@FZ=pzQB`oktgX2Ptpw$(GRT55?mh{m?gNrfQSyBuUrh8+G{v27&@OYzn~F(Q6u<@
zMldMgRPSo*uCcga;&vkaf_B(N?XWA_VIXmhyLtv|GA>wno$$F}9(hqO>WW?zNL>4d
zvMGpGxGSrG7K>o+4I^8eQ3=u~^??HtQKDiUTsNeoJNUs-#t)7%5Q!OOe7K_Qu7c7M
zwgn|C9Ko<?gVKu96&_3Kb~s*8u)C;WcSXT&0`CoB1yDTN-BnRr;<mzehv7wSn~N&8
zS5#~#_});^-O99uZA0xv1J^4Ct``+tC-B|~PME-QMZkE2*8%<mnn!dG@Lw?TzbFuJ
z!^RsVuDc@ZqJZTMdyhS}Ag;j%sfz+OHw5GtST3=>C}1#w|Av6njKqoj6S#<v;t%p5
ze}7;Ak>GMO8B_p+n-UBR44|$fs3iaV1>E4MVL(*GHH@Ir5TdY_sRX17!ePjQn@|F0
zF)(DoRV;va*`ON0R1H%WsEGh(ry!Y))<$Mp!;ID$VXk3X0CFc-Co+)&Qj0nW!JG}Y
zdjY5|h^!RM##9e(XG2;VARR@D3=9mK97To<3=Bn}M(i!_)QW<{yp;Hq%-rHzY(<HA
zDfzj#SU_}<E~xBeOD@d?)i`X4Md_gSIZKf#NLUE19StfuSwRJ2@h!%pVvrL-#UZ2!
zl#B-%m!BJ-oROcIoC>Z8;xqG7QY)%dQOZPI%6vf8eG<6I1}+ecFPOMp5peHdz01Sb
z<2To6j^%v2S#~Q_*Jy7@T%&hU+VYCD<pIVEJWdyRoUZUVbuiu#5S`95iD!Z01fGim
zDpv$l78G6*(74WTaEaewL*fN~gNytYSNJV1a9Cg_AW)QoBL4FZM$p(9asq;NcZ!%m
zBV{N_1)gjcfb@bR2bn-gY@pBrOCp<%nrLbm(UJ^mas{~&Vge)bXh|9q=9p3~Gm71)
zI>Bau2&_6;QdnzPP;?_)jy_h#l*R<^%Q9e3s5NLK54Ef{3|Sl?n?bk+sq0w73hIm3
zvZb)rpt>Bl-E2_1nQK^Tm}*$lSb`Ze*%B`^GBCKO=A{-TmZahuJp>ICmJ}ss=4Ixk
z>v3^`k}wDdrIwTy<rQ1O1wewy`FSY{8cB&I$r%b23d+z%LS|~QCaOA++7PURm3k=V
z<t7$qBbf)wUkZ?ZDmGIc74q{^bu;oyV1u8Tc`2F6i6xnN>0l2Qr{<(4m!zgBpqYhg
zFPfx6Cd^k@{Nb5blA4}cq)?PvP?TDnnpXnupey93DS%v$YA!TDD&sR#;xqF=txu@$
zu$UVRb16)dLS`P3UI3SEez(}ua`F>PjE#y!K_wfgsMF*u0`-W%rCgC4sLW!6ls~uF
z62a}?B2Z7T$QY#F7ewfTibi%w@x}^{*CJ4R{1yu+rf;!;e0WO;=0Jq6ZgC@t6@zRB
z#kc~rb9##pGX4+`GNdduxk?JT#6y(}1Qm0Sz{T7LDFz;aE8L)A1#XQC+!{B8#V43N
zm6e|pd0kfblC199q&1m4Sgu$ET+|D^q8E4}IO3vg<Q3V-3mj5U`2~9_W{56exuRfn
zLBe>4<pqA%i~O!v_+2|#?jqzYE=X7&ki5X}eUab$3cq&;%MDYr4wl}ap2!KM7kOo`
z@XB70H@v`Wc#*^ChOFXsS?f!())#F2F3S2}k@dg8A$3Dc>WUb+AQp4K;1zRGET)6|
zhMqpE5i_bU@M>M;(0;%#JR#(QwC)9c#|!+97dRX-i&;>>fr`q{Z3K$h8V1B@(gJuf
z3o15X7<Kr)mZ^rRhAD+H8)P~Ia*bcZgfgIqC{UPC3t#Yn9#aWY4UMShvDe!u=7M|%
zb{fi1JDMG+Hr6oKfNFkb<ms4NW)xFF9S5+fHJ~CNKIoptl)|!x6|J&hNnxu&opyk`
z34Oe&h8bFTf(lLcDmP+E%ydYpnwg)cP+FX-kYALUo(Y<g0d<$aT{JyNCe>uT#aK|}
z56WwZT*?i}rJx20C=(}^l%(cC@~kRIAt>V(1%Ozf+yU+nBl4{#D>%0n1%c!pL0Oz5
zIX@>pGo_-qC>WHPK(!k<H$oaB;?VpCb0IuaDr3o%=!$Yc25tf8&IjE5=y~&jgv=EQ
z>+2HEmn57oO1NH;aP8o|A)&e;cSGSNHH#|}mSFB$rZsFUYA<ToUeU0<C}G>ddskR=
zisf}-<x9fKOP!W@u1(&cdr{Bfik`zo6~`+ojt5Ln_+AiJz9<}ZML4R111bArWbGT;
zI%s*X;sUSwMGlPz{K6gFko*P;4se=bU|;~H08ozjYyrt{%nfEGFgXSWf_V%%En=i0
zrZr4xDUh{>u?A^6DwrXip_a7-o`xA1YM5(SQ1!Sl#CpUqF)-A!fhMmOAiEvzP8=>q
z%BgUr3^fdh>8~0#L^)7{G`WMC3sHHD3_YIczR+Z@dJdTwg-$aT<>w;OKddDIkt<Ft
z$t+H*gjdesc}&n`CHfpDQp3ZsC>^cU0U8a#-o8Mw4ap3wjh8%3Gx8LQ^2>`MbvbDK
z2<&P^mF}m>c#EYtvno}S1Cl*KIrA1fsL7t0R|3wMx0tIk3yQQsnGu|0H9-|B2PBJ{
zfpRBf-Yw?j{G40N`2{7nm@`u<ZZRg`V$3fFHT=NY6kcm`Lvt#~yH(;?awk{<lp`mC
zbL0mx1|I$n>$^fCQ#7v&DPIy&Uh2GpWufO1&kc?jgp@A|nOzYw>u|myEIxyIisc0!
znFo9V9n5z{#Ah(i2%hLW#TPWr0v@d3bYtvv>2T~QyuriY@73iswP;G^b#e7e;_53{
zu4vfqNW7xqa6s^ihVw;nmn-5f7kON-@VH(8BTy~p+*5ghLjscBK(iFE^ao;tvm0nw
zdpamF!el@*E{s)bkZ1xodP?$35_1&tN^_G^i$L)KZC_-j7VBv;f*UNFjNnj(%!5G6
zHlCEsTyS|74{BvdAO{(;Og|{hKvM?|3=jB)ukb5g;829P7UVjZn?O||i2u0{+-XLh
z@<CL-HB2=~9pyBp6h@3}&s>AlX|7>RVMb&-1gn;%hNXrvg#{r3p34IDHkeb`QrKFV
zYFMC^5_2$vCVQ0}_DTub3f4oaVZiM-j<Up@($x6O6!3_?pC$_={UTa_8lZFvsY-6~
zf!b-{H3i_YMsU?~i@7W{8I)e*i$SFjIHf`YT?1P6fo7BPQsUE!K%HFBx&!dM6{O`3
zHmAx2IYNllH5n9HVc-P*luw{1V+QMtk}L9t7sQQrXkOrRxya{oh0mpf`39fR6+Yb+
zc^7O0ujmC`;S26yzQMuWQ8y#&3cuz>4$Thc8)D)e%pHzkSlr>($$f)gsDld}?|z!x
zpfSq4#N5>Q_*-1@@sI@`@$t8K;^PZT6LX+4?D6p_`N{F|Mb4o74Q?=pgIF;jA{|6z
zf{1(&0ZQXw!@;3c1WHX{0#q&+_k#V?zyN_aSi&x_gneLfWMl;mI-*c+Ow6p79~j^S
zmn5qb;|B&L(u{{y<O2hokP&8;{J;Pww3t|xJ}|%u4mMWd4-A+D#HAodYO)p;gUTt^
zywco)N^ml{#g?Cxm6}`vc0YJd3Zfi5*T!K32`sy!G6n_)P!=x^Wnf_Vz|6?Vc!NRm
j0)ylO2B8Zu^nopyiBb3i19ozP+DEX=7chyb3T!O^Xst}y

literal 0
HcmV?d00001

diff --git a/ctgan/__pycache__/data_transformer.cpython-311.pyc b/ctgan/__pycache__/data_transformer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eef72269ccd8a69bdbac53fbddfa87b182e24a6e
GIT binary patch
literal 13850
zcmZ3^%ge>Uz`$VFos=G-#lY|w#DQT}DC6@i1_p-d3@Hpz3@MB$OgW6XOi@gXAU1Oj
za}-MoV+wN)YZO}wV+u<SdoD*5M=oa+Cs>{}hbxymiaVDliieSbi6NCKg)N0Wg(ID5
z850A;Y9^@p3{kvbInH#36s|>#XmWf^4DJjm+${_#JgNN4m>C#WGsE=>q%a3FX!2Hx
zyCjw*h7={{6{qDF<)#)X<mRW8=A`Pq1gX?yyv3E5n46kXQd*Fc>Zi$ciz6VhC^07|
zHRl$4N@`AGWon8h(=9RQoYLZw)FP+E;?xws{G!~%oXo1!qFcQFd8r=xC9ZkN`6;PI
zw>W|e67xLs((-RHr)1{d;!R90$t+7O$;{7-Ps_U{?3|xdnwtlA3P|Bf)?3`kVBvU>
zGvJ((%7WBeoFy<<-{Q_MEh#81iO&QXE(qnMWag&k6=&w>6(@uI0>iMN`<w=jnpB1;
z#uSDqrWD2~<`m%+mKMe+mK4?&hA7q)wibpcwiJ;Rjuyr!_7u(*hA55{t`>$UPEe3W
zai#FIFhp^u@U}2S@ucvzFhudD@V78T@ui5S2(~ar@uvv2FhmKY2m~`|irwOOK@H}r
zJs<%ExNlR7^tiYb6ciME^HWlD6q57vN;30G^Gk~rApS2_D9<d(P)Jm8N~}yR&P>d6
z_w`js%u7+og9W!jNk(d}LP@?tqC#<UVoqX_LQ-l;d1`8&LbR@-j)Gw<*m#hZM1``{
z<dXa%Jq4G{;^d;#l2o`QiAAXjso>~JQ7A3W%u82DR6vPDJ+LV+K`H-vwa7a6hgBu^
z$?#Ogz`y_sUJ(DY5(6mZ*D%&F#KVG~p@yl3A)XP&O<@dX&}2&FW@KP+E=o--NmWQx
zNJ%V7RDee;JXAm-<5-klYz4L+!~@~n#EST2NYX1-(8$ay(L|L8sqsy$$jmLxRmdyN
zO-e0N$WK#nPb@74g`h%8W^qYTW>P6A1r{r0=E1{O!QIyv;!%(ZAY7iBnVwM+Uy@Oj
zTAY!elcJ!JmXn`|&AxE3N(F?<wEQB4MC_J=Vk;?CAv?981YsR0m+{6Ud+nBBJi7gw
z%(qyIQ*+X8@t`OaKv!1`iV6h<g<r-lRxv4AiFxVy0Y&*)smUcVC5buNC7ETZG07$A
ziFq-gfR9HG_=3t?9P#m)d6^~g@l_I_)QX%~VKNd73=G9m3=9kn3?GCT<g_|iu5d`*
zP&5S5au+y2=rbt3lR@bM!~kJX>H#_Ivj-z2jbwpD!8nCc5=?`LTBZ_^DiD)_fgy#d
z1XijtFff#`LxmX_7-|@oF)}c$hNr6<hAbACENY2Z!;l3xV+}KEiCD{A!@PhErVp8}
zVXQ$bbr~5-_z@~<n6o%wJXDv|vXltJWEdD~m{M43m`g+vA}OpXY^W|omaSn)VMmos
zXG&qMVX0zZU|7SkjFo|5HQYDB44Rx(A3}-}(UP$ODEVQQCeXqLl(aFk6(}hvXcVOA
zfg;_lC^0t`OTKahXVo;2EV#5v%*=xn7m)G_$%3HNlG36)w5$s%&LOo5W|nnSNJzj^
z86+eq<R^iWE+VfKfl}2i?zE!(+<1t+n#{K(ax;rTMOS=TVoqsle0pkLY7w}WDFUU_
zB54K&hFe^SAbAN&H@8?y67w={u@~p3mE<N?6!9=HFlch$Vuu)7Tm(uNx0rJ?^KLQc
zq~;ZgGcYg|fs+0$=CsU`Tg=6!xwj<C5_2+B;*;}p3-a?)^FWnRN@j9mNq$igH%K`s
ziD<GGfs*$v7Ellu@q%2!l%9KwFRv8jeyGOcVo<DtqDG-e2Bd^PJ}t8(9^v9Dc~Hcl
z<VI8lDxgfc29hbo7{sNg<Sa0plHb97msk0Qs?l9ZnYqCWL}o_LiCn>QMZ@Z%r1cd^
z>l?C)S7ePpu(R-SePm$a<N5+3I+#9iF>rGCvvsj`luU@e$f0<JL-7KK;$2CZ89`TM
zEiX!1U6Hiv;CsribVEY<hJ?xuF^Rj9iVGAMgk6->xgx2vHe*BGfy#?|fmiebFGvL4
zkdt4acTvvZ12>xh*GC360j@6~;sXO4Cs&ai0|Ns}F~P~ezyK-$K3{@XCBzgH%NQ6K
zR>N}`I8%cX30wrmz?!L%3l(OhLIkCFsbwhvMJUvu8m1Jc8s-vET*5?HQkc=QEn^MK
z8WxNU%Ubo$Ewe-+KQC1wBfmreQtuXlb8HG~#f+4jhJJSCfO)P+6_hW)S&h9&4a5e8
ze~~7Lr2$I#>?x(WxtXcO7#XNY2c!lR=C_#hGg6@$D6u3pJ-;Y36(a-jf-?}z+A3MB
z*#{==0LnahSTYZvbZ34~eh2dn4(@*TF7^p56S6OIC|%)Dy1=0X$`HYGqAyBnUXj#9
z<Oii2Vv?ZzptnKuqNL3gNt+82wm0N-a3=qi3@B|WP&ok30HAViIzu`Ga$70|)UpDV
zUEpMw0%}>o6FO?`SOO|~z^WLK>Oa&ZTI`#`T*D9#Hz9?kg`tMYg#j%)F@c)KD2bgh
z3lwQ!(^J?$g&Ua3fLae?n8REG%G?m8NN%cOtYKb*R_E0+)-WzW@&Qy60|T1dvfw7x
zFvNpe7+{6qFoS2;8pbS8as$h_FvRxOvedBDFqJVB$<{C|03|1|3^Gx}Si_Qz)Fwgo
zS1oIvObx>VK4cv*HX)rg42a@v0a9xMrk{bKhBZq7#zVE2k)ej6f-wso`!%cxb`>wE
z7)7+QYZ$WN{>Nnxy6HWADIDMy1-O7J$w*a5NI=Up2?<Drl3Qj;u|guK+J>~o(K><&
z3E)BvQG+4+8+j=TiSRakLIO%(1F8*{F=3jJ><!7tELJEkNi9%F%qh-SNX{?KD=7wb
zgcOvaH8-ePpPN`xlvx38uYkKu#RZ9Z3ZSlHQ7*V$fYzP}aZ82=ECT}rC=-Gjai2kD
z4kQ!C3f3}~AgMubfKFuU5ejA~GGkz1Sjl*cxwxcAlj#<dp200PsH;HPOaW56|KhUA
z$t*4bH5=@z)X@T2FD<h~&n727IWec$P7k3l9aKhv>dOX(3-aC&w1I6Y^8&Yp+)KD2
z{0s8lkSZ0)p%i+i$P(lurXp)l;mn*@np=$Iks@0L1_q*ikio#f@EPg@Y0nE_v_g2U
z^8)6X9&<b({0q{aU>_7IgREr)R}p@ig0}>q9ehv}#)AT{ND$;8PEex?+#tQhm06sb
zS6q^qmz;WwsUYPRCn#LOE!kVFMTvRoskc~*GZG6@Z}Dd4rKJ|dL)w+ax7boZBMZg1
zgrKd-_@cz}c&L~(Obp&h0uMJp900DpOhJ{BKFBF-i3J6zc_~Gp?5oKJY4dS_%}Oju
zyd?lj(NGhMZ}C9V1$dC77}72jg$kkCRt##0f}4hr)B~<8tAs(`Lv9bIWtK#Ml1K<6
zxML>3pr{P$kjdUqR07d*HzZ{{SbF%L@=MOpSn9OI=OVx16@J4H46K|kj5lPJI$V1E
z9zfESmDdG%?~C%@SLD4f@W|W%$J18U4H;YV4g_8?bH8Zdam4@>T^FT2uSk1#_}mqj
zoKmyEV?)jbap#NT&R4{pJ3JmJX<b*cxTIupQOW8G2swfrA%8<#7es5^kXF7Tt^I+4
zgVUQ4LUwrH;Fr1}r+0xLgl=f)T-UI=q+zwg;z0BP-z(NZ7d3*fXarvnP`x0a`T!(r
zb4kNyN5p~Z1G!ghLoRBBUeO2zN%06=;ZeNKqj8Bx<ASE$6-}>;Jl<D$yf5&0-{29r
zz@v19N9lr^^%XVWi#)y%?hTL@o$EZtmw1dXn0Q{~@w&p}bpeb%2qGK^PL?QDH7H)d
zRkaWkxJ%1~DBMxHw9GY3HOz==ashJYgq8JZ725((CWFht7-$tU3ll>PQx?btU<D~`
znG7||H7uEoOBnkYYnf^o5v|!;)*99Y@D4IqIh3ejLahjD*^pbhSs*`xRibo$*--V=
zFd{m`Y3Sy$*Ra(v&t^zrn#)957}T(*Fr+ZFGNmx4G1-8+{19`PNi}ByC)fjMM2QGO
z0M$3O94Ra{>?olCuS{z=5RC~`T{Rpjte{Rh)UE8n44UkIpf(&s6}KLIz%HeVOAl1+
zgL{jh5}kp8;e$Q6ume{en!Ml=?-oZ%etceOZb9WO_OzVDl9JTCTWpEN#hK}OMWANG
zEvCGJTdY;7Mft@=pjI=aFl0+C1{Gzum^1TAZm}go3O$zM%&JsPPH=ft6a*@*Ky@0V
zyb*w!oeJv0f(BxXbU`H<H@y5wO}WJeDR^(OX6B`&RuqH65>gf^Kr6FbqKF|R<PNJU
zdI5!1Svjc4VgVOfH^ii1oz@%D@(Z}HNb6jeHoYWmdQsZ^inRH4Y5Pmk_7|lcuSh#y
z;E=e%FV#`fQ+<Ko_5#1{jO+{A))!=LAiNHi8@x(acy&6MZiq=t&zY37z~rKs#uYJ*
z4(_j93_>DPB<32=sJ<wzeMMUPg09mQ-H;1Xp%(?it_X&8u-y<>T9AD~*$#rPh}&Nm
z_qZhPaZ%jsinvz?TSp-XPEhP9>EyW~AU&h{f{YmiT@f(9F5qxUz~Q2R(-i@y>jIva
z1U%0wUlj1WBH-7-dc)YPgRQryr*ekq9I1=^Dp&Yb7No97T$p)*U;iS9!2^EL4qnV6
z29&u$_;VZsxQJmy6fxj#E+P#hvCy+LF$GRA!%9X!O(t;aDhdIOgs`9|8Dnr%F)%QI
z;vX_fA&eBKunMk98BfH+6m@{i0;Svrh8uj!SNL>0m~Ze4_1AXQ&Iq~4uXKf9={mpm
zC4TLT{JK~8buVz}Vg?PU3jz+B58$9lXQ*WcbzWi%Kn;APU}CLdtYJlN&eyWluppYw
zDU8|RvbIR8h8cUQTf?vbWDimRLzp!zHE08Qs3ig;LytQiGqIP(gv~6h0ZrmCRpo^w
zW`jo00}}I65{nf;fevo%fSbGU91LqfD->7eC1(`n=jE3cBQ>V|!E-2}c08y}1(|Q~
z1C=`piA6<;mBmQSWYE~YCM&q$C<1kUiu^!@Ng{|y0i_pqNP6)Hl^{}B8#3T#(=7q0
z9H`|67J-yXpq?1Gi(d?iSWps!7cgS+NRCIOxGFV}pOKqD*wyub(%Ljg=bno}Mh;YX
zsM>?ZpW+!I<OLqZ8&djLq^vu<ZwO1zh@RnlMOdrD<pwC18f@^n#P4u{-{FRU!gT@l
zO9JXESgvb0UD9y6;2d!wD(<32{1uJ(ivkH(1QITQ(F0_K4wp0>E;xo=h={qU5qm`=
z_M$-C6@j=5VDx}r{sM<QX37TjK0qn%vjTXi9G<cZ38(B7##(li)XfA<-SSKf$V2Hh
z42X6EYO2Ld)wS$-au~Ydsh6<s8s-HcA3-t=5?RBJs+W<Wr-+FGd7vSMc?}Dyni^(Q
zelUY3E2%jlGf$xaI-yBcPVg%NWrdgj|NsB5$#{z`FFq?jsaTU8DQAGZTLjM;kc<Kf
zFij3b9w`D%Mc(2pD9TSxEiO(iM#}}a*r77TpkWYbc7U}1Zwbf4+>ewQFiTxbMRPzo
zVFN5Dh>CY`-T-9;4o)>t8(a-nHV~8Qa7W7vQnqu|7pPoOvbiW_dqv9jf?e>1(8!BY
zQCFm*Ca_One;{QIk+-@iWqn1;`hrd1h2ZduQW00AA||lk;1RvfBXfyIW=8l$9@Q&6
zsu#fM0l!jj#f+j0{7R4-V1XrC=?^Zap$P~S#Na$K4V*_%d*&rbDzo5hw81Z{6vh^`
zDJ4XEIR&Y$ikjjuhP6;8tPr+=#+uiljR}K>wUF#!L9~ZcShK+%D*_EIf&2#MArqiU
zRYYqX)i1TodD1lu3qS=YvKBBKzrGSsFB+_lfguajwga<km{XY6pw(>);KOTRB~YS7
z49Wlv&7q8#gWA7U>fk!6D6t$d5CdsWDkLf(jnL>(aqtB^UQuKZO6j1Ohc;3n0~NaH
z6$ZST#5Ulv5mdD~f(X!<1*DDc1mc3s&HTa$9_V37tjH|Z<SB9mNw|RscM#zLB0NC^
zXwaj`7sLXMFBgS_sz>hR{JiAElGMD!lGLI|kO(VeY8l$*0@a~KfgnAgE<!P==m3=f
z;5HYyg;phr(<kddK0gS~;}~6iCQw(OtnT~=4h9~P9=jQeSGZ+n6kp<2y1=b;Ltgp1
zywxRns~Z9mGgPifS$$+?QsDc-z@)%;fk);7kIV-S1_hPt^5&Q1&3AAdP`ToibWuL}
zihS|~9+{{7f<1K?<n%TKUf?&s$Zvjy-~2ki-6ejzy`>lVU9a%FUf^(rRE)@(0hGYO
znf@9$)7K!khFOv4jghh;6SNb_kiyu?l!lb8QO3)0bdFIw+^7Y?8npRg?3G*zeC(M4
zZG60z4Y^B==s?%NT+9sW{35#;TpgmU1gc@HVMAow6r@20aEF<>(YXZFRD<TxEO-`P
z0PoO3rNC4T8=`Xy4K-F|`$2;Xh$0KE2x6~c$F3jj5>^}|?PvoD!3>&gelJ1!s|cLa
zK!boqg`kW864T@ZC;lRMHoL_FYPH@HN>43`hm<Sv#gLIZaIPvU0qF$iIJU&1bkH0x
zXw2>wTTx<ON`CGwHb}Sp7E5MMe)28BqSUg~qT*CUoeHb2LA_@9G9*y5^%h%kW_oU7
z@hzs}5^$Fp+;F`mS)7`anp~2a5)Uyrvnn+OG9itn5|@q7%tNyex%Z4wRN+;=AC&2D
zfipd<>wA}9u%~8*_X<WZG+N=f!e|BS6&by~EIWd(+Xh{-4Z3I>0vb892)$wvdVxRe
zB7fKw{;=!(v6uK`uf`=`<WIT6pVGl{S3zkB=L+o&i6FQ`aEH(ip$(~5v}_I-UQlqo
zsNi}<!L@^>Bcvy0hGI|rMGn~;+=3UlW$$Wet>L_`VR1>rVu#2c$peO9n0UZskNgR>
z3xUCBw4ft)JgPV3%sX6qJSQZA;0(bDDV_cg_yw=?D_`PQUf^<(U-JsT<^>K-NNE7d
z{ou69z`y`64R(P$fGDK_Gp^DACD$`!9}>fx>(SQR)w1Ah-l8{M84=YtA{T*HL-~Od
zyCyqY>MjC}jTBXY(r_(^0QHcdiwPi&#X68Es8<e7e&{O*q_LzJSogb18@>L(qxd8!
zae=BNL|+}6xVDOHNWO07bIHu-qM6?nGrx<5{#OkBFYpIk<PW&QAJD;agJ0-6zse<k
zl?5Re`L(X_YhBQGxT5WIh2QrAhc7rr{ZK|!;ZgbqT)?L=)-tCv)Uu>A)UqN^ow+c?
zcGQB3KKMv#Eqe`P3R4PGD^m?~4RZ}UXe17`xUc1?VMA04Da_D*ktELEJzEV2Xxw%g
z3j@PyPz?-DcD0;&Vl@odlT8g<4F_7G$XUaJy;`W@M6FC18ERN@sANMV8w@=lpYk#=
z)H3&U;;2l}5@9VD^5P^!g;2v)!@7nA)%7?k2?j(Zf$E-Gt{R3cP$C6~RIx}67xqa~
z)a+HuRl^1D#mxh^6_Hwn;DHQCv917XLSl5-6iSOzA+vwQiMgo?pz%+IM9{*%g3=Pu
zQajK*CcKrWkXD+P3|jez7}jvjEXhbMQphWS40tHy7bz5^fLp8JStW!*@K7$K>j%lM
z{E${1r~wTb4`~8r#}*I)9vK0J1f;nKo9QVoNlea0)LxJ}$_J#u1w?>H-MAq=MDQ4G
zJxH_x)I5|bN=?o$OD%#A)W(CVA;{<nxSa^9qKZK6_aacCa7z@c`dfSmVUYe}P%8k`
zu!Lf8`|_3`M!i&}h+bM?C^-izBlN(-xF2{J`1m`RZm6nX;840DF4@6zLsj#Ns_O>^
zK29+(0qSQ8sb3MY>~OvzB{!oG1Q#eRC|cpXuwsGk1$q4|QU;*uAeAfph|w1fo$DH2
zmo&U?h-j{`-obW&=_50frqCA#CQTv8q|gTr22I`T8XlK4JiuyD<UepQ@Ty<u(Y?f@
zdqL0nB9F@z9+wL|E;o22uJg!W;*r0gV0uNt@gk4Y6&|MxJWe-w_$DZKdUiN<I6ak=
zo+E!j)A^#L%N0qN4!*mR(sN`M<n740AnAQk())^}_jSpjOOio1K(k27ADNk?xV|tj
zNpW@XePCk{5b3G=z`)39enVFKx~$11S(A&hW>;j*I{Y5+%U<A+1s8$XOTabo5-^3a
zmIbZ!QvuEAAjU9jU}Km}j0}02HLOVG9BllDwT3x`8MI8Gmbr!n%3|qJMXnB0SZWw&
zGo-N2MH^#jVyt0B^jdMv57jWGgVe!G<a9>#9%L<h4I7S<6Ky02drua%+(4OO<3R4m
z)^MPfOYjk1F%;jj*D%dyNZ|ncnR6~Hs-7lBP-7QIS&42w*BsD@EO(W<duj=IAwp(K
zF=){-sO(A4EKAK(NPv!cf>uB#=z)upB12H_1~qW21Vb`ltFghe!3qggG6@PP`KiTu
z>Lm)P6`92)pd}3;<3QO~4_y9Kaf5V~ftEWeBvi@cP!F+JA)$&XK~IzUmJqA~0G)-7
z&rG?+2~h)La)6DC&rH$eh7@k#G7Qv%xWyF)vdOimD8C53nF1=einKt!0X4OYDnWz6
z(D^2i{#&eIQ&DEXOY=%ni;8+ddcY+Yyb*JYJwGosJ|n*b+eDLmaz3cZ3|bZmnydnc
z6{LL~Uy>i6nNp>XUQQ9v0Gc@{76F%952R&4<H91MpfNE>amC3Q!}x&#M1l$|$XKxV
zC4T!0{Ps5lL?^I9hs3-t@!MVCx4R*ru)y&mXnI=V0gvbm!A`#p&koP8Yz)G(7i6P0
zNMDeRf}o2+(N~0`FYrV^fKGfs8dIS&+%M=kU(xfwpc`<3Kky=d;1&MB4wf52qSJLJ
z=`LWoD5QEtNcFmq-X$Tui$VrhgbX$aUJ)|xV86@F*HPYS(_=Hk@DjJo1#X$ULNXU*
z6IPgBkWB!gz!(q-W?vLaydso%fhX|;8-ud(4N0ZD5>j(G7pPy9(7Ga_^?{v5fa?PT
zgn+Ki)MP3G&65^^)-P)^f~P2pKszfm1wjkNLCHBj{uWn!d@gtcNqqb*p7{8}(!?C7
z410WhN`7*De32|DszJ*niogpdz{~v*1Ek<F67cW{XmYa%QIU3mv`q#PpmjCi(h^)r
z6@eD#fC*6bTl^c`^J!p!zz<9;tQ;R0Km-pLtI-DrR6;_EmFoiooUjlBap8o4EGuY5
z0tRWv&notT0Zu3}vFd$bfD;^itO_3(;DmwzNES}0@Uyxxf;7WPB@R}t4-9ZZgqhWk
z@dE=A$)&(5_kjVGaAA}LITe$DgbOH0GzE*MgW3zA4K$#xC}<;05vb2}ixoVZ3~7(B
z6@aIVz%7wmY+3n9IhjdCpeAk+Xt3p$Kv7DG9=y$zS_D}MT?A^I7l8(&!7DRwu|d|)
zfFl8tF2JK?pvB9-IBX#4!LDc?0|Nu7N+`Y#n&@U`WMsU-AbWv9_5p*^1yuBaLFfVu
z-Cz*D07EwzR4!mcHyE5QU_%eMMLVn}WGygVp?^Wi>LR!G6>jSWjvHbM6S8iwNZw!x
rz5qraS((L|zA!M0GhJX1{=ieg%*geD0XsP%?IT#`3z)=I1rB}y%zUhQ

literal 0
HcmV?d00001

diff --git a/ctgan/__pycache__/load_data.cpython-311.pyc b/ctgan/__pycache__/load_data.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4a7c66a7be7e09ca9ccedc03198c88c49b47c286
GIT binary patch
literal 168
zcmZ3^%ge>Uz`(#RoR{v&&A{*&#DQT(DC09o$#jNvhA4&<hF}IwMn6r)TU=?Gd5Jmk
zDTyVCD;Yk6RQ}R&v5HB_O3X{o4=BpdN=+__DM`%9F3BuQjY%#^Pt1$S$xloH>(DEx
v{Ka9Do1apelWJGQ#=yV;va48-fq~%zGb1D8hae^f9=--H5G-P0U|;|MN){;!

literal 0
HcmV?d00001

diff --git a/ctgan/config.py b/ctgan/config.py
new file mode 100644
index 0000000..b04e5c0
--- /dev/null
+++ b/ctgan/config.py
@@ -0,0 +1,48 @@
+import json
+import logging
+
+
+class HyperparametersParser:
+    def __init__(self, hyperparameters_file):
+        self.hyperparameters_file = hyperparameters_file
+        self.hyperparameters = {}
+
+    def parse_hyperparameters(self):
+        """
+        解析超参数文件并进行类型转换。
+        """
+        try:
+            with open(self.hyperparameters_file, "r") as f:
+                injected_hyperparameters = json.load(f)
+
+            logging.info("Parsing hyperparameters...")
+            for key, value in injected_hyperparameters.items():
+                logging.info(f"Raw hyperparameter - Key: {key}, Value: {value}")
+                self.hyperparameters[key] = self._convert_type(value)
+        except FileNotFoundError:
+            logging.warning(f"Hyperparameters file {self.hyperparameters_file} not found. Using default parameters.")
+        except Exception as e:
+            logging.error(f"Error parsing hyperparameters: {e}")
+        return self.hyperparameters
+
+    @staticmethod
+    def _convert_type(value):
+        """
+        将超参数值转换为合适的 Python 类型(字符串、整数、浮点数或布尔值)。
+        """
+        if isinstance(value, str):
+            # 判断是否为布尔值
+            if value.lower() == "true":
+                return True
+            elif value.lower() == "false":
+                return False
+
+            # 判断是否为整数或浮点数
+            try:
+                if "." in value:
+                    return float(value)
+                else:
+                    return int(value)
+            except ValueError:
+                pass  # 如果无法转换,保留字符串
+        return value
diff --git a/ctgan/data.py b/ctgan/data.py
new file mode 100644
index 0000000..8a48da0
--- /dev/null
+++ b/ctgan/data.py
@@ -0,0 +1,88 @@
+"""Data loading."""
+
+import json
+
+import numpy as np
+import pandas as pd
+
+
+def read_csv(csv_filename, meta_filename=None, header=True, discrete=None):
+    """Read a csv file."""
+    data = pd.read_csv(csv_filename, header='infer' if header else None)
+
+    if meta_filename:
+        with open(meta_filename) as meta_file:
+            metadata = json.load(meta_file)
+
+        discrete_columns = [
+            column['name'] for column in metadata['columns'] if column['type'] != 'continuous'
+        ]
+
+    elif discrete:
+        discrete_columns = discrete.split(',')
+        if not header:
+            discrete_columns = [int(i) for i in discrete_columns]
+
+    else:
+        discrete_columns = []
+
+    return data, discrete_columns
+
+
+def read_tsv(data_filename, meta_filename):
+    """Read a tsv file."""
+    with open(meta_filename) as f:
+        column_info = f.readlines()
+
+    column_info_raw = [x.replace('{', ' ').replace('}', ' ').split() for x in column_info]
+
+    discrete = []
+    continuous = []
+    column_info = []
+
+    for idx, item in enumerate(column_info_raw):
+        if item[0] == 'C':
+            continuous.append(idx)
+            column_info.append((float(item[1]), float(item[2])))
+        else:
+            assert item[0] == 'D'
+            discrete.append(idx)
+            column_info.append(item[1:])
+
+    meta = {
+        'continuous_columns': continuous,
+        'discrete_columns': discrete,
+        'column_info': column_info,
+    }
+
+    with open(data_filename) as f:
+        lines = f.readlines()
+
+    data = []
+    for row in lines:
+        row_raw = row.split()
+        row = []
+        for idx, col in enumerate(row_raw):
+            if idx in continuous:
+                row.append(col)
+            else:
+                assert idx in discrete
+                row.append(column_info[idx].index(col))
+
+        data.append(row)
+
+    return np.asarray(data, dtype='float32'), meta['discrete_columns']
+
+
+def write_tsv(data, meta, output_filename):
+    """Write to a tsv file."""
+    with open(output_filename, 'w') as f:
+        for row in data:
+            for idx, col in enumerate(row):
+                if idx in meta['continuous_columns']:
+                    print(col, end=' ', file=f)
+                else:
+                    assert idx in meta['discrete_columns']
+                    print(meta['column_info'][idx][int(col)], end=' ', file=f)
+
+            print(file=f)
diff --git a/ctgan/data_sampler.py b/ctgan/data_sampler.py
new file mode 100644
index 0000000..d9b67d1
--- /dev/null
+++ b/ctgan/data_sampler.py
@@ -0,0 +1,153 @@
+"""DataSampler module."""
+
+import numpy as np
+
+
+class DataSampler(object):
+    """DataSampler samples the conditional vector and corresponding data for CTGAN."""
+
+    def __init__(self, data, output_info, log_frequency):
+        self._data_length = len(data)
+
+        def is_discrete_column(column_info):
+            return len(column_info) == 1 and column_info[0].activation_fn == 'softmax'
+
+        n_discrete_columns = sum([
+            1 for column_info in output_info if is_discrete_column(column_info)
+        ])
+
+        self._discrete_column_matrix_st = np.zeros(n_discrete_columns, dtype='int32')
+
+        # Store the row id for each category in each discrete column.
+        # For example _rid_by_cat_cols[a][b] is a list of all rows with the
+        # a-th discrete column equal value b.
+        self._rid_by_cat_cols = []
+
+        # Compute _rid_by_cat_cols
+        st = 0
+        for column_info in output_info:
+            if is_discrete_column(column_info):
+                span_info = column_info[0]
+                ed = st + span_info.dim
+
+                rid_by_cat = []
+                for j in range(span_info.dim):
+                    rid_by_cat.append(np.nonzero(data[:, st + j])[0])
+                self._rid_by_cat_cols.append(rid_by_cat)
+                st = ed
+            else:
+                st += sum([span_info.dim for span_info in column_info])
+        assert st == data.shape[1]
+
+        # Prepare an interval matrix for efficiently sample conditional vector
+        max_category = max(
+            [column_info[0].dim for column_info in output_info if is_discrete_column(column_info)],
+            default=0,
+        )
+
+        self._discrete_column_cond_st = np.zeros(n_discrete_columns, dtype='int32')
+        self._discrete_column_n_category = np.zeros(n_discrete_columns, dtype='int32')
+        self._discrete_column_category_prob = np.zeros((n_discrete_columns, max_category))
+        self._n_discrete_columns = n_discrete_columns
+        self._n_categories = sum([
+            column_info[0].dim for column_info in output_info if is_discrete_column(column_info)
+        ])
+
+        st = 0
+        current_id = 0
+        current_cond_st = 0
+        for column_info in output_info:
+            if is_discrete_column(column_info):
+                span_info = column_info[0]
+                ed = st + span_info.dim
+                category_freq = np.sum(data[:, st:ed], axis=0)
+                if log_frequency:
+                    category_freq = np.log(category_freq + 1)
+                category_prob = category_freq / np.sum(category_freq)
+                self._discrete_column_category_prob[current_id, : span_info.dim] = category_prob
+                self._discrete_column_cond_st[current_id] = current_cond_st
+                self._discrete_column_n_category[current_id] = span_info.dim
+                current_cond_st += span_info.dim
+                current_id += 1
+                st = ed
+            else:
+                st += sum([span_info.dim for span_info in column_info])
+
+    def _random_choice_prob_index(self, discrete_column_id):
+        probs = self._discrete_column_category_prob[discrete_column_id]
+        r = np.expand_dims(np.random.rand(probs.shape[0]), axis=1)
+        return (probs.cumsum(axis=1) > r).argmax(axis=1)
+
+    def sample_condvec(self, batch):
+        """Generate the conditional vector for training.
+
+        Returns:
+            cond (batch x #categories):
+                The conditional vector.
+            mask (batch x #discrete columns):
+                A one-hot vector indicating the selected discrete column.
+            discrete column id (batch):
+                Integer representation of mask.
+            category_id_in_col (batch):
+                Selected category in the selected discrete column.
+        """
+        if self._n_discrete_columns == 0:
+            return None
+
+        discrete_column_id = np.random.choice(np.arange(self._n_discrete_columns), batch)
+
+        cond = np.zeros((batch, self._n_categories), dtype='float32')
+        mask = np.zeros((batch, self._n_discrete_columns), dtype='float32')
+        mask[np.arange(batch), discrete_column_id] = 1
+        category_id_in_col = self._random_choice_prob_index(discrete_column_id)
+        category_id = self._discrete_column_cond_st[discrete_column_id] + category_id_in_col
+        cond[np.arange(batch), category_id] = 1
+
+        return cond, mask, discrete_column_id, category_id_in_col
+
+    def sample_original_condvec(self, batch):
+        """Generate the conditional vector for generation use original frequency."""
+        if self._n_discrete_columns == 0:
+            return None
+
+        category_freq = self._discrete_column_category_prob.flatten()
+        category_freq = category_freq[category_freq != 0]
+        category_freq = category_freq / np.sum(category_freq)
+        col_idxs = np.random.choice(np.arange(len(category_freq)), batch, p=category_freq)
+        cond = np.zeros((batch, self._n_categories), dtype='float32')
+        cond[np.arange(batch), col_idxs] = 1
+
+        return cond
+
+    def sample_data(self, data, n, col, opt):
+        """Sample data from original training data satisfying the sampled conditional vector.
+
+        Args:
+            data:
+                The training data.
+
+        Returns:
+            n:
+                n rows of matrix data.
+        """
+        if col is None:
+            idx = np.random.randint(len(data), size=n)
+            return data[idx]
+
+        idx = []
+        for c, o in zip(col, opt):
+            idx.append(np.random.choice(self._rid_by_cat_cols[c][o]))
+
+        return data[idx]
+
+    def dim_cond_vec(self):
+        """Return the total number of categories."""
+        return self._n_categories
+
+    def generate_cond_from_condition_column_info(self, condition_info, batch):
+        """Generate the condition vector."""
+        vec = np.zeros((batch, self._n_categories), dtype='float32')
+        id_ = self._discrete_column_matrix_st[condition_info['discrete_column_id']]
+        id_ += condition_info['value_id']
+        vec[:, id_] = 1
+        return vec
diff --git a/ctgan/data_transformer.py b/ctgan/data_transformer.py
new file mode 100644
index 0000000..7ed8b0e
--- /dev/null
+++ b/ctgan/data_transformer.py
@@ -0,0 +1,265 @@
+"""DataTransformer module."""
+
+from collections import namedtuple
+
+import numpy as np
+import pandas as pd
+from joblib import Parallel, delayed
+from rdt.transformers import ClusterBasedNormalizer, OneHotEncoder
+
+SpanInfo = namedtuple('SpanInfo', ['dim', 'activation_fn'])
+ColumnTransformInfo = namedtuple(
+    'ColumnTransformInfo',
+    ['column_name', 'column_type', 'transform', 'output_info', 'output_dimensions'],
+)
+
+
+class DataTransformer(object):
+    """Data Transformer.
+
+    Model continuous columns with a BayesianGMM and normalize them to a scalar between [-1, 1]
+    and a vector. Discrete columns are encoded using a OneHotEncoder.
+    """
+
+    def __init__(self, max_clusters=10, weight_threshold=0.005):
+        """Create a data transformer.
+
+        Args:
+            max_clusters (int):
+                Maximum number of Gaussian distributions in Bayesian GMM.
+            weight_threshold (float):
+                Weight threshold for a Gaussian distribution to be kept.
+        """
+        self._max_clusters = max_clusters
+        self._weight_threshold = weight_threshold
+
+    def _fit_continuous(self, data):
+        """Train Bayesian GMM for continuous columns.
+
+        Args:
+            data (pd.DataFrame):
+                A dataframe containing a column.
+
+        Returns:
+            namedtuple:
+                A ``ColumnTransformInfo`` object.
+        """
+        column_name = data.columns[0]
+        gm = ClusterBasedNormalizer(
+            missing_value_generation='from_column',
+            max_clusters=min(len(data), self._max_clusters),
+            weight_threshold=self._weight_threshold,
+        )
+        gm.fit(data, column_name)
+        num_components = sum(gm.valid_component_indicator)
+
+        return ColumnTransformInfo(
+            column_name=column_name,
+            column_type='continuous',
+            transform=gm,
+            output_info=[SpanInfo(1, 'tanh'), SpanInfo(num_components, 'softmax')],
+            output_dimensions=1 + num_components,
+        )
+
+    def _fit_discrete(self, data):
+        """Fit one hot encoder for discrete column.
+
+        Args:
+            data (pd.DataFrame):
+                A dataframe containing a column.
+
+        Returns:
+            namedtuple:
+                A ``ColumnTransformInfo`` object.
+        """
+        column_name = data.columns[0]
+        ohe = OneHotEncoder()
+        ohe.fit(data, column_name)
+        num_categories = len(ohe.dummies)
+
+        return ColumnTransformInfo(
+            column_name=column_name,
+            column_type='discrete',
+            transform=ohe,
+            output_info=[SpanInfo(num_categories, 'softmax')],
+            output_dimensions=num_categories,
+        )
+
+    def fit(self, raw_data, discrete_columns=()):
+        """Fit the ``DataTransformer``.
+
+        Fits a ``ClusterBasedNormalizer`` for continuous columns and a
+        ``OneHotEncoder`` for discrete columns.
+
+        This step also counts the #columns in matrix data and span information.
+        """
+        self.output_info_list = []
+        self.output_dimensions = 0
+        self.dataframe = True
+
+        if not isinstance(raw_data, pd.DataFrame):
+            self.dataframe = False
+            # work around for RDT issue #328 Fitting with numerical column names fails
+            discrete_columns = [str(column) for column in discrete_columns]
+            column_names = [str(num) for num in range(raw_data.shape[1])]
+            raw_data = pd.DataFrame(raw_data, columns=column_names)
+
+        self._column_raw_dtypes = raw_data.infer_objects().dtypes
+        self._column_transform_info_list = []
+        for column_name in raw_data.columns:
+            if column_name in discrete_columns:
+                column_transform_info = self._fit_discrete(raw_data[[column_name]])
+            else:
+                column_transform_info = self._fit_continuous(raw_data[[column_name]])
+
+            self.output_info_list.append(column_transform_info.output_info)
+            self.output_dimensions += column_transform_info.output_dimensions
+            self._column_transform_info_list.append(column_transform_info)
+
+    def _transform_continuous(self, column_transform_info, data):
+        column_name = data.columns[0]
+        flattened_column = data[column_name].to_numpy().flatten()
+        data = data.assign(**{column_name: flattened_column})
+        gm = column_transform_info.transform
+        transformed = gm.transform(data)
+
+        #  Converts the transformed data to the appropriate output format.
+        #  The first column (ending in '.normalized') stays the same,
+        #  but the lable encoded column (ending in '.component') is one hot encoded.
+        output = np.zeros((len(transformed), column_transform_info.output_dimensions))
+        output[:, 0] = transformed[f'{column_name}.normalized'].to_numpy()
+        index = transformed[f'{column_name}.component'].to_numpy().astype(int)
+        output[np.arange(index.size), index + 1] = 1.0
+
+        return output
+
+    def _transform_discrete(self, column_transform_info, data):
+        ohe = column_transform_info.transform
+        return ohe.transform(data).to_numpy()
+
+    def _synchronous_transform(self, raw_data, column_transform_info_list):
+        """Take a Pandas DataFrame and transform columns synchronous.
+
+        Outputs a list with Numpy arrays.
+        """
+        column_data_list = []
+        for column_transform_info in column_transform_info_list:
+            column_name = column_transform_info.column_name
+            data = raw_data[[column_name]]
+            if column_transform_info.column_type == 'continuous':
+                column_data_list.append(self._transform_continuous(column_transform_info, data))
+            else:
+                column_data_list.append(self._transform_discrete(column_transform_info, data))
+
+        return column_data_list
+
+    def _parallel_transform(self, raw_data, column_transform_info_list):
+        """Take a Pandas DataFrame and transform columns in parallel.
+
+        Outputs a list with Numpy arrays.
+        """
+        processes = []
+        for column_transform_info in column_transform_info_list:
+            column_name = column_transform_info.column_name
+            data = raw_data[[column_name]]
+            process = None
+            if column_transform_info.column_type == 'continuous':
+                process = delayed(self._transform_continuous)(column_transform_info, data)
+            else:
+                process = delayed(self._transform_discrete)(column_transform_info, data)
+            processes.append(process)
+
+        return Parallel(n_jobs=-1)(processes)
+
+    def transform(self, raw_data):
+        """Take raw data and output a matrix data."""
+        if not isinstance(raw_data, pd.DataFrame):
+            column_names = [str(num) for num in range(raw_data.shape[1])]
+            raw_data = pd.DataFrame(raw_data, columns=column_names)
+
+        # Only use parallelization with larger data sizes.
+        # Otherwise, the transformation will be slower.
+        if raw_data.shape[0] < 500:
+            column_data_list = self._synchronous_transform(
+                raw_data, self._column_transform_info_list
+            )
+        else:
+            column_data_list = self._parallel_transform(raw_data, self._column_transform_info_list)
+
+        return np.concatenate(column_data_list, axis=1).astype(float)
+
+    def _inverse_transform_continuous(self, column_transform_info, column_data, sigmas, st):
+        gm = column_transform_info.transform
+        data = pd.DataFrame(column_data[:, :2], columns=list(gm.get_output_sdtypes())).astype(float)
+        data[data.columns[1]] = np.argmax(column_data[:, 1:], axis=1)
+        if sigmas is not None:
+            selected_normalized_value = np.random.normal(data.iloc[:, 0], sigmas[st])
+            data.iloc[:, 0] = selected_normalized_value
+
+        return gm.reverse_transform(data)
+
+    def _inverse_transform_discrete(self, column_transform_info, column_data):
+        ohe = column_transform_info.transform
+        data = pd.DataFrame(column_data, columns=list(ohe.get_output_sdtypes()))
+        return ohe.reverse_transform(data)[column_transform_info.column_name]
+
+    def inverse_transform(self, data, sigmas=None):
+        """Take matrix data and output raw data.
+
+        Output uses the same type as input to the transform function.
+        Either np array or pd dataframe.
+        """
+        st = 0
+        recovered_column_data_list = []
+        column_names = []
+        for column_transform_info in self._column_transform_info_list:
+            dim = column_transform_info.output_dimensions
+            column_data = data[:, st : st + dim]
+            if column_transform_info.column_type == 'continuous':
+                recovered_column_data = self._inverse_transform_continuous(
+                    column_transform_info, column_data, sigmas, st
+                )
+            else:
+                recovered_column_data = self._inverse_transform_discrete(
+                    column_transform_info, column_data
+                )
+
+            recovered_column_data_list.append(recovered_column_data)
+            column_names.append(column_transform_info.column_name)
+            st += dim
+
+        recovered_data = np.column_stack(recovered_column_data_list)
+        recovered_data = pd.DataFrame(recovered_data, columns=column_names).astype(
+            self._column_raw_dtypes
+        )
+        if not self.dataframe:
+            recovered_data = recovered_data.to_numpy()
+
+        return recovered_data
+
+    def convert_column_name_value_to_id(self, column_name, value):
+        """Get the ids of the given `column_name`."""
+        discrete_counter = 0
+        column_id = 0
+        for column_transform_info in self._column_transform_info_list:
+            if column_transform_info.column_name == column_name:
+                break
+            if column_transform_info.column_type == 'discrete':
+                discrete_counter += 1
+
+            column_id += 1
+
+        else:
+            raise ValueError(f"The column_name `{column_name}` doesn't exist in the data.")
+
+        ohe = column_transform_info.transform
+        data = pd.DataFrame([value], columns=[column_transform_info.column_name])
+        one_hot = ohe.transform(data).to_numpy()[0]
+        if sum(one_hot) == 0:
+            raise ValueError(f"The value `{value}` doesn't exist in the column `{column_name}`.")
+
+        return {
+            'discrete_column_id': discrete_counter,
+            'column_id': column_id,
+            'value_id': np.argmax(one_hot),
+        }
diff --git a/ctgan/dataprocess_using_gan.py b/ctgan/dataprocess_using_gan.py
new file mode 100644
index 0000000..53cfff2
--- /dev/null
+++ b/ctgan/dataprocess_using_gan.py
@@ -0,0 +1,141 @@
+import argparse
+import logging
+import os
+import tarfile
+
+import pandas as pd
+from ctgan import CTGAN
+
+
+# 配置日志记录
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s [%(levelname)s] %(message)s",
+    handlers=[
+        logging.StreamHandler()
+    ]
+)
+def extract_model_and_tokenizer(tar_path, extract_path="/tmp"):
+    """
+    从 tar.gz 文件中提取模型文件和分词器目录。
+
+    Args:
+        tar_path (str): 模型 tar.gz 文件路径。
+        extract_path (str): 提取文件的目录。
+
+    Returns:
+        model_path (str): 提取后的模型文件路径 (.pth)。
+        tokenizer_path (str): 提取后的分词器目录路径。
+    """
+    with tarfile.open(tar_path, "r:gz") as tar:
+        tar.extractall(path=extract_path)
+
+    model_path = None
+
+
+    for root, dirs, files in os.walk(extract_path):
+        for file in files:
+            if file.endswith(".pt"):
+                model_path = os.path.join(root, file)
+
+
+    if not model_path:
+        raise FileNotFoundError("No .pth model file found in the extracted tar.gz.")
+
+
+    return model_path
+def load_model(model_path):
+    """加载预训练的 CTGAN 模型."""
+    logging.info(f"Loading model from {model_path}")
+
+    model = CTGAN.load(model_path)
+    logging.info("Model loaded successfully.")
+    return model
+
+
+def generate_samples(ctgan, data, discrete_columns, min_samples_per_class=5000):
+    """
+    使用 CTGAN 模型为不足类别生成样本.
+
+    Args:
+        ctgan: 加载的 CTGAN 模型.
+        data: 输入数据(DataFrame).
+        discrete_columns: 离散列名称列表.
+        min_samples_per_class: 每个类别的最小样本数量.
+
+    Returns:
+        包含生成数据的 DataFrame.
+    """
+    label_column = 'label'
+    label_counts = data[label_column].value_counts()
+    logging.info("Starting sample generation for underrepresented classes...")
+    logging.info(f"Current label distribution:\n{label_counts}")
+    generated_data = []
+    for label, count in label_counts.items():
+        if count < min_samples_per_class:
+            # 需要生成的样本数量
+            n_to_generate = min_samples_per_class - count
+            logging.info(f"Label '{label}' has {count} samples, generating {n_to_generate} more samples...")
+
+            # 使用 CTGAN 生成样本
+            samples = ctgan.sample(
+                n=n_to_generate,
+                condition_column=label_column,
+                condition_value=label
+            )
+            generated_data.append(samples)
+
+    if generated_data:
+        # 合并生成的数据
+        generated_data = pd.concat(generated_data, ignore_index=True)
+        logging.info(f"Generated data shape: {generated_data.shape}")
+    else:
+        logging.info("No underrepresented classes found. No samples were generated.")
+        generated_data = pd.DataFrame()
+
+    return generated_data
+
+def main():
+    parser = argparse.ArgumentParser(description="CTGAN Sample Generation Script")
+    parser.add_argument("--input", type=str, required=True, help="Path to input CSV file.")
+    parser.add_argument("--model", type=str, required=True, help="Path to the trained CTGAN model.")
+    parser.add_argument("--output", type=str, required=True, help="Path to save the output CSV file.")
+    parser.add_argument("--n_samples", type=int, default=1000, help="Number of samples to generate per class.")
+    args = parser.parse_args()
+
+    input_path = args.input
+    model_path = args.model
+    output_dir = args.output
+    n_samples = args.n_samples
+
+    # 加载输入数据
+    logging.info(f"Loading input data from {input_path}")
+    data = pd.read_csv(input_path)
+    logging.info(f"Input data shape: {data.shape}")
+
+    # 识别离散列
+    label_column = "label"  # 假设 label 列为目标列
+    discrete_columns = [label_column]
+    if label_column not in data.columns:
+        raise ValueError(f"Label column '{label_column}' not found in input data.")
+    model_p=extract_model_and_tokenizer(model_path)
+    # 加载 CTGAN 模型
+    ctgan = load_model(model_p)
+
+    # 生成样本
+    generated_data = generate_samples(ctgan, data, discrete_columns)
+
+    # 将生成的样本与原始数据合并
+    combined_data = pd.concat([data, generated_data], ignore_index=True)
+    logging.info(f"Combined data shape: {combined_data.shape}")
+
+    # 保存生成的样本和原始数据
+    os.makedirs(output_dir, exist_ok=True)
+    output_path = os.path.join(output_dir, "generated_samples.csv")  # 添加文件名
+
+    logging.info(f"Saving combined data to {output_path}")
+    combined_data.to_csv(output_path, index=False)
+    logging.info("Sample generation and saving completed successfully.")
+
+if __name__ == "__main__":
+    main()
diff --git a/ctgan/load_data.py b/ctgan/load_data.py
new file mode 100644
index 0000000..fe98e2e
--- /dev/null
+++ b/ctgan/load_data.py
@@ -0,0 +1,73 @@
+# import pandas as pd
+# import numpy as np
+# from sklearn.impute import SimpleImputer
+# from imblearn.over_sampling import SMOTE
+# import os
+#
+# # 定义文件夹路径
+# folder_path = "data"
+#
+# # 定义需要提取的特征
+# selected_features = [
+#     "Destination Port", "Flow Duration", "Total Fwd Packets", "Total Backward Packets",
+#     "Total Length of Fwd Packets", "Total Length of Bwd Packets", "Fwd Packet Length Max",
+#     "Fwd Packet Length Min", "Fwd Packet Length Mean", "Fwd Packet Length Stddev",
+#     "Bwd Packet Length Max", "Bwd Packet Length Min", "Bwd Packet Length Mean",
+#     "Bwd Packet Length Stddev", "Flow Bytes/s", "Flow Packets/s", "Fwd IAT Mean (ms)",
+#     "Fwd IAT Stddev (ms)", "Fwd IAT Max (ms)", "Fwd IAT Min (ms)", "Bwd IAT Mean (ms)",
+#     "Bwd IAT Stddev (ms)", "Bwd IAT Max (ms)", "Bwd IAT Min (ms)", "Fwd PSH Flags",
+#     "Bwd PSH Flags", "Fwd URG Flags", "Bwd URG Flags", "Fwd Packets/s", "Bwd Packets/s",
+#     "down_up_ratio", "average_packet_size", "avg_fwd_segment_size", "avg_bwd_segment_size",
+#     "Subflow Fwd Packets", "Subflow Fwd Bytes", "Subflow Bwd Packets", "Subflow Bwd Bytes",
+#     "Label"
+# ]
+#
+# # 标准化列名函数
+# def standardize_columns(df):
+#     df.columns = [col.strip().lower().replace(" ", "_") for col in df.columns]
+#     return df
+#
+# standardized_features = [col.strip().lower().replace(" ", "_") for col in selected_features]
+#
+# # 读取所有CSV文件
+# all_data = []
+# for file_name in os.listdir(folder_path):
+#     if file_name.endswith(".csv"):
+#         file_path = os.path.join(folder_path, file_name)
+#         df = pd.read_csv(file_path)
+#         df = standardize_columns(df)
+#         df_selected = df[[col for col in standardized_features if col in df.columns]].dropna(how="any", subset=["label"])
+#         all_data.append(df_selected)
+#
+# # 合并数据
+# final_data = pd.concat(all_data, ignore_index=True)
+#
+# # 分离特征和标签
+# X = final_data.drop(columns=["label"])
+# y = final_data["label"]
+#
+# # 检查和处理 inf 或超大值
+# X.replace([float('inf'), float('-inf')], np.nan, inplace=True)
+#
+# # 检查并替换所有 pd.NA 为 np.nan
+# X = X.replace({pd.NA: np.nan})
+#
+# # 确保所有列为浮点数类型
+# X = X.astype(float)
+#
+# # 检查是否仍有 NaN
+# if X.isnull().values.any():
+#     print("存在缺失值,填充中位数...")
+#     imputer = SimpleImputer(strategy="median")  # 使用中位数填充
+#     X = pd.DataFrame(imputer.fit_transform(X), columns=X.columns)
+#
+# # 应用 SMOTE
+# X_resampled, y_resampled = SMOTE().fit_resample(X, y)
+#
+# # 转换为 DataFrame
+# balanced_data = pd.concat([pd.DataFrame(X_resampled, columns=X.columns), pd.DataFrame(y_resampled, columns=["label"])], axis=1)
+#
+# # 查看结果
+# print(balanced_data["label"].value_counts())
+
+final_data=[]
diff --git a/ctgan/synthesizers/__init__.py b/ctgan/synthesizers/__init__.py
new file mode 100644
index 0000000..881e08f
--- /dev/null
+++ b/ctgan/synthesizers/__init__.py
@@ -0,0 +1,10 @@
+"""Synthesizers module."""
+
+from ctgan.synthesizers.ctgan import CTGAN
+from ctgan.synthesizers.tvae import TVAE
+
+__all__ = ('CTGAN', 'TVAE')
+
+
+def get_all_synthesizers():
+    return {name: globals()[name] for name in __all__}
diff --git a/ctgan/synthesizers/__pycache__/__init__.cpython-311.pyc b/ctgan/synthesizers/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e7bfe77a077018470ff0b67846f5b064c5acf443
GIT binary patch
literal 776
zcmZ3^%ge>Uz`$VFos|BDfq~&Mhy%k+P{wB+1_p-d3@Hpz3@MB$OgW6XOi@gXAU1Oj
zb1q913nN1cOB8DgYYRgZTMAn+gC={GNN{CdNk(dMW>soYu|jTsN@-52-b;{ynvAzt
zokQFm{WKYGv4n&<x@s~NF*7hQ6tOTcFeHOiz%a;ukj!Tli2aN$3?(o*1_lNfhGh&4
z46ETH!3>&Ae#wk*J_7?Q0|SFF0|Ucn3$Ov13^fd~;<b!5j3r2_Pz|YJl!O~FnW;x2
zm|-P@CgUyk^ql;p#GGPHrdv#U2De!95_40FLGDvf_+{Z@6_b*cn3tX(P?VpQnp_f7
zl9-cSl3A7-lU$OXm={xw6cjP>@tJv<CGqik1(m<JY*I3lOOo?*3+$>i(o;*~6LWIn
zkyY#2<m4wO<`moMAyk6mt5}JFfuVun0=F3iElBBbnW{cfbBg8+mnr%Xu?yU0pFv&$
zxjr6ja(odxD4cmg1UCZ%!z~ePb{6q7Ffed3FfbI0fi1ZKwq$Mk1#Yv8+-6s}%`UK*
z6@kp~(`3KJ9v`2QpBx{5O9JdwJru`-g}`xpOB~Ebl`bhuOa(`J5g*7!AZyVa_>02^
zl5*^dK+y<_v|?Wd28IvJjEsyo7^E&RNIhWCx`2vqFlbyrMGv^78`wXvF*2HeV8A56
GegOb+M89$X

literal 0
HcmV?d00001

diff --git a/ctgan/synthesizers/__pycache__/base.cpython-311.pyc b/ctgan/synthesizers/__pycache__/base.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96a5fbe3cdb9bc443d8d155f16e47ad2d9aaaa4d
GIT binary patch
literal 8812
zcmZ3^%ge>Uz`$VFos=#s%E0g##DQTJDC2Vy0|Uc!h7^V<h7`sWrX0p7Mlj73#SErd
zqF7QHvsj_38B&;A7?v?GFsz0NFhsGXu(U8lu`@BaGo-M#Fr=_8V`5-f%>+{r#gW1k
z%%I6$CGM11oElu2SCWxhoLQAxq>!7RQks*h_Y$PiFBz_Zfq{XIfq_Apfq~(31PcR$
zHv=QXbcPa;Dk#okgfbWwAnSq&qB*dZse~Oa#gGN_IJzq48ip7q28LReT2>r(!2<zd
z8UsTO3o&k3#>l|18t!UF1{4>8d<!<RhLs$%I2lqH!x*MBGU5s&B;SLr0TZZ^Sj$$!
zwg9XDaD)XB;f~kO9AHN<Frdc*dks4dcOdx!R~(|-$iz^?Si@Gsp2mbOgTq%?L!bsz
zH?mt0vB!*VOA2E!gYsyspr!qRj+JtT3g$?Laz;(2L<L3$2Iu^|lGKV4h1|rv#Prl6
zg|z%42s<+`U7;i+RiP*`FC{-$p|~WmBvp@#OF=<F!Lca4*b2-7@gNG~!3q>KGV@9l
z@{1HoN(*vQH4%zH3PX@gOHEAyDNBSXRLD<L(8w#zEvVFkn5Gv5wlx@Jsg6QPeo=CU
zo_lIuYEfcIevzgg-0tGklK9;Gl+>Jfh(4Gr(n|A^OEUBG&|KgK6;>$8S13*`!SJ6#
zejZp1Y$U`MKTXbCOnC*j*dP|%;!IC1f!K1373{8CydaUHymY82C&*3Uz_`T=Vj^U?
zkvw&a6X6|>TRafMp(codw1J&~;;UPN`9+!OnR$sh@p%PMr6O=4uv?&_{7~cI%7mZ-
z$chA#ON)w9^GXng3ByE?lof+YDp2Hs@Glb=tC*Cm#Ju$UfTH}Y)Z~(wlEj?slFYKy
znB<c5#JredMEzDAlayGTs#j2XOAzEAs0+XWQ(Pp*z`(%Cz`#&^osogz1;Y!528OR9
z43g4wIIl}+Uy{(iD4}~rLiYlT=v`@rE7E!wr46n~8(d(Kc%Wr=kwv}1|AMIcQx>)s
z&+9DWmsrGS@L!iOyd+_GQNs9&gz-fdlPfGH7g$X0vT$Bu5x*dzbzQ>bl7z`c39~B_
zW*1q^udtY3U@`x|!o_NIgM+t&6$GD3%g(X9E^Tm0+TfzJ(G_W<3mg(RI3%ufC|%-E
zT57b!^18CwC1taV$`)6YEiQ6cUg5C3z+rh;PI*rLMLGQ|a{3oIq;GIYU*}N0#G$%0
z^@^J1MK$XyYStGyY_4$FT;Q;|!6A2nL+*lt!387N3-WFkIoz*sxL@FKf6Bqt!+D)U
z`Vxoq1sUCo9C}wc^e%Ad-R0oD0unK}$YFSe!|(zO{rJkyAR{;Uo%-X)k78a128Kg|
zEUxklN9CQlJT#e;;Y|Pr1_n?S$_#4c2r)1)v@=X+=wL`=Okr%{sAa5S3}(<|s%j6$
zmPM0N)AEZ_6_OKka$rRiEVpAVo?wy+8qSG1If+SFizT?Cl6-~oqQnA)#G?Gtyc9?g
zfZ}XWeJ2lc_9_MjhN+Cx8EP1^V3j#T0}m5JCvy!$I%5svBE}$wOoke!C5(NH!3-rJ
zuR_%!>!@MK0tEzE2DRa6%D~8g-Lzz;6b2APRnN!(s%;G!ido8;ASFB_LnH$u10zEU
zLokCTqn{@8EnduAcuN31cWScSVku6|NxQ|8Sd?CTi!HkxOmo1(s2J3QR8UZ8C=z90
zV7SE&3a^6HqAFQr!}V-(@{<#DitY5E(gvX1ew2ZM;YS0*4Q~D`+_D$ARc{DOPq3Kb
zazWT=fyo77BM4ey2B9zT7=2}95K>-Xaz)v6L&+6$pNoRNR|I`OFfj6(gNYmB@)If+
zuv`!~TTpUA+zf(NltJhZ%rIqOwP4~SNb4660p{W)e*E}>Qw}UqEXcsXu%5?}gZUr_
ziz5&FL4HO@9`=<CnoLD<3=9lKpcGOJir54Na3~dl@&cI90EGjnasQ)%;e!AhtN8~8
zEQA09k8p$EXHc?81`#j}G7`l8>;mq#q%uS?rZ7Y?r7%V@r!YmafVwA9tl$m_8>ll9
z#h%KM!iLdh;Y?v~VTj^N0d-6`Zt<gaOsY&l0t(4FiN(d>=0IXjjzUUmT4HHVi2`x~
z5302j5}ZTa9sLp#^!!k=38*OvO7fqbm_Yq2<Ro9qSi`UYxg-F22#iZWy1*<3hAfb7
zFgt}2wPA_MV+3VDP>KYr#ol;EYwR=OvcCq+CPs!5P`rUnW&qg(Z}^r7Km|cfc$EHL
z(F$C8QkW#+no2;09+Z~_?xlh_S==BJgi*70Q5{Y*nUTzdhXMmb77xf25Jod8vW6iG
z?h9lWEkN=qNFM~FhFlF(3X2T`YDlCYxqwJBSixp6E#QOLfl8*ZVZ<*enX)G?VPs(N
z%q=L&FH2P@$V|=#m;T`91UQi<B!E&aQvW(30a_h^S}xA{d1a|ZC7^a+LPByuX+nZR
zN@`hVa;gHfIw>wmEGmK3Cvf8nit>|Fi;ER9^Az%nQc{aR`jb)>K$VC>Nj|(4nv<WH
zf@n7?q?V=TDU@fV<|&jGr=}>R73JqDB<3lkR+OX`<t64Ql;r2<C={0_XDB2VD<mWY
zyM!eq=qWhn6zA(GWTwGQO-Mj4P>T~16cUS4LCsbjh2j!W6<Csynpp&Cm=}YLO34Jb
zc@v8=Q;QXf5=$~b4$er-Q%HxkoQsiyBq+6{v?vd$9S;g8kP)bjc93X5Wl2VUo&rcx
zp(wSWD782>4<1<gNuY)xq^1NZzQvqeP+9~ks9u6H`7PmiaKj$8DJvWgX=5V$un3ft
zHTiF`$3vp#78j^yg0PD~Y4#Qygma4{J|5(j`1o5a$@vA9x41HkGxLf|67!N%!G_#o
z1$RG+KpCP4RHhVxGFA~NPZWXb?_10T`30KHMbe=9ix*;Od{Sa^c4@&aR&bIk29-Vv
z3XsD47EgS9dTI$+B0j!K1}&lKA<H^|iubP!3=9nnAH*0WW#{m8@ZR9y>F4j_pUXOj
zXF=#i3GFKq+7~%=u5jpF;Ly1tATovbx`5Ip0i}xqDpv$lI#{8~XYgL+P`kpRc7a3f
zDUaBMk}KkxYfUfm7+&Eq{J_l4ss2<zYKF-bY26ixSELQLDs9ocV&r~Nz~hR5#|Kt+
zPW7*B41&@~3RJh~Ulee<BH;9aot;zt0~dpi!4+-W9f?=89j<G8UDEctsO@t_+o!|z
zf{^+}9*rB?`go;z1p7U^JSXH|<WalAqjrHu?E$|?2TzeA0|NufP!=c(Kw<g$93!Y6
zNMWpDEMq8|kirC}dukXLpk{w4553H<Wx}ZvZFr)V8M&#2)VM+!wX0>RVOaoj7t}rk
zRf92%Q3C1$K@}iI0vCW15>ymHp_Z4mtTn7SY(aB1BSQ^q4NDqR4Py#3xO8N!VTcE%
zd9cwaK7zZZ1eC|X@(c`FoKObC0(fl#69Lhv?kZwRVaa1mVMQ%zYuRcT7JxDeNCOJ4
z!6=hc*uZ0&$AePgMG>SR0<{Se74R3v$eJ(<IdE|a8h8NbH%OtAnWj*Zkzbq&E&-wC
zk3wP|xX{gm7nrafJtRXbBqk^47o~v9^n7sX4DH@3WF!`)lqVLYBFd~Hbx@-X)HDN^
zAHUd=OH&f942o147#RFCd5S=dS8x#oDMLVkjb6CufK1Q^6(}qq&9`_mi{lf^5;MU)
zjv_4v1_n)baDh`~1X9ES@0)W&N<y%2i$p+j@P0U?;DpG4j4lSH2v8Y|Trw5oEtwFq
zQJ|9PE~I4QVc=0*;Cg{u6NL1f50qZe^SPksb3<HadgY|b3yNkJ#m%pXn|FBJkdmLT
zIZN|`iseNqt1D7g9o|os)R*XAR5HGzWPE`~{sxcybso)2Jen)CFY=gO;W4|wWA;=@
zWr^l>CG$&4<`<PLuP9lfsMEN}V|<0j_yUjd4K>RP)_xaM{X1MI=wIYfdLSX&!P3Kf
zS5Rz<%v#koTGzE4FKIbm)N;O}<$OWTZ$`$9j1!_41^urG`ggFw3Q4w$9P(E<<S)Pw
zW(fo;|3NwZ^EPmk8KotPvn5DgOB6*HsN@6Zq8f%8M$}dqBLmJBC`P#eZlp6;sRSpM
zrGis2cwh+BA_3)*g2ZA_6C<G@u_Plw51L+Ci!?!r5tQa2i58T0Zm|?6mZicfJr<CL
zB12G;0EG=KJr&t9Ffde!pjU7po%tZO{@~>EKthI4wRV?-w})$j$qbQ;9CBB{BOkB`
zL)2;)z;zpDgn<GM<fzXFz!6pgYDItpl7Rs+^MNvg4KE0>7q%FMA7dV43R*zdGNIT5
zas${-^e#aya}DzXP?7@6A`_V7<TcE}44TYU+CHewLvV?if+Ge^K`{p^b~M=#aaUvp
z5;X^v4J@FxCZsrFxy78EQ(R;Rk^{#uD>!Muy62F#OA)9+f*GSAQ_4Uw8U&6}P(NFo
zh+KD9Ky(V%43moj3ReUaKrI#C>l_l7I3#9>UF1->!l7`1L*WMYb^s{Iz>#_a+zzN=
zY~ZW`)p76;1Wgto*K;+DCGbdKU_i}3MZP6SwJ%~i1V_X$u0e}z<lYpbH-wr@7QpLS
zxD!D=AG97~8dD8JJX`@5KOpt<(EQ+s>RaS~6n5Vs+8$^oX`-3L0%}M@9905p^Fo+N
z{XR7P*kS@Sv;)<b!dk;Pn<0g5E?Uc|iBXe1k(ZHy0oyPgQbz-Iwg6OpA-DR$GX^@4
zX#gG2lz?A;9{R)qNRMkKxIY7%Fi6hND@n}E1J~{usiaIEz?%yoQ$aYt2-G_R&0Apf
zZ6FR+C<e_|l;nfDbKt&aQhrGW%DhN1qAB60$#{zeqygLoDpCU#sQf5Bxder(K!wzb
zg4E=a)D&2cGEV{I0FYB4K>=z<X(Z$oV3`j|&{0SLw;2#KA_<zB+~Af=kshd!V9v}d
zL2sXcnhxOB%PsEQ#JtkPoOrM!^+D=6Ln;eWU5kqHi!_-aZ4}U80C@Zg+65|-0{KG$
zqhd%cLFw1kfr^w@;3DON00WQ21#bBpin<*xS9s(ufYDuE@fkczlP~gWUE$UGz`)PR
z_f$Y=3g>kJ)k^}ZD{L+Zs9qGXxFTQyDropG3MgI?P`oandr3fdMb1S5>nj4*A6S`{
z_>c_z%EiDdKA~o5=tUmwD?HjCn7KekO3koY!8pf$t<xIMD>{zQeh(`bCm*C&1d()I
z<9|`W=8AyL2X-z_z7JdsN_t2JKNXVxz{DtMzd+{$1EZilg80bHEX?<XfmxXE0|PTJ
zAGrR~<OcOg^AdAY<Ku5}#e=4#N^?MLp7{8}(!?C7410WhN`7*DJS3`&K?S!phyc~P
zx46LrZn>!?8Tl#TQX3rRMG7DVpl%gp91%2fSOg+K8jC?icmo3r-e9r6z+(SEM54j>
z1B(Q!=m!QkVIs_`@PPqNXtA=Ye_((Ud<<-YADCD{I#^jnKQc403ARK-1o$Lb`93gU
z5@2tG9iquu<jugqaEl9+B~vR(ax#-{v4UH!kR%97c(?eVlDUvsM)3R(csLYn1ezy*
zao9lCM%WbvFfcHH>i^<t3=9k(m>C%vZ!pMSV32*lAb0_W9xw=9fT0@<JPly@fsKKQ
z=L)0L2PSDoqYn(yj7A@s8JMKuA~phy!XFsmgjx(EBi{!G?BqwV_!lsVsRkSp0JPkQ
A+W-In

literal 0
HcmV?d00001

diff --git a/ctgan/synthesizers/__pycache__/ctgan.cpython-311.pyc b/ctgan/synthesizers/__pycache__/ctgan.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8119686edfd9f07e4536c37ba0978bca11bb7798
GIT binary patch
literal 29580
zcmZ3^%ge>Uz`(G0NmcsMa0Z6QAPx+(K^dRp7#SF*Go&y?F{Ci2Fy$~tF@kBPC?+t?
z9K{T#S)y3MG+PuqnC6J$NMTH2&f&=AjN)Vj>1D~`%H@vY&gF^X$>ojW&E<>Y%jJ*a
z&lQLg$Q6td%oU0f0;^-q5zZBk5(cx`azt`PqeQ`M_8hTX@hEXHn<GaeS29X6S1L-1
zk%5UJl`%^i>Ru3A2F6a|1k1_7_{$g=7*@kL4C#z(xRx<9Fsx>V34zRVXGr00VMyUg
z<wuo^l4D|UXGq~~VMyV_kdp_?@wYIf2w=!5faL^R7*d2%CDF`OOkoOU&=juXbq;ZN
z^i#;qPbtkw)q4qYkDn&vE!O;klFVF9j$7PLi6zMye)&bYhAFq$U5fGx@=Hr@ar&et
zW>*HK`h?zM^U2IhO)R>_<_k9B77JLED>${VG&QdzGco5DS6XRaa!F=>USf_W<1Lnw
z!jxQ1##`Jji6x1_iMa(isYRNMxA;NakfOxA;<Wst+|(jXrd#|@iN&eGm3bu@sl}O9
zsYSPViW2iu@^j;hOA<>`tJGCB5{r{dGILWkixr|(HF8qRQgiYWLEKmcRgK)#;^M^g
zRLx>d=3A^_xm#=?qZ3PRabzUsrR1a*6(@td3&XIu_-w?$z|hVxogtMWiZO)&l-i=0
zQ`kEg(il^iTR5XwQaDmrTNtBQQ`lM<qS!hZDj1{KgBdhAZ}GZh7AF^F=4R$4mgE;z
z$)Io*((;QGN-|OvzyYrJ5|jawSzwkkFfgz)Fff4J@Y#ovfnh4+bcPZ}glq{Df?dO~
zjER9^HC$u?oP|)sz<`?K7#V69YnZE885n9=YM5&n;z2fnmDjM;FvP>%k<L)dTEY#M
z024J#%NQ9LR>NK5!Vqgx%T~i$!V8uL6E!R~Y^W|O;Rj2Ai4;au6G{ZZ!eAl=&GdBU
zDqaSLTGkr2T9zJDbX!uGQEf?Qgs3hN1{($@YFPUiY8c`{feYrRFa$GbviLopH4_9*
z+CK-2GCZ(<3G$LA_bt}q(t^~YB2fkghFcu*@tJv<CGqjMm<tk<Z?P35CZ}ZP-eSqg
zEG{VmMNbi^Y${@9U|=W$*>H=wIJHod;}%PCYEIfM&dj`m(vo<P4#5=Uga8sM0+j@}
zm{T%yixfasv1FE{=HB9rk5A4?EG~|ZFXmuiU{HX7hF_*GRxv4AiFxVy0Y&*)smUcV
zC5buNC7ETZG07$AiFq-_h+I(&5z;HDECQ82RbnWqP7mfKNd^XnVo;j+(ZFz*gSR(#
zM#x1Dr7Ijt*EzH=acE!U(7nQ;dx1mu2A}v0<<9&L<{QFd9V|T@H^d}6SbDf`@CkG<
z-xZO#AgOgxMEi<}_6G(gPH`~N;c|nAf2!LI;{_}iMO3eds4lJCAiS_{N!`}+9mN}K
zw$xk@cDcyo(&2nVNUX!T!}$RR_f)nC#WNfy@=f8pz%PH1L%xIguCVNk;tRscAS*jK
zIygRq;^`K1azSY_D9k}T5C(-3DCK^h!33%skaI{DC_+JM7-|@?=LqzaSi@4q#K4fk
zl)_ZQyoPxh3j@Pycs{6Qt6^IJvI%S=GLgbm!-CTWwDgQx=7P<H+kj#w@g}0=e2|GX
zjG$bd!c@an#l^r-!vx}0F)%RHvNJK%Fx0T4=Du2v5+sYWKuH<QUw|CSU=a{e!%@SI
z62_oR0Fq#+;lLiopcICv=fL5L;^Gum5M0ByjE#X|HQde=Mo{tu8_ZD4S;M&iS11u3
zPM}agEiD#+(hkH5Br=6Pg=q~3s>&25P)OEt)o`V-S8*^fEZ{`a0B554loe!SEq4t!
zs3xeAb<WQ%C@o0^7x?K#i7A<>c_j)3sd<SxC6#(bpaQ!{ih+S)CF3o&l+?1!<kXk{
z|Ns9VqRDoPJ-@W1ptPj;78|%cDZa%6(jN~K<w-6|O)N=`PcKR=$hgH*lv<LQnFr-@
z=jY{A#zVBd1lh5Yu}Fo1fkBh&7Hdg<QF6vD7Eld(iv?7j-eN0CEl5o)xy4eJnOc5}
zBeAq3A7s-l76_T2ms)&_B`?1y_ZCZTYGR%yPmv<1<l-z!P0Wc;Ni0dc#hI3voeE|a
z>43^aa2dy*lbD;7k{ExBH8H0kBk>kbW?o5ZQ9*uAVo7T8EpW*gk0QhgvJVtK#kct2
z%y_V)1YrDlXwVh&Ac{$FVOgXOa+7FsVoowB5WwLKQ(UEtS`=bcqy{P%JsBAo8W^4m
zh)&^}VL9J!mfZ@)i_$t*q;)nhUz9f95V(W!fZ7EC&x-<{R|GseSZ@f3PUo4#Gb3;T
z^Mc6baf{+MI9^mWyP|4#QNa9)fH_0~*CehPMl&ks*UhS1k$6$w;EKG#MFGPr0)`!|
z55#3=R9+O<SP{5EbY<iY&5JtD7sXvVcpk_o%y3yCb5Ta`ij3X`4#}q?l2iPyi>O}`
zQNJOgdP7+LuB6-?sS7Hm*Hx@8saRc9vALpRb0F-3q|Zf3pDU6+HzZ|2NJ8cVFN3JY
zM+SCLt}h_s10Mr7&li4nQQi*D4?+x5%Jb!B$*o{o;k-d}yWS?f1A-TgoUa%;gPbAh
zaz)bRx}@hNNzW6C7XpJWNP1qB489^6d?76SqGWgn-wj#q4U89M46jR@Uy?SzC~bK~
z+H!~cMQP^?91;)s<u7o^7ilmsFeHPLA1tqc@;Rt1{k)C=RBG2SE`Z6P(5N{IxlE{G
zK$PxPYzz!3jM-rGi<!I_niy*su$M(Oj5tehhAemqgPQkIOUqyeP39^M$AW^KN^nj_
zNlywT`C!2WaE6OdfM#o5P=;eIvVdg>aD8871>(wqq?jPt2$H+wi$VTSP=F+3_O$$>
z^2DN)Dp3@R^q|6a;H1g`N~#a|Mf<C}syl0YYG(vq;g?^avP9z|zt)PxOZ@s5VCV+F
zaDPQt#SFIj+_ShBIIUn=!+lZ0<cfsJMSjyO{H7N;Ou<p5$yg)?au{=wCdiQ>Hbj5}
z)F1*isN>^ramB|&x(M;{w|L^?3riEhP3`#jTkP@iDf!9q@!-}+ktfIqA5fm;g|=;y
zQ*(0S<5w~kfpZ$Dl6C<57Tn|l6CfLl(-;^Sel#$^;0+eF3oL3MSh!hLJ}}@Sq{LVi
zJ}{sX2F$Ex9~kfvU{8ZBMrr<oA_9~gKZBAldh;LL_-AV2h+;`$ZefUG1vmQHKuybA
z96_nYnJJ}-IaRVSmO@TqWoi*xgFhLR13(slFevqb_@B=)(4wth0<WhTK<$1e)HIIV
z-Y>z^)~{iR2jyIldl`^h^r-o?h9Mqavw~XND;fPXIg3D@;-U~xVgNNt!HwrzOliqQ
zAdlZ-O3Et&)v32wic)h*HCZ4zf-&<JV?HF6fK(|cC^SG43dn>iL0I6zTeR+=?BE4S
zHlP-5#tfE=9I{t9WUq6mUE)x?$f0qCL*oL6#$6$a8G<t$C+bYm>0s$$zbhm=BXOeM
z6g@CUL}H@b6t@nR9?lNV4o+~=Kq<OFE&)f>21wDxfW7FdLF?qyGSx6F0Ht+s;2{&}
z9V)O&SlUBY$bhL5Ia(0~SPfGRV>(j`;~KPR1{YqMh{6jNfY8*5QF?*eBSog5l7O`+
z929dAAR-b(n1M1hb8=#dCNns$ioC%B`K2Yrpukap7GOo7j9(=Hvj|ph`GO+K15$4B
z3H7IUrB6uf%<lp9F`(S!&YYf{4(1ztLj9RtnKPI>^Lp|+nD6q7^i<AJo>4YKc7@{w
ze!YwQdRO@ME^z39qiH335y;>oKad0cK|}zE03{Z%JBk89+#rzMpooALI^ZM;B3wWc
zAOlbe9rX(=>K|BGSfxKOU?F%|S@l0K;3L2;1E+hclr5a@sd=eI;K2k5I1{~G0VPpb
z!U5%1P>g&QqtO^bEem*r0J(fY>+`rU#7cul2tYX;>MqO?0$AT2Bm?fnv-HT+ur7f2
zN1z(P6xvt=+9&{Ojm?NsQb9)pP|ByGFi;YQmQabHGy+aC@*ox{VccR%EGS6LO97Wz
z;BG#+2G`^Or~F&Iskuq1DJhwG>7Z`?E#7ok41u{Epms5c1#WkPH5RFYOa@neAa{UU
zt*Awt5Ikz(Wg95#755>^w#*qK7dhmvaL8TfP`|{Xevw1-3Ww$e4$T{U(x7Yu>6_1x
zy(pr7MMNFcE0+KhkgmDMjNk<-5UC5o8W(vqI-GCt@J}%5^agduZ*Xw;b9He|HJhP4
zqvE2l#uZ@=xbf({a&YoSDIY<J4TL`%kXk;1h94QLR1oD8JQOe*R+^09kVGkbSU@3-
zrSuVm8v`qTl0Z3K0#f|IT0|H5<*)F|cd*>x7w+H&`>H4&6iw*G1vq_xh)|Fr7{!Is
z1s0_bEF7RV3Kl|ynN<YT6v0hEYzCQ+Qe1!n6qNEm&jAN2@;D?X18$finZnb;7{!vx
znj)3L-@=GmPe-w(2(&Onv8M>OFhp^r2(>Uoai(&mNT-NmsN_!NN|8ws!w~055pQ9L
z;!TlgVTj^O5ea6{l)c3Y9?ndZ;$&cO&d*E9gpBhlgd`^Aq$+?$mXQW<^|-h|$v-3`
zvsfXs7~FVH&M!(;0FQww<fnl}z~(6wK*qgv6v{JFiy+<Tl+3iW)FMznIypbLAU`iP
zucR1kUSd(I0;qSJT3k|;SdyBeP?DdXT9T1kqz9IE%P&&M1zD2<>MQ3ID<mf6mzIF-
zDk#cNPAx9hQ7FhsO)O4TNX|%2&IZdw`{t*l=49rjgM3n&lUSqxN^c6K#hH2OU{jGj
z>+a|mtC3MsQc!HAub)^{ky)mfUzDz&m{hE9XlY=sXJBAzV63T7kXQh66Ii=rQF^fz
zm<0+Elrlv@BQvi=6Cnvw5S&?+3JV}m7d}5%p%^@Lu27IzTnzRy*x880ui%oJmROoo
zQVhyXhDH{UumG8bQer7+l#~|afWl7!G|CCGBQvk07?1Ukfe5Je#o)>)H8D9uAwLi9
zLRbT!SWm%G0mf1A0X2%@UIw|OJToUpAt_Y>GNh4$)vU6_oYK@{P<klKOi4}AL-Ch}
zk*S%Ef|042Cc=9dL$5UPB1#7h;x({uA(J5x0owZ&5tTVb3L0rS`HAQW$0s$hC=Zm3
zK{*2)z@RAzxIWb2G)mPqK?EdRQG7~ja$+TR^Wh0jp*%G+J)=Ycq82vS;h2(`tKbis
zN&-#bpxS1bs%wU@4K><G^A3t$R7+63grX3iw@{;(La!x(X3*k6V{X{8wqI#(Qfd(>
z<AF;UNVW#&T4?E{keR0d&SAwRsRg(K3?x@fz<A_235C*v6wrVisB}ol2PX`$9q>qp
zDnQr=%CK%l`MKb7CLC0<gUd531snyRxq+dcfw_gbsh$EtHA<F9fs|07s=yQ;YYL^H
z;w-aR!4TxMTu{&{<QHY8XXYj5K&H+VGC^~Oxv6<2iJ-~`;liB!^!T(Q@VsMkrGiFM
zetr&U+!D1U4bK2q4Islyi&GVH^3xR%dh*j0k`qf()ANfmlM{0kz_W(Mpzu#dRBb3;
z0LNBlUOL?Q;D}2|2q`K}O-MjEtt_=DDZe<C2uEckmZd5b6lLa>D1fIqLH>l4vIRx?
z=|!o<#R^5K#h`*6-A=c}oMIG9Qw#EwGYBM#lA^>+aQOgHjB2K_fdRq`1&PV{(-$}_
zQo#X`UX)*20I!BYb$FgaB4`v1R1$#-Ae4-VumhCFQ0grt`;$vk5+UJ-RyiTXXktl8
zYHmRZEONnGKt+*zK&S$!(NKz>WIz^rg4!LB76~{A74nM|oI_n4LFqCtzeFLiEHN_&
z)SS^#a1IDnfK^(cbc);Ms7c^u18A(D=_Lar14B_R0|UczU1yoskIb~}pZC7qy5Re<
z4R$ZTfaE}Z@DR6`Iv_S^SS%T7B9#Hu(E~LNKW_ssjHuyAVFV9`GNBBHGS)DFnzEqw
zBWOCPh9Mr*{RMNu6Tk4W2Jpl$$UKM~EAp^V4O<OEJiIej!;U-@RKroj5D)JF)^H*Z
z0M&5SFvP<<el^_4Q?WHXH4O3a;ieki8ish#Xg<Vtz9P92kmtcX28Jv?D1%`EatOc#
z(I&TxnNyf+`SYYo6yXXOYWPZ&5IocYup;IfzFPh~rWBT1{u0ng6;xrC5R}1C!;jjn
ztYL_UPco-~CYVKFD#1L^>@ipX%mYOnf+r4UGB7ZJ2FEm6{fbOL9Su+{fX7>Jv4B#k
zro=72c$8M)ExveoSq$n>-V%yO$uwXwUL-X+MYjZyxZryDmLQs@oT6JIXhIN$-0_H7
z3PTCFa=FD9k5a(i;)sWqYPZ<qA*Ba+lDh~rSX5LD8X)4yERKigo?Bd*#hH1<C5d^-
zskfMmONzjAMz`4GA@g6ic;ZVCt3Ph>f~S?@A!SO@EiQ!5ZgJ=27Z=Av+J%||;9hxA
z3MfodL4-2cVnop86y4%S@F1QQKn=;9qFchKA`oR<NI`<801_g+r~$$b3lBEr-~fdL
z*sH|?XmeMf!Ut)DlMmeUfe!&p2lY6=fxE3Alo$l0W;k}%EfBoKuXF*19>{5Tuw3Dg
zz5%0^j6r<)2eO(VTIz;^K8TjRp=b!A<!-2$gJ{JE(yAa@0>+2ymyqvZ>EXR0F5AJ<
z!vmfFf50#Nfq|7%X0GWRtLxH6m!yp@N}F7fHUSIX5S8k1>u|fvD?Wo|Zt_K5r7OHj
z9~hW9Wp0Q{f+QY@N_V(D6_=S(v%u}5xb_usZIH<iWE4R5OWu%G1kqAAq~tnSdiZZh
z%Y(QQ7dSu&+%<l_d++s|_iJw37J+6}lHmms0|NtSECp0}eLe;5BO}jJqtElZFvPmn
zvX+2UL-S%5DEoogH4HV3HB2?kH7skGma#H0tcKfB%Z6jPl?}CMz&>V$+V?JIsbNcI
zVq^$r=y5|H%}QZLHLHnHlciLQnSsG2H8H1Hp**uB1KivK=bNO=oXnC+P(u+^gQkO;
zjXCkf`DrD&i51`mTz*bUYLP-&YEf}!eqJ$HA*fpn?RA4v8=kIkW^r+8Dx#?bYt*F_
zWtOGtrRV3T=cMW-=jZCDW#**nr|26xI+~UTq(vEK7nUXlm4p?QI;WU9WkyE1C%dHj
zRa6G(gC=n8ON$F^i!&07K)q&o!wI$T3vyRZetKp}u|o8thGjYmd8N7WX{m`NrA4X5
zu?QD|)PZnl9%xQFF(<PsH3i&!D=5lON(2W^W-4Sr0Hh3rOA<>l_50=L>E@-TCxTYf
zC=@3rg8KHLRs(39wX`S|?n;mW8Hq(HSlb*RNzXI|P%}>l979E^C8b4qsVVS&F|3&Y
zE?81az^+X!2KAD4GxAFm%2Ja{@{5XfP^<%)np9c>Z(_hlJTk$<8i~aUnQ01{C7^~0
zC<%b#9%>(`jS8E<M=>%bGZ$N{4P=v}LP}<CY91)DC?w|O=cOx@XJjU4D1eLwB`KH>
zAPpy^uD6wft|3yQ3Icl_X*2_5IAmd8ib6?hUU5FScT=30o2pQpkywzbkXQ_gI8cnI
z!8#P~;N+|u3{B1{pcPk{Nu{U_e8{vU%Pr=T#L`<Vpvbw!oLW!}9&4?VMavP87y&gn
z{Sy83G}&*l7A5ATrxt-)Ah-BX6@%-nTdbMId5L+qm=p6VZ*hfz#vNRXit>vz*>ABy
z5@^w6Pz^H$L`(${ptTjZ7~^koBWkXcVo=Kn6auhGkz4%nsFqcULJN9S;S-?pz6@O6
zf7N0T5}!~sQF4kTs5Dgu6CI9srBtp+8E<eqz<fgGLP+?9*n|sGi5I03uSg|!c;6M4
zm|{Jn`l7J*6=Ch`!upql^)Ct=UJ*9@z`)3>d_!F6f&dhKU}of1{>sL{FL^=I;v%2r
z6+X)g9G2kn7G-Fc3$`Bd5V*OS&QQyg!dS~(!;l3oIY9(sR^Ejnc4sY14T}py?3P;A
z8dgLJn*wc~7CF=~*RW>6OJVRDCAJ!-1)ydv$S@R)JTI9An$-Zy)-a|pr!cpI%4@b(
zrZgtRsv*>J6qU!wP{UTs+~bSlDi%nixyTi(s~8r5q7iH-GJ)Iv6xKCpYanacarq5>
zO%~LBb5T}CvLm~8HbV->T#UM+XBSE+A)4JxDV%G#P@BBL44T|k4v@k2f};F_)FOq%
z<dV!Xa2Hwuwm=uuaRQYr&@n#fh%CJM1X<y5i=`wnFT?K^JG7uE0*$R}GJ-Qlm3(Mk
zYDGb6GH4hUn+3lZ^}q|CAQO#TMJ=G_ngGPU_{_Yt{CLnX<Skx=-uSdUa0V{|4UvO0
zY7uA-wrCzG2hIW!v7iYZ$n;E62FN@%&`3&r$t|Yhl3U!#`8lPzd0?||aTXUOLYYjd
zDMfQZSq{{YfaX`|G7SEBa9=eZ>V{N!hD8;A0@4kdXKP@1ASm3yc7vC%gXxB}{1s`P
zE7G<f7<f3786jkc_YGN%4-DMAAOR3Lf%As2%oSm+6=|1*O)dzV+>n%>kvS*#qNMf&
zryC+t(|srTE>&HkwIXDLA{YkkVA>G6L1RtaMJ1anN;VfoY_EveUJ$XpAt^n<X@b)Y
zQMoIkIvY4IiCSI|wL~#YN$t9l*(D{j4JA7q!LaCn(vH#{9$V^87++8_yQt)UMah4H
zH`LMv+80Iiu88Pe5YdAg{ZvZjf@<P{$}8@17gQ4==%Q586{(~PB1s>_L7w?4#ULQs
zQ#nIyLEr^`^^5%KSNPR0aHvBH6y(eeDtSTK|ML%6fx?o)Sj(EiSj&d9P+3sRUc>If
z5Ieh;qXgbGW58I}n8K6|GMS;6yM_Z%#H27oM2om<SZg>CGv?q4N7f!~G*Qr^NzNKJ
z6qO(`E|hGBTo|#Wu(UGOu-0&*7DmX6C9**23tWcOFr`4$v(<2+s!w5B!-|?+Q`pw9
zqb(lgu3^Ex5V3|EqqyU&Wv$_=W$T&Kvk=8CpduaO78Xo9LB%&@iTDCgOo7!Q6R1tv
z8pZ|i-aBqp$cwnDxEVl;u{o-2kQYZL7v&chE2QR?6y+CGDuB9n3i)~9CRJKyMQVxy
zwEdK-0LjOYA_!81EM#C{@B<elpmmCQshW(pIEzwKO2K;#G<m^gOVMIbfxw)Tnpd<0
zR91khY)BJ;7i@Dp)MjwOa*HjoC_Oi^0$kMGVl6I7OwPW=TwIz9Ud;tv00}8#7?W=?
zf>ws-7Z(?S0_GM=aY=k~(Q=R-D?r3bki9IaDe=j-m`idCa4cNpiU*A>#)G_FB>-&}
zz(sz53K&iB(!>Y+f*mZ70)}6pg9WLS;pJ5Lz`)C?04-y9cojY{@bD^3;Jhm=Hlg~W
zu<{jQ<pqV8gf%V*Ypn3RqT_s^@{&%#1)YEgB4RT%=jdJ(Q3n+v50FbBY54`3OLQ+v
z>rZe2w>gY%NXX0x1i=MN3yc=9&WM=9x0ZDc&xXJag&?>?Wk<;l%?&kM;!h}E(D1pa
z;d4d9=R!o>2WBQ|v5yQ)(qdmg#02LLYz(53)4eBoFJQhXqH;w<<${O`xS+WqEkD6!
zg3AX<9$tm7GNAILdWH#7dD6l1KtN<d;Y5!9x~@9N<n02L<(i8$7iurjUXi#&?;^j!
z1r7s9nFB5fplK3R{($(O)4*j8ayPMtu>{@?WMDuvSWt6fkt%XG5PeM!iX0=Ngh9=b
z;N`^3d9i4=!*e481EPt8X1hr`LoIU+lM6$v9w>jNFlK{2RxAiwgq{U&q@*A<g|flr
zWRYkMGxjCVH7uyEV`S*j&*MklM1k2Ytzj+_DuIu3F)*aC)G*Ixn2WabqlqzvHJCw@
zEis&lfx#I(ms$=TZcj*nW#o8BMlMcBP)M!FEG_|$uO}pc2cp4)KA>@PcpnC{uZU2n
zppjRaTTrQ&my%dilvoKGw=GD_OGzx&1MMYpD*|l=!tDQrz(%}5iu4pbOB8ZTi%US$
zkckRLy6`ShVh*SyT2QF~H#<=Q-E`F9Uvv*DXn>kBx;dHIsc0kCpyq~8W^oCqiw@HT
zE#JVy=SiuMS&))^1!&s~G6Rk@=L`=ZkXaxc2JTSmDR`zSpav$Wx0tBl2MQ^OUqBt!
z%wouNZ*fL`X-*1;VafS<pk?LYVi>emJ3X~XAu}%}GdUG9tdIdKj*Ima{6RDA<(b8)
zI`G~phIxq!2?^+dmym$z5Cvq5A?ZW`v;egj5z<AVJPR)Wt1Qtx3h@jmCQ~y(j#4N{
zL8Jsw>c}h51IMObmAGeKSz=CR3Ov-)@=NnltQ7n-xr-V>B{x$+$}LV%3IQi0&_<t<
zTkH@Q72jgYE4alDO?9_e!F^gtmBATOS&$0uG!}tcrkbqa(aBp}h%|c(GWM8Sk_yWQ
zxA-$5Zh;H1LG1zMV^GvWo262qG0#lUtQ}+o7s-q&$jq1?ZWRKI3=G9rAiYZo1|G2q
zSxc3c=v?I1zrwBmfq{cFlJTyb;sqt6i*m+S<czl_@6ftp<#W-@_llYCMLFLN{|<l9
zARpIKrzKt&dG)XG>VIHh<BVjyA*ayce?v;^0|OJUF_@UZc0*WlhT|lw304clR#;up
zw%*}*MceL@qTK~WyBi|P*G050iD<1byC`CGMa1ZWh|!M^tW3PdUpW|rWiH5sUlfYC
zA{23fC*lFO)CUGePA|s0LNXU*JvP)|koCAA>v2)Y^NNt?1s=~2%pggK5^%A<lEDvU
zI~}NW1(*Cz0^mjGI9G1gpskx^WT;_o;6Pq-Udx;(h1~i?YGE>!z}t5WXxs7_8IaqJ
zphg1H#u%ifwKYhkKdx>k7B?d6LUIGRdH@p`Yf>4J%lH~bWP7tfp#{;!kOgo3p|(I#
z=4+5_&w_6StYs+yZ2<wB!@y9(kOiN+sbR>1w-y(G#=)U#5EN?ofCdL@7_#771#4JP
z-Le2aQ-d%8Y$JS@2DOf>VTgy%*pz4^^nm8T$Z{Ecj;Dqp3qHfMhM6)~*0PouBHWe*
zAKj>7$b!3m0n#iPLM>|O)i5A7C8w~|Fl2#NID!?gVMX8OhQmfyBJF|4H`pG~1_$i+
zu$S0_CBQ@$XbU8mRl|@4pR+?tFYuUL!-g8FwH!4ZDeRyWQp;K53N{8zWPzrBz^rtJ
z6plrVYdA5|0I2m3R#5_4zzk+FFqD9n^+DM+XoD;@3|XKJMqpVUh7!;UV=$Y6Aq%uw
z0?bC=%3aHatgA>Tg}au!hHHTqSTPz=!@U40H6kQacv5)Q@S?T_QGLV6Py$+-4b@T-
z4P`KZ7W|{CL=I!{hDUZ(b8C5OxLg<}u*UY(@|J+cQlLg<=|dR|DSS1&n5{iV)R5yr
zwbg}T0!yrAEgwqgGGu|4iGq#AP|XxOzm^{*mB8~t4Mz><8h$ia3!s!Kpgr4QbCC&T
zpQGAWD_A4gpp_yJ!@|H&D^x37D^e?pVxBRwIWRW3%tWMDks497+$5Iwv_=rwOceis
zCVs(2)d-=whSbumRtSf=!W5b-jAAZmF95{70yRSEOeySZ1kvN51hkj|tcHOh%N)u8
zmHhBfLUmECcu5XS6$3*SXp<C}yTB4mp%FFWsJa<JCL`(*;^IggUW_u-h}IzS(BrN~
z2wApPqK07sywpW@Sd9Rxt7;`{7#4u`cOfeRvuh;LR7#<QFCsmNEwBb_KqFGn(g~_*
zB&Qm&8lgr*6c-?3SPDlBOXE;4RwIEmMabl_)QHu})JWAzGcnY#)<~nexmFe|DuX5}
zSHrr%2F>$eZjDThbPYomye>elZPDUS9>t}$VC|sNP!82D9I9kdRe^f!+zhn}HS7yO
zyY66y!RZv?H6o}w7#V8h7a-MfaHR|>qHDx3)GA<Ci&~E}GSsjya6p)ZV4<s<z|^x4
z#b>lGJ2_Cx&b;!GRqoNEY^V`xR08D=lz1eVLs3&!t!y4Q$c!2(bnzO&2Idr=TE#r^
z8ioaqh?In2)kxL|)d-`twG>fYgUD5A`GuRIMiJa1Q>tMjA;%~YnG4v6%mpkx)-_5C
zoDhCOuuyZ%0%wE_f`z6Av_%Xd31y+Gsa38~psRf<*jv0cyy=WJ%IS<XD(Q?V;&aei
zv9+o-3|a8DZx+1GTx^PBnj6%MpjsBLx<)k}ZIo{TXyYYJDV$CbT_b^-=4u$?;ilxV
z*Qg@%OD4iKGT>?h*C^Mh5NZRXw6wKIZDkW<!vxlzz7)n_22IJtKqdwTw@k>+5%5^O
z0@B_V&<RSZIpE<t&<rYg#0Rwa05lbev{+(vJQKX3adbQrJg}#rpx{^4mjGEVn4pkY
zlwJy2?+seblaQ8KlAw^AT9T2UqL7hTtdNwNnx~MGT2Pb<TB4q!keHVOUno?Rnwwvi
zngY_7sE}3)S~>+=2d1YG0NL9F+9;#|w+g!ZH#tAAxTL5Qv=&JLw8$(qF=Ztqcp0ZA
zQxT})oPPEU2)wb^WWL3eQ*?_pDYYcA_!bXjO9W&KMG>eA3m%dM9kt~OT6O-553w=C
zC%?G3O2`Ggq6d03mI6pjld)(YXy%^j7JEu&F=&DHEf&z~vMP1=)I2=}jq3E&JS#n;
zv|3Gt8U^Tv43Kb2W-&w<JQS?Sbc;DXH4k*OQgIbGc*GhUoC+(Mia^aV&{ESBO%^{t
zKj>5;J9wLnv5_0-46@?XlK7JR_`LkQ)DTU^B2X_<leq}ADZl7E0|Nut*@zRppw7O<
zl2Vjkpy^q395m5a4;n7!D2HsbDZa%5V%-vQfksqjeqK04stB~C89XO>i#aW`<Q6A<
z(NGa+01Z4*SyT%$u@z*ZAaqRxd_tnQ2((YI2y}eUEuNIjT*%B}S!!}oDM&q2Nq$i?
zNOv2E08P~tfff^i<FW`e4tk3Pv>E3XS3x3ZVx%OssJN&cq_hG=RDuXl1rC`9Wvxmr
z$}cXe28r$l5qm&HJBZ)`5z|0K9mrbd;>zM%Y{exhpi|-aLBW}l3OWO=0JL(e2sBD}
ziy1Vd3Of1Y7JE*9dU|GF`YjgFN{^x%kO6Fv)qJ=3AWMD0zA8&izQqmYgUZiaT%dJg
z@!;8>BG3p$(SESjMDU8PTWpXWaktouGfLCaa#D*{L1c?dDhpB}CsKe@0w>5<;8S95
zaU>-sXM>K~0L|N`78LCS>i}&KOwIr=%Leb-y~PUJO;CA@Ejd3gIkDsxcTp<HhnaaP
zsTH?a!J9NR)xbmiMTbElas)(701;^*^LW8?JFw&6ZgGK+)C28<az|pjfKzGFRgeZg
zaLSF(ECH=W%rCmdjTDx*_`z$A;|o%YAgis5qCi@hGgIPkG3MQ31s%mx1)8GBO92HR
zQ?lVLrd-2Y%*pvVx0v$_N<e2VB<9>=DM&5Ky~UJlbc+R~;}$!JOiV7h#Q|c(gM7&W
zKHvw$Vylb?iQZzXj0cI`VlGI{1FeNCE{;zrS`V_G4K#BdpMHxEI!~2f1fI*jC4}O2
zh!`)#cVHbUX~m!&(Xc`gKMh$UeFu~uIl=X=URq`eXkNVd0Xyh0ln+u2oV<{Y(JQzv
zn7CZzaJ|CedV$0B0k>d>-3?L6DKVhM$WM7hdVD{yFp5Sqer012RQ|}##>@AGfsL2%
z1Bm!4z#yx0QAYQQj4o&yvg8d8o+})Z*EwV_amdcdx}arsk;D25hxG*x>j$EeS47pW
ziyB-KHP|3^QPk#&s7(jgQ(280JYq9gu1IKJk+9m~a>d5~BQuj6*B1sRIWCZiQXe=N
zL?o^Vt6UK_>)?1QrF?@&;0llY0+lODCfAi5FDW@*RC2ze<b09G`9R1O*N6*INgtV+
zq(KHqb3Nzc=~C(l>4~|>C3S^M>H{kSCvOk`R}KbAwHrJl6H2a#YhD*OyCiOQQQYE+
zxWz>tiybOgtO75DM_&n#|G>;7$@P(eNfKn1BsR0ezH%^#Nl%Ho!6VS)f1O9|5|7$N
z9*rwJ8aMcbukb705R$neDSJgy?*k{R7}rMzRxz$GAfkioD<4#!<aHjEOFSwUdDO1(
zsDX8<!gLv7(<LC<Q$54vinR8MkSn?lCq&N3T;LD8$RBovKdgi02A|M`z|M>g=BLuK
zb8IdcBtnsN(naZ{E7D1z;FNeODLXfHf!oa7Ik`Jj_UK+Pb3dDWChLM^<VDHIE0U2N
zd^foHdu&1b;C&fyNI7?S-{6tD&ZBjSM{9-kb#2>A+O`K2PY7Pnw!NtBeMQ^*B9G4%
z9-j+5K2N1&=O!=EnwdW*e}l*sL&uA9jt2}cO1WN<a_!*1At?WWA&HSw=p&fyaJ<VS
z+~adWQfoos1xYQ?>3BTaS9r89@Mzx^mzfcNQC#<mxb6oAA4XoKk6>~F^Ib`~86{Wb
zEiOu0UXir?z@RIr_7Ox(aJ(y{ILBv2$`0lWGBy`wY_7=IOmKZ5r*uWm;JTdsB{};8
zELR)?F3JU7kqeyQenVVtLj4VG$BPmgGni&LEnr>{d_!LOioD@<dD~0!wl`$dR<K;r
zu-Kuzqxd5;lY-P21||im87wncK5#JT7$fPrE30%xR(D0xMOniuvWC}XO)klrYzWyB
zb5YjmimcNNCW7(`s!KRmm~K#7p}oWKf`aWu1=}kMwli3t$}7&PTp@Br)AFLc)fIWG
z87y}tq~>s5m(aW<p}9ikqJ;hx3H=Ka`Z!#jvPbhk@ClJ4k;tz9z{Vh_uz+Pw#0;hf
z((?1|XW6eXy(n#XMcQzJ%MCfD>vBey<cv0$T$Ho8B4;td{i(D9D22JK_g>|_gY%+}
z(-j@3i_*?lq@AxzyIqoYJ5X|@=Av}S73q)(E)T>dC)C|l)>vY>B7H;Air52;7nGeY
zDmz_KcACKf5l`9Rv?6Lp;00y-i^}#_l<jA5%;0#StUiO|hO*WSjt8<@Gh$a*T$DAr
zB5N|i?FK6Mft31nDceg@wl^dc7HF(+S>v-qWse3ZgTofbd}U)$P@Pk?LhYiw;T3tq
zt*%?Vc9dN-al2yTc2VB#f~@-lw+A9}*F`igiD+yHy&$4-QN;R+i1h^#>l-5C(|soS
ztVo&Qb5X?jiiq(A5o6GbKL1JnkP{g%h^XEWk(-XpzbK;mR6=$R&kB_d&KD#sE=pKj
zk+7J+dP7R?ij>ZEDeFs8)}S)X=c1JF6)E2d><@&cCfMIlx4kH$GQnvE(+uYY!V82S
zsOur|?@G#Dk<?tFbWu|Gilpv!NrOw01{*}S$Xt}PyCP{f!HJ-}jN%;61*I#J7G!TI
zydYzFQO5F$jO7H^2jbGx>nGJOD23#v2`o3H<gQEUU6RsUQF2kr<cgFD$a$g@?7%_3
z)^v^KhV&gp8)8o|UeNZssO@z{+v}pN_Z3<132qbIAhBeAQP%j1tnmalh__T!C-_eA
zeW0ST+<%e(2A3U5JCZJ{I9^e4oZ$OZN`8*kg0PEHI#;B0E=cOEsJ$p@GlA`afYfyX
zl}iFD8$>S%s6ftryC7hGLqKdg|0Mp40*Y4z6fX!U-Vl&N;a?O`d?2elKXz8^%=kI+
z*JTYZ$r@agHM$~ebY0f`lC1efS<5T3mK}aK6xEl@E|Og+zeN7JqRAyilZ%RGR}{^z
zD_UPtw7#fldqvUq0*~BN8M!$=H>9Lz2;WdPyr5)sL*4R%n$-t(7A?LH3@loF9~oF=
z`M-dO4&M)4;L1n&im>Jf23B6fyUMCdELWsl(X%_idqLOlqO$)LW&i8S(U+8?Z%E31
zU}jK>Vf@Izq$2bML`-10BA{?VK;Z)igNh-TbC*YOg7$SGl}kb@E95TN`d<)IxhNEH
zMJV7RPv8}vzzaNqH~58nYCkYA^9tP%kh?CRcS%6+hKT%i5w%MqYAft6*auz^QM)J-
zbVVfSBQujA$R0tV4-8C#LO(w^GJ}d(NI?sp+(((&2d#?$wP-(|VFu6aGd0@PFgMyD
zuOY>C_EHVn;W4$K*-_An67X;Y_T%EfbHt!^WMDODGa#&aphfYl;F%$~8t^O|6R0^|
z%a+bi%Z|JbpoR@~l*EN$0%NRaEeG<pAjG<^8l;2jY8Vh}Nzq)vS;LtEv5yPI98R#4
zL8IWDHC$`Z{JsD_=MPf_qHB=n|51+qt>vy^Z4|9xC1%}cEl-|&4bqyr1)wet$aoZt
z>^sy@spYBRX<*EAtKq5TMez}+69qP-h8I-_$-Y4~7qL=-yM{NNDTQeb3tAX+*DzpT
zxq!a@wwAAkVF7&Q0@$%oqJ|IE&9(eB>^Np|_!q$EJfNDv6zb?B_W3C^HK48_ST8bx
zss`0nj0_VPd%R0PJMN%bvfwKmYS^)EbYM?mtYK_LT7y{2j*=TeI|QI+At>~15Wx(Z
z?1}MA3=ELHLy!YkiZgRF!8^1dXD#9yR&&nC&j%gNqL2t$LIyd72D+XKJWv8!Qwu(t
z2DCz|G!-&%l9>lOc%&FI7KPYF2HRo>U&RLQ%cPco*1SOvxPh9Ikq9~|3N*@>30bm?
zum-Y(3-7>R9<Gg9h!Y`-^2@;^ijclAVi6puzl3g=f<|#kQD$B`mbHd{psiQ=X$pxL
ztEu6BL^Tn-xt0h6K_i`@(WvzNqDloEVF2A#Rt^gE;tbFZ6XbA5aRA5>Xz>6KVx*wN
zy5S5wIDl^{DBLDL@K99|XfYgk@i=4{iSZUoVnt@LrX*yDMm{+|uPn8w1e%EAL8}lU
zu?`w9%1kM01$BQlpxs{3V!phTc<?9%dc1&lo<N3Dz{4is?a`nM55Pk?MXx}^1rI<3
zX!Pb5OLBfe<t=eYgvY}UzJQI26ukv01do7#hj~B)9pIIRMPES@njkl^f<po_Lc$Fh
zxJb=QEJ?j3m<jSxaVlb;HEj1ZBO?RDEdg|g-{MDzfD^<mK7;^h0OyuSdQN^)Vh-4K
z(A<@p3|co2+LQ;{i31*^Kn!pcF@erCU<MH^Ac7l2@PJk(f|i7VVjme7fmYt!VuJ)l
z6%VB658+(~>A3=48UI0zfk);7xB3eI3p(yMWEHQ->Rp%hyCmy(LqK7HKBx!I2kL?I
zb@+W?V^Gn!qHKCyIqs5j+zkP_1uRz-3_db5sen|d@Lk|hxWEH$;b>goQFtIAI3eVU
zsOm)l)fJ2v1vEQYZ-Ca|-xU;_P&!d&ip&QFVNQ#WAfm$&viGj!io8Ck!{~BR*!7CA
zYlq7XA+Z^P6E!-V?}|uH@mY|vf%$@n$wd*9D<US>Ma(aWnD1cO!*fx@^@@mVhwB4=
z<qQ1!H>C8gOBr91GQKEfdPT}~0{aB^2hbL|%Oxon(6~V8MX9hWQehL=AMlHH)ZY*g
zc_1n?f$6TO)D=<H1xXi0HLr+jUKiE5B&xF_WKGOPQL8JWRuh=+O32RPTwuCFX@T|z
z!wV9o7bQ%uNSIDweIP78-F}k&0@I7anpcE1J6vvvNnIDyx+JEx!sMcu!4)xs4)+^8
z0{z}y-ZR+e^UdPBD4}shLgTuG&Ls(*ixPTQB=oLJm|T)DxhP?FMZ)YNkNFiI^9wxY
z7kSK|@(cFV%}7}xc!6L0BER+(er@o1YZII%1Ws`3spw#N02*6RxXy2Mi67Kw<F~lN
zZ*hUc0@C1vH_8|ww<3TWe4uUq(?L77k&n1QIvcoA5qa4rj#Gh=_AjDWe4vRQ5Y}X>
z(g{w5%=9Ec&L>C!m0Yle*&6EZ0io)kCB^E_0il|DewxfkgRh|6tO+^h7qamm<QM1>
znYXw=BZv?qVLLt$B0NkC48@>jGz|<7pu6z7`FrdyaH~NFX5=pM$jvC3Q*)6==L(O`
z1sM9k%)q05fm;nT=s-IJz(F^Q1>9Cf@_Zv7a$q%rYHW~0!4-4^Kl087^bPe*j3w}v
zC1|}_2~tZE=S~LDIu7_rB)HbOqDDw9V_pH8O`w7T>_P^HEVvs`TWUq_HH?V$tjIef
zko}8T<qF*+rd<M>?|_<}!i>v|*`U1-NJ$Q1&TNJh*14>xK4@ahLpfn21+Dep#E86N
z7HPdKBlfmt3fgi0sJ7u~buu8%H9=JmI(rte?*)5{6D1rt5J7}y5_1i6qYDndpv^{;
z>`E3CbKz^UYM5$Rz%6x7)cCJuLRsO<gflH*U&D)bF%QYEWvyXdfW5tol=l`O0uq|y
zKotH~IZ7zNR{^8yF=YT<BviwSzE7!^ErkIDQA3K6A&*;`!H}VtrJSjPIg+8Ak&z*i
zfsp~UA$bAPxi}!VK`^Rq3*hUDA+it>O$~g7FhmwYqN)LHPWCGWEiQSsV&T(y-A^~n
zdNya(i*;+C?{0p-r{~4`?N3`fUTkQ7K4pahWM1^yw4Kj3Plc#}=y}}G1a8-rg8C{i
z_Dz1;y7bx94bOVlJ)hO`v}5|SsT&kDN<r?`RCuzZ52_PG0jPTkQqcRfXXne--p371
zrJ&U$$mTuW-}-Xl?q@UBV{8Xfc-AouvS|q34t>6H_tSaZAh%ruHPAriE2PZ{I<-Vo
zt_a+U1MPu^v?qCjO7lP$8bP*@fHtBOA*SE%f=mESd=(u4u?~WWdmsX|)uiYUhy`w2
zg4>86L0nKp1gV!~KrJnByAjkKM9d7{;?K!XOo3de6rYlrTyl#GeCim4&7PMJo-RYQ
zE{kLt85lI#A#+}!7Tzsx&;(sP<d}ydMAPmUs42=;3_9Mk81qO3_JX2RkZsW0r}Us)
zAyD;~&IFp%0gc_;U*J|+ka9)M{DPYO0h24vp%<KE?+VFXkdM146n{l1{sK?@S2oaI
zmXx_pb9^pxYhK~j1np%>V!R<NennX2y0G3QVbB<%u*nr+lMa^-m%HGDYLez;UF6lg
z!mA0|!;-`Z8mFq7q1aOkY7ZJ-QE<M*?|gyZ`L3A42POtlU#1l)YqE9-uF3tt0OCQ&
z3z4x`BD1bU)<C$(<VR*EanN9vIAk#E12cnw3*%QV@b;F3i$aN4gc2|CBz|RM5SO_k
zrnaDDdG(^|4VD+xZLg@?UKF#tB4&48%<+<#;{lZ;Iv2(KuZa0~aDQN95Ej28q;g%z
z=8}-jj*vZ3Cm8p{UKH}aBIMn{euG=+0=L{<L9r`>iVM^)3Tj;u)VeOHdr44tMar6-
zi-Oiy1g$%q?n=qe(OeL{LFs~u=?=pSQnnYRY_CY!c6fhf15YA}U*VCRk$8ni0W=|H
zxuaxH1rgImAJ`Z;#jbOxUgA()kg_EIx|+i!HHV98PFK{NE^;_u;c&jd;rx}2fls(6
zwllsb9<*J@08HEvRhkgH!1|)7{uNRE4mZ$l9*6{3_9Fu$r@<Ex0p{W)e*E}>RgR58
zK&-#EtG2Vgr~W#>(j|VSi~K5A_*Jg+YhB{ky2!6{g<t0ahfXnQe)_l(j|Ug?2`&~7
zKK2uWj2?XKntVl|U6w_lmC6tz50snoK?H(T0OEp*;Sjf?LJ%8NorGwzfJ-5_BG3Y-
zBG4F$CQ}jUXb(S4CO<z-ew=5P-{Jvfyv$_CiVw)?2H@@;Xlfd=(g4)<DFRJ{7O8@=
zvpR^-2N59i!A(DKm#GMJJ_ne{0Bs-^0e6#N_bPl~(qZNJzyKl=7#YN+Zip$|kW#rJ
zDtSW!G{tvAM&X9A*bN!k8)8y71VnG}3*F%1yTL2?L0+F#;R6FIk;KTrCiH<xlGW-1
z1C02{%)lnp5(5!66Jpi)zyK#S#8?eM*EK*1194U_#t#fgq!HM>bS74z4-5zrZX{SJ
zh!w2JjMef31B`&H00|T@vasrcj&Fw&JW3!gCIJaBP*`ay6@fN(6ukh&BIs0yqMINV
zsHlT<quC&P7a?UaxQqogsc&(VfR9(o%LCtP0J#<nbUHv0sL@>n+LHyo<p6SqIA~K<
z5oqzmEg|p~UV7ki6Vjb60<QqMC59vcUmTKJ1YS6DOB^CujCAprUQ%LlY7uDUvIvwk
zZgD3i7H1|q=jWwmrr%=m^mFsS#T#5wl$x090lDw42weDs4mpRcB>{OJoH!w20qSr6
z;;_jD-{WUjWXZ??I<JtS_$ng<!v|(YM#dWq!WS5XZ!qvRfZ+`Wt_Co?!NA`Dh7TBo
zF2E2(@dc<Bo(3?y!61GC72RM+x<D?v!61486+K|kx`2vqFc@7xMK>56FJMCt7=#<x
zuP_MT;1cZ7xxsC6gHO35{R*G*MLyLle5x0C)jqJZFfx5)U}0qX0wO+u2`+6$vkwfI
z#00mGAki-%0wN#7$j7MvfdP}~DEkN!{Q@E&@=5}XA|Dtqi3w~UL84zk1VmoVhn*2L
Wzl@da@c+mFlK%oGKVXpHXa@jQ;6yk8

literal 0
HcmV?d00001

diff --git a/ctgan/synthesizers/__pycache__/tvae.cpython-311.pyc b/ctgan/synthesizers/__pycache__/tvae.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1d6ab1a3b33b5e4a9046deed08c1c7ac4bfa0f48
GIT binary patch
literal 13794
zcmZ3^%ge>Uz`$VFos{mX$iVOz#DQT}DC6@B1_p-d3@Hpz3@MB$OgW5EOkkQhiUmxw
zMzN+arZDHQ<+4Yy=W;}G<Z?!F=5j@G<#I=HGlKN6<nZM3M)88#tT}wS{89X1Hd~HB
zu3(g4u27T^n9rUgoGTn93}$oWh~$b!iGtajIbyluQR2B0Q4)*{ObqS}DO@cKDcq?X
z%a|D$Rx?B0$q*$8mg8w*Na4khlS<)hVTh6jtKe^8ND)XCN7F5n!W7J)DOkl566WZt
zkei=Unv<&c5@fiaChILWpUk|}#G+elzF>)4oB@eNiMgpIsYSO~f>M1#Z*c{u7M7;w
zm1HL7Xfod7O)knWE{;#lD=Eq^sMKV<#p0Nfn5)Tji_0akB+(~7F(tL=7H>#uUU7aA
zNTfKm1f-~>FeO)$@fJTwD5NMcuQ)BgC^xl8lj#<}Q(|#yaAjUeMrv_pRcg^Ko}$FO
zl>FTI;*!LY)MSt|VVDWZ_^iRez|hVxogtMWiZO)&6m3z=9SmuVDNHRKQ7kFUEeuht
z9SjwWQEb5snk={2UGtLjQ&NknenS}wY57G8B^jv-pb*jH;!;phP;e|tFSY`+K)jU1
zlEnCw%v=SH%)Am!gg8jfB{MfQuQ)S5uUH{J4QvcZwH{n=a(-?>QEG89NPn?{MoDQw
zPO3tFkwQ*paS2FAW?o4#nvuboRjDAU)WqZrg^bLUl+-+hoW#o1B7~`_xk;%hDVcfc
zAY-vP3uXq$U-_ja1*IhlWvR&}`A{>GS)ic@qSzQ17(mJEvjrmq!&JuU3?+;pVF)gP
zg);*KLk+_+CI*JpaFGRY7DO$CT*kn_uo}W-U|?Wms9~&SNoS~KEn$ZXGSo0FV`N}h
z4Y$XIAy&1Pt%kLP6G;tA4I8R`CEN&ER9n)Ss(2X~YFTU8YFT>JN_Y|KYFPUiY8c|-
zcB9$AjA}y-Lp%peFF}=Xx27-zGiWmUX>#3SEiNrcEh-XZU|_h#5g(tKmst`Ye~Sea
z*+rlrE&_#h5y&yOn2S>jZ!xDO8-fTUO^#bE#i==Iw>Ut_2NbKfcu|thEnbv3zr~!A
znOh_VvWF$JBsKRIXMB8ePGWI!e0(v;2cV>>prFw3%f!VhCM7E|FFik?C_gJTxg@3}
zF(<nuvn(|xxg<R?FQyog+lpgK$`Vuc3Mz|47#J9;1fUs259SdW1_p-WX$%YuKN=YB
za`5)1Pbj&_A#;U8<~oP!B@We#9O_p%)Gu(T-{2GLV7|d8(7}9HL}G^QMG^HYBI+L)
zm^j73M2E`_9{#CrGmICoToh5gBBHvqWMSEovJ1j`7kTtLoNow;bvSo8KNXgpQG7vI
zxr3#L<F2se4CNWM7ln1N2<u`M=-}w!_za4<WaKmnO68zn|GWcSxYsZ&04c%1HH@gK
zv6iWZVF5gCA}dEvx3$bQ3=5ELge%5S$%35zvp_Bbt4(37VX0zeU_f&LYYlS^YYj^p
zb1;J@Q<XF%Ccwo@L1J-nYKlTaW?n&QNqmA{I@lP72lkq*;Pg-=&cMJ>BmqhptR?wH
z$r-npQ!5HI*@`4VLTnIix7gEC6H7{qQg1Qkmfm8^$xkm!EV{*9T#`}@N}-U{Q=|aW
z!Jd|1RGwIrQpFE(M3Ej;095W4hcGZOG%!365}U3uNn?iDMIogtLP{O%H+Tj6Bf26c
zq+jHfyTU8i!E{4Vbh^YOi5cn_1(mJ{Ds`~k<rD77oWVRJZ6?n}KGiFHsvXQX_ys2z
z&tU7QzrZ0^q{zU)u#z2=iSiP2Q{&@ramB|&YODD8TRidcg{6r(P#O04_>}zQ_;_%t
zEz$xxnirZIl2dbX;^S8`7RiH>E+|SJz`g?KP%t3?k_d#P=>`S}{J_M?%JG2#M2K;*
z>V06qLP)T&%70+MM}WNmwu3tLmP;z6-ueS&5U#gSN(@4^1UO-!R1%<6kdlg03Q?^B
zgOzgFT!pVn#a?mM43mnhhIIib`$2Lv5{bVmLM=KN8A|vOTA|e+s7e3}mI#9>Fi|23
zrocoNylSmsMl~;m5wm!zVTk7eYhz$w09SrYe(AGjg1}CDP2M7q9Yw~VVhB`^6oITk
z)PA?v5(^4a^HRXIAf&_swZV&&LHU*?KQFcT7He^4dTyd72e`;80@XUV_|Xz#krc>y
zSr7rL1dHTAEN~_Vc@>m$LA4(wuY(*>B>>Ge@XE{<l=0U=GCuywOd6>&lfNjUc|}AM
zT#tbX;wrNT9Nhg}U0hSmW+>07xG1b~MOfn^hvpRy%?liwcabYPQMtLLbE;NoT$DGy
zB5%AQ`J%k#MNz9OqE;O&JzO1J9bDj21*MV$Rd68u*@~n}u7)8CR33t(F%8rwV62jc
zgb~(iMw6)s9MI+r3=En~keE>g1tF;RQ&50}G^p&Y;)j?5uO~b~s-z+H1ix^9MOVcP
zwu}7oSNP=@h%Ax1z^{FQLmTX?B5hD48iJIATm~(v!RZJ@fF<yj)YhPq8iSBvV%7P;
zfRBLK4zeQ|)HDMJ2?GNIs1*qg*m=z0t|4PBdpbicM-8hB!vw}y^;*svPB2TMmaB#<
z3#N-9g((}Pi=k+14SNk2B9N;X7#M1~YnW=-K{YW%c`Z*3V+unGLn~7adm0l$yabdb
zAW9grz@~vXH4NC)qGs1BE(V5L-WrYt@D>J0Ck)r{rZ87=FfgPrS8+2i)bK!fAoc<z
z8)2#$7*Jc;s5UV&;Iak9_3#3%h6h!D4dVjj(1u%zYG(=~B2+L7B1VQ9?ppSq!&pNa
zM;MfVTIW#P5FuK_-O7YB1~{S084$6Bu0Dlj4J&G2B86=YJ8I_})&CH;Phjj(n!wnj
zS;Jn#KvZ~eAi^UXl=m2lnY<X97*jY=LLDBPDI7H{pp;j`42^3p28J4z1>gY@n6F?o
zYRH395~_<q;a|&F0<S6<7-~3BRn)M7)Q}Ogpw!4mo<4Bl#pPE73XEIq#rbI^xrr4o
zLA;m$|NsB5$yfxcCT=komx2ZzIEzwKN|Q@6^YdPU3K>}IdA}z2E#AcBlFYJ1kX(FP
z9#|K&Rw_x%%P2Ahl}>h`T8cR*Km8U9cz^;_+uUMHEK1K!tSC|em5uzj*o#t=^Yh{>
zZZTHeVgpxY#YO5M&7c<ZEpAA6JU%loE&mo<S|X@Zev7HN<Q7X#esS?F?&SQO(%d|V
zG-q)^B9zILno^_*vYZLjZS?kWDF)T@(25R5flJ|AeDNTY;?qj=!2T=>VqjqK1C^?G
z7(tEo8=PDXPB*xD8(eOPD}7)PWR>^`B09KlNUD5b5a5*f2qHQ>Z}3Z8;a6Lsd5Pcn
z0>AMMQSlj?b966?s!w3LAto~;XiC-u<{P476AD3ahSG%6DG^IqmvF5x-(U!aNgGVp
z*j-exx}sopQPgSz(+vTU>D-gJX9O(}oEf?xaAxER<rRf1)h@~!UY9YuBx80_#^Q>M
z#YF+jD*~1m1T3M-=eo}ET#!7|Z;s!J!1ZCP!d6DCiP(_1Ci)_>Qc>{<OcR(surY{9
zAlwMC?yiX94GH-Lj9_TAAZdZ=0*)&Znrlth*z923p$LXS2bgw*?$Fr7eM0Dhw$DXv
zpDWrv7b0RmFf)mYePmz~75f4rJ~Fci3VmT<5fu6WCb$^HBqlITWcsKmz$x)Tg@K1}
zg7|cqNiq}VrpPT|ydtP{iCgIc4Bg=3n<_d*a)#lIvYA#Zm{vHhWW6Y@yMb*x*Cwuw
zJX?4!O50o%w7nu|dy&hw!TByXUq?x&Wsl{AvWwiZ3mBKPEn-{9v4mrV;1ceO+}bOG
zHyCa=+hn%UVvEHN$1T<ub?q*2+g)I>1J`3Hoefyk@i_t9Hb5RC<77aGDFP{6EsRkt
zsjMl2DLgHVsKr4PTPkY`Zwp$L5XBB2apXwh4`$F5y2Sz-L#tv1k$Nv1KwUVdqEJxM
z2NBOdW<2iL`s0b+%P))!3@>#U85oLe7#J8r+>()tBT&T!a`)#bMo_g|!&t)*4=bS<
zYM5#m;z1b~EK<W<!w?V3=@1@E4MRLgCxpjZ!w?TL3&LZsVTgwpiZyIC4Ds-a3Aw$T
z1<KH19q9}yjEfl8pbb>jFvKG&1I`+Tc(}`JIEv&-KusjDZU%-dcnf&}s38FrMNp`r
zTg;roT+5XwRRSujpo(fZN<ht8C=a!ADPpeSsO8FIN@1zxDgpKQp{lY#{aG-(h6}A|
zW{8JZeBd@atKUlITdb}H`N<i#Se+6}k~40x_~aKCuVlQ%4k?6+Z!sqql-^=VE=@_)
z6aqKIi$D#zA`4JK#GYeRoSc}GdW$Owq&Xf`9^T@H6g`kqjV%>qc=0W!f|OgFput`@
z&|orDO<7`2X=?E;_OjHXr2OJka7kN~1TvT>vp7DnEHN`DF$rXPW^ra-aY<rcaw=$m
zqUaV|N@`hVa_TMicnCw28`2sD^%D@zD+1+%q7+ctO9K(9AOd835vZkBlm{vgK@}OS
z=v08UVZprxUT_MAjeF&Q6mJH%U_VGO$Z2=5T;Y(qA*%(U<@G?c%mZ<`4wfFC8&WDA
zEIs@;q}4#85;w$TI#_zR?@G(gvAJO21x3={7p1+gNP9z+Kr}w!7yiJ&$|*C~bdJ?^
zX`@ThMi-?`u1K4J1#gH-b+~o7-Q^XZ!7?}bBCpaFUZoEV%$zbeL?uBI4@9Lq+@6Zd
zOsQGmc2QjWinun&Egd`^Jm4~6B||biUobE*fcy>Oe{SIb<uK&oUer{HlGl*>3pj?n
z7;BhNQz;_@axY^6$WCzDL?$o>-x!g5yU2}=EO?H^)lf)58^}Y=0Sn-H8^wNxEKV37
zRXr-NmKnLf3@UjLedZLV6y`N7%UBo~R>Mo2S{7u#BT9o>))G)2f?8R_TEmdV2jzf5
z3skFvxu{J?Qr%w5RstG91DgaIxX*%{Tf>kAF9&K^DYK=9Azlb-1w#p_i38=7fLbb0
zb`9GC&;S6GhoDeX8fsXh=x4}++d;ZX;HED_7Q9?WZ=~0<m*^whlm!|#1#{CGQrH$T
zu3<-YR}Divy!BlIY7c_dGBA{Ydc{z74MP^FBm(o$b!Wj#cOHf{9H@DvmZOFNAz!4O
z!dc5%!?6IV1^^ogCTchrfW|t&Y-A#ZYYjK5ov1uem(7J?0#mGOEmt~2Eq6LYEl-If
z!h{+YRP`>Pj&4{C6R1mB%U8p`04Yr&bfLy1BSQ^e3Ku+>7)VaJ{7eirY&HC-E{S1g
zV5k+S6|5CPDG`)luBZ`2RbMMyBU~d0?u-b4(m<_94Z{L>I)do}(KRBdI>EjbtraVA
z2gyQkjc5&jjaZF9jSxzn2DSIW3Tha#Kw}|b_A+(`hSl(nE!Z^iT8SF*8VRsFBx@uW
zsDTxt5g4J!z5pplArxb%L&_gnpdJQTX%;*OE&vV6fW?rB6rL1b)bz*5P{XqTGy;RH
z49u?KMGYZP#}$W4$pt!KEocO)n-}P!NrSm)YV^Q@XauSnbXQGa?1{jhn`^kz8A}k2
zNCxzw2yi;#so_pxn}as`Su0h;kOfc7Sq3l%6`P{i1R9$HOChVSkxEBvJuWZ;D?%eu
znAY&2`USb1$z!jPLgu5C67e-8mJl^`jUhs%(FCR*cW}=^ATfZ6fx#`a1U&W#YJ4dm
ztpZTU%}+_qDOM=S2TPX}C1&Pj=A|ouCJ~?$_Mm1Q>ZCnL1gtV1q)b60uQa!yQZFwh
zv8X7qQX#)cp&&6YC9zl!-Y!L(g$L;mf!PF7q^ICnqL5o!T%wSas*tE)q?-bp`cKSJ
z0GqA=H#-r-bi~{}$l#RB;^d;#lGJ!;Lt8-uG)1hNlbM~0#Th=(wE{3*3NT$I`5?EH
zg60ED@)gok^HPfvOH#ps<eZ<Ek_j5hNX&t+MF80h!eNk=01BRI3JD2_FiJ>J$ShV!
zRPY0Z6vQt&3MCnt#gN4c#TogfIVl*1CFkdrBxdG;EzQg;Nli~JQpn6p$xKd#uZVy;
zS5LveBqOz`JhM1eM<Fv!A+tmwvlzp?M1_O|^r%Tlz;uWLvc(V^6hIS##fXr;#h#Xv
zpIBmS6ym1I4DJiuVlU1pO-sv3y~UYQlwS~^lUNK|k-}4+nwg$a5}%TqoLE^D4eBa@
zN()WKTkI*B#h`9Tl>n$~WTj9I>fc!D8K>2P$D1@6i+mUu7}Bdn*112dYO#k7Cw;I7
zEpq|c;HPO^1nU0X;x0ifbh*WxmRWL(6E0E&YLJ6FK)1Nkit=;g!KvvMTVioZWkKpK
zrjq=ke2|+!eNBir1fhdS@FZ9ao@j)0e?ZL-$n*=Op9Y#bNl7g#0yT5NlQFlr3Si56
ziXr_Zkh_aAKz3w;2vFDc7Hd&rUV3T~xc3C=Gu`4ZP6Y)~adJ^+0VrkNVgs$)Ni4a=
znVMIc3r>2sIIB{N^5fHs5>sw*ftK#%#e>o`cnIp20MyZ-k*|1AdM^TX%5HHaB_?Ns
zW|(iW6qlqH6oJ}Ow^%@#_7+=mPGW9BJb360)bA@vOwK4u1i28@EV{*<lbUynH8r=O
zr1BP9a(-TNV#zJ;qEwJGGxJhXD>S9SJ;PgEh%|mn0G5ehS??A*$g1Mh63{d_Xu%SA
zYUvhxDrA|^E$;k+lFVGtVkAe`TO64sAb;i;-4cPZ;!{!)A;Oqh1PZ`gtl%#3E#}OW
z_@XXQpO&R4H8BS=3J9v1i$E3iE#}mM;#<tAxk<N}i&B$|JU|NBKnX0~5agKR;&>x)
z^18(f?npz1^-|J`LH!3%KE{E;qku(~pd}m}prTYSEwiKt)P?=Y3>yFYD#ak9eNkHH
zinI=-|H;8~g+t;xhs-4onHf<RG%YW3SY6?;y1-#|LsasLsOoi5{Y#?y8@MitT3->h
z?%=vBB)^pPih{ugsq2PLmkgZ_SY9;ry<+HlQNizug5Pz8z)K2&7lOjB1jSuch`*u`
ze^Dsmicmra`&~il8SWccw(y@2xeyt5K`{QJVEh%q_zt!Q!s1tim9Ge!c5vJklANKj
zz~zdH*$$Q~7G4*Hysrp(Ul;PbB;<FZ<V?*)p_nT|F&*r8#pPy{T#+~2U~<LS{e;9t
z@t`Z>LD$7YFNud<2#dNBmUK}(`HFZls1JNsOzJL=zy$H@g36Z!l`jgaUJ+Ei$fI_J
zM{TXhij*}u*Y#X4>A7CibHAeJeo@ooil)a!9*+y*(O1HgKQJ?ibA4oB66g8?B09Lg
zaxloq;q1rX6&9b%IwNIH_C*P;D-v2eME1yBFm=5k?0!+${fe-A2geO={vI39q@fq%
z4Q0~{JPJ2>M6UCwT;fr=plW$t)!~w=!-2#Lsty-bJ+7#FT;%b*!sB^?$Mb=p<Oc>V
zPN9z=qQmiqluU>B4H2;p*SoT63(79a>Rpl5`@kT^YyS~MOyr!vJR#(Ufb?|%)k^}Z
z7X{R>2&i8WP`@EAJ;QNI`ShwuRSQ&>=v)-nzap+bf#rsh<^uH<DhqX&=uYIHz%;>m
zhS*&(=__Jt3yLp_X<rf3p1^!VOmc?v#LNXs6Z0l8-;h(9pEE0Gh4Do>{VQ_%6S*fa
zPYAoKsDDF5{Dz3c4H2muB4S`t;);mM6%o@LA|S>`ehwv}FAN+?LKFEWu+3omAjBXc
zw}5S8{R9@!I4|!c-iiEE_$Tm#ovm<5K;fc*(iH)v3j#_v1VpCrTmkXauM6m163|;w
zvLP5gNc=!RbUM!@o*BUx1r)CcC|(dy1e>6ANkD0V$#UyO)(dTy*j`sQxuk4zQQ7Q@
zve|WI>r2Yk7nN<VDBE5Xu)88)cR|4JqJZ5~8M!$=H>9Lz2+u8EU^=sEPSr(8%^Q+(
zHzbrlu(N3KePCeG;`_+JBFp~;M0EIm;9?Mw0H<2b4-BllhIf@!msqYyxuR!xfcJu~
z-$iBrE6V=Ym7^~yN8gZ?{lLti62thBfk{Q^3y7G&az#Mlf`GyY4h9uNFy{fka8Jz#
z24-HN8v+v71vD-RXxtFdxGti9Nko4`$oA+>(Hmp8#9lXbyJYHi(bVILsmFCw-%F;x
z7ft=InEGE3(Z47Xa785GBQq0dkeUfJNX;ZD^z%a?qZX$SIQf7Ff>0)oLBpY-A(_vw
zm_WloH4LCq9?~|&zWNEZ`3+jxSHqCS1ku63P{qW+kirP=<S?Z()G{NlcS2tc<-!nq
zx|RibpeGCD8n98wX4Wtu27OSwE48dOtWew7P|SgMy=qu%*w&y8NRiwb1<lHVU4>56
zpbZUxJ9-?ooG5MqRoGyS=<{T?95u`fK(#(t7MZBwSOBVhk%ho))UYRJNO=K#SP5(Z
zlt48HwcEkS&~v5)X<z|qFqs)Vz=g8<vzEDrIfb=`u@!0F8nhk`F}$<@dxx93hLf&-
zL~(U6gC<+mvEam9&^laD&96|LnVXrDSVVg7q8MyQu>zJ=!JwMiuQWF)wFtD3w<y0H
z+<Ad$K<cpsrIwTy<sq#!2I&Jg{_u5F;7#Tt2hfBg(=C?V)WkgS>>#vJdW$8oBD2^}
zQ@E%e)TUrbElbP+Hyn!^L5(NyQX5uKORX3@sDFzsFTW@^F{h{&Bnxh8I)PZA5w0Ro
z%K_Zxy2T21*)68L0?;fasI`@vmspZoQ~|O>Ff*?#wWv5X9^R_e<bt%m*dg8qE$b;s
zEhvVxhZy6FmVm4PHHL1n=H+J=r`}>oOUzCMw}Og683z;=5Dae6++u^6Q^gH#(?B>k
zK+PIU@U-R!F$PZF9-iwQl9xCnFLFp<;gG(-A^ku=e1^#!`-=ivD*`VH=ykB(;1}#*
zxhp6(p>(3m6qyeUY@DWGqQmj7u*4MW86|V-H?UkWaJne$d_~x~!{vsM*bKpm8XeAe
zMI@*AEJ)eFd_lzIqKL^A5tHj8=9fgwcd+c?xhUd#MZ~qk^@g<ke4AM|D_Ac|>s^u7
z>-6bxnV<$T*J?(}MNkt*8MItTpx?X8dj`vV?pfS3dFSw6m(aW<p?Oh4`-+72bqT{u
z5{4HgjIT%-U*s{l!eep)jGpof_SDTtSs{3VU;84z_7#5Z4wf7I!V^3ugirA3sRMJa
z@XKH4H@w7exItwH+kW0%ygT{#@L#v^zGUHj(ZctNh3|EX;7b<47cD}sScG2W54*x2
zc7Y=dvswUE2_XC#G;che0d4+`@GJmQHG*Eeq%eY)v+4w=LOS{hkQtl=P~J;QRRDKn
zHPqb$Le)VXbam%|P)$8QO=d{%2b7e-DHOK+2t0=as(zqTH@CPzEq{oSRs4`t0`9SZ
zri(y}gNs3fbPWs-#AQJ(!y6ntJv{yVUHmgx=I~tPP`$#TdI5&Oq2i~>37NbM1Eo|j
z5e{NSfQU#C0hztjWGpHGWkk**(BM}QXmqOxG<;PA8d-)cN(GI+gInZ9n?U-ufC!Kt
zusz`9T?AT53noAvh~l~6*(S(x)CPtdEc^}LA2=AqrEZ8R+>ijZif#yt-QX9xAtHK%
zSMY{_=m$|b&?+ZX!hw;2P3QxYCnKxS2L=TBk(q%_s3is}5yk{oB*7Za_<;dL!lgl+
z4<^j4Y9AOd35d%<j@FcgcyuMmQ>#E_0UM<MUIgw{f(siC@JeL8ygbM{K+q&f5vamP
zS`RM_Q=$i3xR97r1e&qA#SM`G_rr=no!DCf5RuZ7%$#C9P$6CfDlj13Xi)8WOANfW
z9V~??kWz~vU2$=UB=W*{y`;qA)FM!YSp<$M&<qKr3k}Ky;AjEIF~rxPb<w{#Y#_^#
z?TYq*w>dB{6c;cuFnnNUWMsU-AasF2=mD6#048rR2w#Aq8w@-RV0eSU;R12!27~Se
zRCI$u<^n3Z!C-a)8@j<DdI1$ZU@*FXif%ACUciPvu*osXe_+5QI?_IZM8ALth`bY{
w6r;`u2228^REkmO3qs0Dg3<m11Dr7PWnxtOz<{0Xu>S~_`2r>}Re@s!05w*xK>z>%

literal 0
HcmV?d00001

diff --git a/ctgan/synthesizers/base.py b/ctgan/synthesizers/base.py
new file mode 100644
index 0000000..add0dd7
--- /dev/null
+++ b/ctgan/synthesizers/base.py
@@ -0,0 +1,151 @@
+"""BaseSynthesizer module."""
+
+import contextlib
+
+import numpy as np
+import torch
+
+
+@contextlib.contextmanager
+def set_random_states(random_state, set_model_random_state):
+    """Context manager for managing the random state.
+
+    Args:
+        random_state (int or tuple):
+            The random seed or a tuple of (numpy.random.RandomState, torch.Generator).
+        set_model_random_state (function):
+            Function to set the random state on the model.
+    """
+    original_np_state = np.random.get_state()
+    original_torch_state = torch.get_rng_state()
+
+    random_np_state, random_torch_state = random_state
+
+    np.random.set_state(random_np_state.get_state())
+    torch.set_rng_state(random_torch_state.get_state())
+
+    try:
+        yield
+    finally:
+        current_np_state = np.random.RandomState()
+        current_np_state.set_state(np.random.get_state())
+        current_torch_state = torch.Generator()
+        current_torch_state.set_state(torch.get_rng_state())
+        set_model_random_state((current_np_state, current_torch_state))
+
+        np.random.set_state(original_np_state)
+        torch.set_rng_state(original_torch_state)
+
+
+def random_state(function):
+    """Set the random state before calling the function.
+
+    Args:
+        function (Callable):
+            The function to wrap around.
+    """
+
+    def wrapper(self, *args, **kwargs):
+        if self.random_states is None:
+            return function(self, *args, **kwargs)
+
+        else:
+            with set_random_states(self.random_states, self.set_random_state):
+                return function(self, *args, **kwargs)
+
+    return wrapper
+
+
+class BaseSynthesizer:
+    """Base class for all default synthesizers of ``CTGAN``."""
+
+    random_states = None
+
+    def __getstate__(self):
+        """Improve pickling state for ``BaseSynthesizer``.
+
+        Convert to ``cpu`` device before starting the pickling process in order to be able to
+        load the model even when used from an external tool such as ``SDV``. Also, if
+        ``random_states`` are set, store their states as dictionaries rather than generators.
+
+        Returns:
+            dict:
+                Python dict representing the object.
+        """
+        device_backup = self._device
+        self.set_device(torch.device('cpu'))
+        state = self.__dict__.copy()
+        self.set_device(device_backup)
+        if (
+            isinstance(self.random_states, tuple)
+            and isinstance(self.random_states[0], np.random.RandomState)
+            and isinstance(self.random_states[1], torch.Generator)
+        ):
+            state['_numpy_random_state'] = self.random_states[0].get_state()
+            state['_torch_random_state'] = self.random_states[1].get_state()
+            state.pop('random_states')
+
+        return state
+
+    def __setstate__(self, state):
+        """Restore the state of a ``BaseSynthesizer``.
+
+        Restore the ``random_states`` from the state dict if those are present and then
+        set the device according to the current hardware.
+        """
+        if '_numpy_random_state' in state and '_torch_random_state' in state:
+            np_state = state.pop('_numpy_random_state')
+            torch_state = state.pop('_torch_random_state')
+
+            current_torch_state = torch.Generator()
+            current_torch_state.set_state(torch_state)
+
+            current_numpy_state = np.random.RandomState()
+            current_numpy_state.set_state(np_state)
+            state['random_states'] = (current_numpy_state, current_torch_state)
+
+        self.__dict__ = state
+        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+        self.set_device(device)
+
+    def save(self, path):
+        """Save the model in the passed `path`."""
+        device_backup = self._device
+        self.set_device(torch.device('cpu'))
+        torch.save(self, path)
+        self.set_device(device_backup)
+
+    @classmethod
+    def load(cls, path):
+        """Load the model stored in the passed `path`."""
+        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+        model = torch.load(path)
+        model.set_device(device)
+        return model
+
+    def set_random_state(self, random_state):
+        """Set the random state.
+
+        Args:
+            random_state (int, tuple, or None):
+                Either a tuple containing the (numpy.random.RandomState, torch.Generator)
+                or an int representing the random seed to use for both random states.
+        """
+        if random_state is None:
+            self.random_states = random_state
+        elif isinstance(random_state, int):
+            self.random_states = (
+                np.random.RandomState(seed=random_state),
+                torch.Generator().manual_seed(random_state),
+            )
+        elif (
+            isinstance(random_state, tuple)
+            and isinstance(random_state[0], np.random.RandomState)
+            and isinstance(random_state[1], torch.Generator)
+        ):
+            self.random_states = random_state
+        else:
+            raise TypeError(
+                f'`random_state` {random_state} expected to be an int or a tuple of '
+                '(`np.random.RandomState`, `torch.Generator`)'
+            )
diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py
new file mode 100644
index 0000000..704468e
--- /dev/null
+++ b/ctgan/synthesizers/ctgan.py
@@ -0,0 +1,564 @@
+"""CTGAN module."""
+import logging
+import sys
+import warnings
+
+import numpy as np
+import pandas as pd
+import torch
+from torch import optim
+from torch.nn import BatchNorm1d, Dropout, LeakyReLU, Linear, Module, ReLU, Sequential, functional
+from tqdm import tqdm
+
+from ctgan.data_sampler import DataSampler
+from ctgan.data_transformer import DataTransformer
+from ctgan.synthesizers.base import BaseSynthesizer, random_state
+
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s [%(levelname)s] %(message)s",
+    handlers=[logging.StreamHandler()]
+)
+class Discriminator(Module):
+    """Discriminator for the CTGAN."""
+
+    def __init__(self, input_dim, discriminator_dim, pac=10):
+        super(Discriminator, self).__init__()
+        dim = input_dim * pac
+        self.pac = pac
+        self.pacdim = dim
+        seq = []
+        for item in list(discriminator_dim):
+            seq += [Linear(dim, item), LeakyReLU(0.2), Dropout(0.5)]
+            dim = item
+
+        seq += [Linear(dim, 1)]
+        self.seq = Sequential(*seq)
+
+    def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lambda_=10):
+        """Compute the gradient penalty."""
+        alpha = torch.rand(real_data.size(0) // pac, 1, 1, device=device)
+        alpha = alpha.repeat(1, pac, real_data.size(1))
+        alpha = alpha.view(-1, real_data.size(1))
+
+        interpolates = alpha * real_data + ((1 - alpha) * fake_data)
+
+        disc_interpolates = self(interpolates)
+
+        gradients = torch.autograd.grad(
+            outputs=disc_interpolates,
+            inputs=interpolates,
+            grad_outputs=torch.ones(disc_interpolates.size(), device=device),
+            create_graph=True,
+            retain_graph=True,
+            only_inputs=True,
+        )[0]
+
+        gradients_view = gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1
+        gradient_penalty = ((gradients_view) ** 2).mean() * lambda_
+
+        return gradient_penalty
+
+    def forward(self, input_):
+        """Apply the Discriminator to the `input_`."""
+        assert input_.size()[0] % self.pac == 0
+        return self.seq(input_.view(-1, self.pacdim))
+
+
+class Residual(Module):
+    """Residual layer for the CTGAN."""
+
+    def __init__(self, i, o):
+        super(Residual, self).__init__()
+        self.fc = Linear(i, o)
+        self.bn = BatchNorm1d(o)
+        self.relu = ReLU()
+
+    def forward(self, input_):
+        """Apply the Residual layer to the `input_`."""
+        out = self.fc(input_)
+        out = self.bn(out)
+        out = self.relu(out)
+        return torch.cat([out, input_], dim=1)
+
+
+class Generator(Module):
+    """Generator for the CTGAN."""
+
+    def __init__(self, embedding_dim, generator_dim, data_dim):
+        super(Generator, self).__init__()
+        dim = embedding_dim
+        seq = []
+        for item in list(generator_dim):
+            seq += [Residual(dim, item)]
+            dim += item
+        seq.append(Linear(dim, data_dim))
+        self.seq = Sequential(*seq)
+
+    def forward(self, input_):
+        """Apply the Generator to the `input_`."""
+        data = self.seq(input_)
+        return data
+
+
+class CTGAN(BaseSynthesizer):
+    """Conditional Table GAN Synthesizer.
+
+    This is the core class of the CTGAN project, where the different components
+    are orchestrated together.
+    For more details about the process, please check the [Modeling Tabular data using
+    Conditional GAN](https://arxiv.org/abs/1907.00503) paper.
+
+    Args:
+        embedding_dim (int):
+            Size of the random sample passed to the Generator. Defaults to 128.
+        generator_dim (tuple or list of ints):
+            Size of the output samples for each one of the Residuals. A Residual Layer
+            will be created for each one of the values provided. Defaults to (256, 256).
+        discriminator_dim (tuple or list of ints):
+            Size of the output samples for each one of the Discriminator Layers. A Linear Layer
+            will be created for each one of the values provided. Defaults to (256, 256).
+        generator_lr (float):
+            Learning rate for the generator. Defaults to 2e-4.
+        generator_decay (float):
+            Generator weight decay for the Adam Optimizer. Defaults to 1e-6.
+        discriminator_lr (float):
+            Learning rate for the discriminator. Defaults to 2e-4.
+        discriminator_decay (float):
+            Discriminator weight decay for the Adam Optimizer. Defaults to 1e-6.
+        batch_size (int):
+            Number of data samples to process in each step.
+        discriminator_steps (int):
+            Number of discriminator updates to do for each generator update.
+            From the WGAN paper: https://arxiv.org/abs/1701.07875. WGAN paper
+            default is 5. Default used is 1 to match original CTGAN implementation.
+        log_frequency (boolean):
+            Whether to use log frequency of categorical levels in conditional
+            sampling. Defaults to ``True``.
+        verbose (boolean):
+            Whether to have print statements for progress results. Defaults to ``False``.
+        epochs (int):
+            Number of training epochs. Defaults to 300.
+        pac (int):
+            Number of samples to group together when applying the discriminator.
+            Defaults to 10.
+        cuda (bool):
+            Whether to attempt to use cuda for GPU computation.
+            If this is False or CUDA is not available, CPU will be used.
+            Defaults to ``True``.
+    """
+
+    def __init__(
+        self,
+        embedding_dim=128,
+        generator_dim=(256, 256),
+        discriminator_dim=(256, 256),
+        generator_lr=2e-4,
+        generator_decay=1e-6,
+        discriminator_lr=2e-4,
+        discriminator_decay=1e-6,
+        batch_size=500,
+        discriminator_steps=1,
+        log_frequency=True,
+        verbose=False,
+        epochs=300,
+        pac=10,
+        cuda=True,
+    ):
+        assert batch_size % 2 == 0
+
+        self._embedding_dim = embedding_dim
+        self._generator_dim = generator_dim
+        self._discriminator_dim = discriminator_dim
+
+        self._generator_lr = generator_lr
+        self._generator_decay = generator_decay
+        self._discriminator_lr = discriminator_lr
+        self._discriminator_decay = discriminator_decay
+
+        self._batch_size = batch_size
+        self._discriminator_steps = discriminator_steps
+        self._log_frequency = log_frequency
+        self._verbose = verbose
+        self._epochs = epochs
+        self.pac = pac
+
+        if not cuda or not torch.cuda.is_available():
+            device = 'cpu'
+        elif isinstance(cuda, str):
+            device = cuda
+        else:
+            device = 'cuda'
+
+        self._device = torch.device(device)
+
+        self._transformer = None
+        self._data_sampler = None
+        self._generator = None
+        self.loss_values = None
+
+    @staticmethod
+    def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
+        """Deals with the instability of the gumbel_softmax for older versions of torch.
+
+        For more details about the issue:
+        https://drive.google.com/file/d/1AA5wPfZ1kquaRtVruCd6BiYZGcDeNxyP/view?usp=sharing
+
+        Args:
+            logits […, num_features]:
+                Unnormalized log probabilities
+            tau:
+                Non-negative scalar temperature
+            hard (bool):
+                If True, the returned samples will be discretized as one-hot vectors,
+                but will be differentiated as if it is the soft sample in autograd
+            dim (int):
+                A dimension along which softmax will be computed. Default: -1.
+
+        Returns:
+            Sampled tensor of same shape as logits from the Gumbel-Softmax distribution.
+        """
+        for _ in range(10):
+            transformed = functional.gumbel_softmax(logits, tau=tau, hard=hard, eps=eps, dim=dim)
+            if not torch.isnan(transformed).any():
+                return transformed
+
+        raise ValueError('gumbel_softmax returning NaN.')
+
+    def _apply_activate(self, data):
+        """Apply proper activation function to the output of the generator."""
+        data_t = []
+        st = 0
+        for column_info in self._transformer.output_info_list:
+            for span_info in column_info:
+                if span_info.activation_fn == 'tanh':
+                    ed = st + span_info.dim
+                    data_t.append(torch.tanh(data[:, st:ed]))
+                    st = ed
+                elif span_info.activation_fn == 'softmax':
+                    ed = st + span_info.dim
+                    transformed = self._gumbel_softmax(data[:, st:ed], tau=0.2)
+                    data_t.append(transformed)
+                    st = ed
+                else:
+                    raise ValueError(f'Unexpected activation function {span_info.activation_fn}.')
+
+        return torch.cat(data_t, dim=1)
+
+    def _cond_loss(self, data, c, m):
+        """Compute the cross entropy loss on the fixed discrete column."""
+        loss = []
+        st = 0
+        st_c = 0
+        for column_info in self._transformer.output_info_list:
+            for span_info in column_info:
+                if len(column_info) != 1 or span_info.activation_fn != 'softmax':
+                    # not discrete column
+                    st += span_info.dim
+                else:
+                    ed = st + span_info.dim
+                    ed_c = st_c + span_info.dim
+                    tmp = functional.cross_entropy(
+                        data[:, st:ed], torch.argmax(c[:, st_c:ed_c], dim=1), reduction='none'
+                    )
+                    loss.append(tmp)
+                    st = ed
+                    st_c = ed_c
+
+        loss = torch.stack(loss, dim=1)  # noqa: PD013
+
+        return (loss * m).sum() / data.size()[0]
+
+    def _validate_discrete_columns(self, train_data, discrete_columns):
+        """Check whether ``discrete_columns`` exists in ``train_data``.
+
+        Args:
+            train_data (numpy.ndarray or pandas.DataFrame):
+                Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
+            discrete_columns (list-like):
+                List of discrete columns to be used to generate the Conditional
+                Vector. If ``train_data`` is a Numpy array, this list should
+                contain the integer indices of the columns. Otherwise, if it is
+                a ``pandas.DataFrame``, this list should contain the column names.
+        """
+        if isinstance(train_data, pd.DataFrame):
+            invalid_columns = set(discrete_columns) - set(train_data.columns)
+        elif isinstance(train_data, np.ndarray):
+            invalid_columns = []
+            for column in discrete_columns:
+                if column < 0 or column >= train_data.shape[1]:
+                    invalid_columns.append(column)
+        else:
+            raise TypeError('``train_data`` should be either pd.DataFrame or np.array.')
+
+        if invalid_columns:
+            raise ValueError(f'Invalid columns found: {invalid_columns}')
+
+    @random_state
+    def fit(self, train_data, discrete_columns=(), epochs=None):
+        """Fit the CTGAN Synthesizer models to the training data.
+
+        Args:
+            train_data (numpy.ndarray or pandas.DataFrame):
+                Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
+            discrete_columns (list-like):
+                List of discrete columns to be used to generate the Conditional
+                Vector. If ``train_data`` is a Numpy array, this list should
+                contain the integer indices of the columns. Otherwise, if it is
+                a ``pandas.DataFrame``, this list should contain the column names.
+        """
+        self._validate_discrete_columns(train_data, discrete_columns)
+
+        if epochs is None:
+            epochs = self._epochs
+        else:
+            warnings.warn(
+                (
+                    '`epochs` argument in `fit` method has been deprecated and will be removed '
+                    'in a future version. Please pass `epochs` to the constructor instead'
+                ),
+                DeprecationWarning,
+            )
+
+        self._transformer = DataTransformer()
+        self._transformer.fit(train_data, discrete_columns)
+
+        train_data = self._transformer.transform(train_data)
+
+        self._data_sampler = DataSampler(
+            train_data, self._transformer.output_info_list, self._log_frequency
+        )
+
+        data_dim = self._transformer.output_dimensions
+
+        self._generator = Generator(
+            self._embedding_dim + self._data_sampler.dim_cond_vec(), self._generator_dim, data_dim
+        ).to(self._device)
+
+        discriminator = Discriminator(
+            data_dim + self._data_sampler.dim_cond_vec(), self._discriminator_dim, pac=self.pac
+        ).to(self._device)
+
+        optimizerG = optim.Adam(
+            self._generator.parameters(),
+            lr=self._generator_lr,
+            betas=(0.5, 0.9),
+            weight_decay=self._generator_decay,
+        )
+
+        optimizerD = optim.Adam(
+            discriminator.parameters(),
+            lr=self._discriminator_lr,
+            betas=(0.5, 0.9),
+            weight_decay=self._discriminator_decay,
+        )
+
+        mean = torch.zeros(self._batch_size, self._embedding_dim, device=self._device)
+        std = mean + 1
+
+        self.loss_values = pd.DataFrame(columns=['Epoch', 'Generator Loss', 'Distriminator Loss'])
+
+        epoch_iterator = tqdm(range(epochs), disable=(not self._verbose),file=sys.stdout)
+        if self._verbose:
+            description = 'Gen. ({gen:.2f}) | Discrim. ({dis:.2f})'
+            epoch_iterator.set_description(description.format(gen=0, dis=0))
+
+        steps_per_epoch = max(len(train_data) // self._batch_size, 1)
+        for i in epoch_iterator:
+            logging.info(f" the epoch {i}")
+            for id_ in range(steps_per_epoch):
+                for n in range(self._discriminator_steps):
+                    fakez = torch.normal(mean=mean, std=std)
+
+                    condvec = self._data_sampler.sample_condvec(self._batch_size)
+                    if condvec is None:
+                        c1, m1, col, opt = None, None, None, None
+                        real = self._data_sampler.sample_data(
+                            train_data, self._batch_size, col, opt
+                        )
+                    else:
+                        c1, m1, col, opt = condvec
+                        c1 = torch.from_numpy(c1).to(self._device)
+                        m1 = torch.from_numpy(m1).to(self._device)
+                        fakez = torch.cat([fakez, c1], dim=1)
+
+                        perm = np.arange(self._batch_size)
+                        np.random.shuffle(perm)
+                        real = self._data_sampler.sample_data(
+                            train_data, self._batch_size, col[perm], opt[perm]
+                        )
+                        c2 = c1[perm]
+
+                    fake = self._generator(fakez)
+                    fakeact = self._apply_activate(fake)
+
+                    real = torch.from_numpy(real.astype('float32')).to(self._device)
+
+                    if c1 is not None:
+                        fake_cat = torch.cat([fakeact, c1], dim=1)
+                        real_cat = torch.cat([real, c2], dim=1)
+                    else:
+                        real_cat = real
+                        fake_cat = fakeact
+
+                    y_fake = discriminator(fake_cat)
+                    y_real = discriminator(real_cat)
+
+                    pen = discriminator.calc_gradient_penalty(
+                        real_cat, fake_cat, self._device, self.pac
+                    )
+                    loss_d = -(torch.mean(y_real) - torch.mean(y_fake))
+
+                    optimizerD.zero_grad(set_to_none=False)
+                    pen.backward(retain_graph=True)
+                    loss_d.backward()
+                    optimizerD.step()
+
+                fakez = torch.normal(mean=mean, std=std)
+                condvec = self._data_sampler.sample_condvec(self._batch_size)
+
+                if condvec is None:
+                    c1, m1, col, opt = None, None, None, None
+                else:
+                    c1, m1, col, opt = condvec
+                    c1 = torch.from_numpy(c1).to(self._device)
+                    m1 = torch.from_numpy(m1).to(self._device)
+                    fakez = torch.cat([fakez, c1], dim=1)
+
+                fake = self._generator(fakez)
+                fakeact = self._apply_activate(fake)
+
+                if c1 is not None:
+                    y_fake = discriminator(torch.cat([fakeact, c1], dim=1))
+                else:
+                    y_fake = discriminator(fakeact)
+
+                if condvec is None:
+                    cross_entropy = 0
+                else:
+                    cross_entropy = self._cond_loss(fake, c1, m1)
+
+                loss_g = -torch.mean(y_fake) + cross_entropy
+
+                optimizerG.zero_grad(set_to_none=False)
+                loss_g.backward()
+                optimizerG.step()
+
+            generator_loss = loss_g.detach().cpu().item()
+            discriminator_loss = loss_d.detach().cpu().item()
+
+            epoch_loss_df = pd.DataFrame({
+                'Epoch': [i],
+                'Generator Loss': [generator_loss],
+                'Discriminator Loss': [discriminator_loss],
+            })
+            if not self.loss_values.empty:
+                self.loss_values = pd.concat([self.loss_values, epoch_loss_df]).reset_index(
+                    drop=True
+                )
+            else:
+                self.loss_values = epoch_loss_df
+            print(description.format(gen=generator_loss, dis=discriminator_loss))
+            if self._verbose:
+                epoch_iterator.set_description(
+                    description.format(gen=generator_loss, dis=discriminator_loss)
+                )
+
+    @random_state
+    def sample(self, n, condition_column=None, condition_value=None):
+        """Sample data similar to the training data.
+
+        Choosing a condition_column and condition_value will increase the probability of the
+        discrete condition_value happening in the condition_column.
+
+        Args:
+            n (int):
+                Number of rows to sample.
+            condition_column (string):
+                Name of a discrete column.
+            condition_value (string):
+                Name of the category in the condition_column which we wish to increase the
+                probability of happening.
+
+        Returns:
+            numpy.ndarray or pandas.DataFrame
+        """
+        if condition_column is not None and condition_value is not None:
+            condition_info = self._transformer.convert_column_name_value_to_id(
+                condition_column, condition_value
+            )
+            global_condition_vec = self._data_sampler.generate_cond_from_condition_column_info(
+                condition_info, self._batch_size
+            )
+        else:
+            global_condition_vec = None
+
+        steps = n // self._batch_size + 1
+        data = []
+        for i in range(steps):
+            mean = torch.zeros(self._batch_size, self._embedding_dim)
+            std = mean + 1
+            fakez = torch.normal(mean=mean, std=std).to(self._device)
+
+            if global_condition_vec is not None:
+                condvec = global_condition_vec.copy()
+            else:
+                condvec = self._data_sampler.sample_original_condvec(self._batch_size)
+
+            if condvec is None:
+                pass
+            else:
+                c1 = condvec
+                c1 = torch.from_numpy(c1).to(self._device)
+                fakez = torch.cat([fakez, c1], dim=1)
+
+            fake = self._generator(fakez)
+            fakeact = self._apply_activate(fake)
+            data.append(fakeact.detach().cpu().numpy())
+
+        data = np.concatenate(data, axis=0)
+        data = data[:n]
+
+        return self._transformer.inverse_transform(data)
+
+    def set_device(self, device):
+        """Set the `device` to be used ('GPU' or 'CPU)."""
+        self._device = device
+        if self._generator is not None:
+            self._generator.to(self._device)
+
+    def predict(self, data):
+        if self._generator is None or self._transformer is None or self._data_sampler is None:
+            raise RuntimeError("模型尚未训练,请先调用 `fit` 方法训练模型。")
+
+        # 预处理输入数据
+        if isinstance(data, pd.DataFrame):
+            data = self._transformer.transform(data)
+        elif isinstance(data, np.ndarray):
+            if data.shape[1] != self._transformer.output_dimensions:
+                raise ValueError(
+                    f"输入数据的列数 ({data.shape[1]}) 与训练数据的列数 ({self._transformer.output_dimensions}) 不匹配。"
+                )
+        else:
+            raise TypeError("输入数据必须是 pandas.DataFrame 或 numpy.ndarray 类型。")
+
+        # 将数据转换为 PyTorch 张量
+        data_tensor = torch.from_numpy(data.astype("float32")).to(self._device)
+
+        # 为 Discriminator 添加条件向量
+        condvec = self._data_sampler.sample_original_condvec(data.shape[0])
+        if condvec is not None:
+            c1 = torch.from_numpy(condvec).to(self._device)
+            data_tensor = torch.cat([data_tensor, c1], dim=1)
+
+        # 使用 Discriminator 生成分数
+        discriminator = Discriminator(
+            data_tensor.shape[1], self._discriminator_dim, pac=self.pac
+        ).to(self._device)
+        discriminator.load_state_dict(self._generator.state_dict())  # 使用训练好的权重
+
+        # 预测输出
+        with torch.no_grad():
+            scores = discriminator(data_tensor)
+
+        return scores.detach().cpu().numpy()
\ No newline at end of file
diff --git a/ctgan/synthesizers/tvae.py b/ctgan/synthesizers/tvae.py
new file mode 100644
index 0000000..ecefbb5
--- /dev/null
+++ b/ctgan/synthesizers/tvae.py
@@ -0,0 +1,246 @@
+"""TVAE module."""
+
+import numpy as np
+import pandas as pd
+import torch
+from torch.nn import Linear, Module, Parameter, ReLU, Sequential
+from torch.nn.functional import cross_entropy
+from torch.optim import Adam
+from torch.utils.data import DataLoader, TensorDataset
+from tqdm import tqdm
+
+from ctgan.data_transformer import DataTransformer
+from ctgan.synthesizers.base import BaseSynthesizer, random_state
+
+
+class Encoder(Module):
+    """Encoder for the TVAE.
+
+    Args:
+        data_dim (int):
+            Dimensions of the data.
+        compress_dims (tuple or list of ints):
+            Size of each hidden layer.
+        embedding_dim (int):
+            Size of the output vector.
+    """
+
+    def __init__(self, data_dim, compress_dims, embedding_dim):
+        super(Encoder, self).__init__()
+        dim = data_dim
+        seq = []
+        for item in list(compress_dims):
+            seq += [Linear(dim, item), ReLU()]
+            dim = item
+
+        self.seq = Sequential(*seq)
+        self.fc1 = Linear(dim, embedding_dim)
+        self.fc2 = Linear(dim, embedding_dim)
+
+    def forward(self, input_):
+        """Encode the passed `input_`."""
+        feature = self.seq(input_)
+        mu = self.fc1(feature)
+        logvar = self.fc2(feature)
+        std = torch.exp(0.5 * logvar)
+        return mu, std, logvar
+
+
+class Decoder(Module):
+    """Decoder for the TVAE.
+
+    Args:
+        embedding_dim (int):
+            Size of the input vector.
+        decompress_dims (tuple or list of ints):
+            Size of each hidden layer.
+        data_dim (int):
+            Dimensions of the data.
+    """
+
+    def __init__(self, embedding_dim, decompress_dims, data_dim):
+        super(Decoder, self).__init__()
+        dim = embedding_dim
+        seq = []
+        for item in list(decompress_dims):
+            seq += [Linear(dim, item), ReLU()]
+            dim = item
+
+        seq.append(Linear(dim, data_dim))
+        self.seq = Sequential(*seq)
+        self.sigma = Parameter(torch.ones(data_dim) * 0.1)
+
+    def forward(self, input_):
+        """Decode the passed `input_`."""
+        return self.seq(input_), self.sigma
+
+
+def _loss_function(recon_x, x, sigmas, mu, logvar, output_info, factor):
+    st = 0
+    loss = []
+    for column_info in output_info:
+        for span_info in column_info:
+            if span_info.activation_fn != 'softmax':
+                ed = st + span_info.dim
+                std = sigmas[st]
+                eq = x[:, st] - torch.tanh(recon_x[:, st])
+                loss.append((eq**2 / 2 / (std**2)).sum())
+                loss.append(torch.log(std) * x.size()[0])
+                st = ed
+
+            else:
+                ed = st + span_info.dim
+                loss.append(
+                    cross_entropy(
+                        recon_x[:, st:ed], torch.argmax(x[:, st:ed], dim=-1), reduction='sum'
+                    )
+                )
+                st = ed
+
+    assert st == recon_x.size()[1]
+    KLD = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp())
+    return sum(loss) * factor / x.size()[0], KLD / x.size()[0]
+
+
+class TVAE(BaseSynthesizer):
+    """TVAE."""
+
+    def __init__(
+        self,
+        embedding_dim=128,
+        compress_dims=(128, 128),
+        decompress_dims=(128, 128),
+        l2scale=1e-5,
+        batch_size=500,
+        epochs=300,
+        loss_factor=2,
+        cuda=True,
+        verbose=False,
+    ):
+        self.embedding_dim = embedding_dim
+        self.compress_dims = compress_dims
+        self.decompress_dims = decompress_dims
+
+        self.l2scale = l2scale
+        self.batch_size = batch_size
+        self.loss_factor = loss_factor
+        self.epochs = epochs
+        self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss'])
+        self.verbose = verbose
+
+        if not cuda or not torch.cuda.is_available():
+            device = 'cpu'
+        elif isinstance(cuda, str):
+            device = cuda
+        else:
+            device = 'cuda'
+
+        self._device = torch.device(device)
+
+    @random_state
+    def fit(self, train_data, discrete_columns=()):
+        """Fit the TVAE Synthesizer models to the training data.
+
+        Args:
+            train_data (numpy.ndarray or pandas.DataFrame):
+                Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
+            discrete_columns (list-like):
+                List of discrete columns to be used to generate the Conditional
+                Vector. If ``train_data`` is a Numpy array, this list should
+                contain the integer indices of the columns. Otherwise, if it is
+                a ``pandas.DataFrame``, this list should contain the column names.
+        """
+        self.transformer = DataTransformer()
+        self.transformer.fit(train_data, discrete_columns)
+        train_data = self.transformer.transform(train_data)
+        dataset = TensorDataset(torch.from_numpy(train_data.astype('float32')).to(self._device))
+        loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, drop_last=False)
+
+        data_dim = self.transformer.output_dimensions
+        encoder = Encoder(data_dim, self.compress_dims, self.embedding_dim).to(self._device)
+        self.decoder = Decoder(self.embedding_dim, self.decompress_dims, data_dim).to(self._device)
+        optimizerAE = Adam(
+            list(encoder.parameters()) + list(self.decoder.parameters()), weight_decay=self.l2scale
+        )
+
+        self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss'])
+        iterator = tqdm(range(self.epochs), disable=(not self.verbose))
+        if self.verbose:
+            iterator_description = 'Loss: {loss:.3f}'
+            iterator.set_description(iterator_description.format(loss=0))
+
+        for i in iterator:
+            loss_values = []
+            batch = []
+            for id_, data in enumerate(loader):
+                optimizerAE.zero_grad()
+                real = data[0].to(self._device)
+                mu, std, logvar = encoder(real)
+                eps = torch.randn_like(std)
+                emb = eps * std + mu
+                rec, sigmas = self.decoder(emb)
+                loss_1, loss_2 = _loss_function(
+                    rec,
+                    real,
+                    sigmas,
+                    mu,
+                    logvar,
+                    self.transformer.output_info_list,
+                    self.loss_factor,
+                )
+                loss = loss_1 + loss_2
+                loss.backward()
+                optimizerAE.step()
+                self.decoder.sigma.data.clamp_(0.01, 1.0)
+
+                batch.append(id_)
+                loss_values.append(loss.detach().cpu().item())
+
+            epoch_loss_df = pd.DataFrame({
+                'Epoch': [i] * len(batch),
+                'Batch': batch,
+                'Loss': loss_values,
+            })
+            if not self.loss_values.empty:
+                self.loss_values = pd.concat([self.loss_values, epoch_loss_df]).reset_index(
+                    drop=True
+                )
+            else:
+                self.loss_values = epoch_loss_df
+
+            if self.verbose:
+                iterator.set_description(
+                    iterator_description.format(loss=loss.detach().cpu().item())
+                )
+
+    @random_state
+    def sample(self, samples):
+        """Sample data similar to the training data.
+
+        Args:
+            samples (int):
+                Number of rows to sample.
+
+        Returns:
+            numpy.ndarray or pandas.DataFrame
+        """
+        self.decoder.eval()
+
+        steps = samples // self.batch_size + 1
+        data = []
+        for _ in range(steps):
+            mean = torch.zeros(self.batch_size, self.embedding_dim)
+            std = mean + 1
+            noise = torch.normal(mean=mean, std=std).to(self._device)
+            fake, sigmas = self.decoder(noise)
+            fake = torch.tanh(fake)
+            data.append(fake.detach().cpu().numpy())
+
+        data = np.concatenate(data, axis=0)
+        data = data[:samples]
+        return self.transformer.inverse_transform(data, sigmas.detach().cpu().numpy())
+
+    def set_device(self, device):
+        """Set the `device` to be used ('GPU' or 'CPU)."""
+        self._device = device
+        self.decoder.to(self._device)
diff --git a/dataprocess/Dockerfile b/dataprocess/Dockerfile
new file mode 100644
index 0000000..b5b5ebc
--- /dev/null
+++ b/dataprocess/Dockerfile
@@ -0,0 +1,20 @@
+# 使用官方的 Python 基础镜像
+FROM python:3.11-slim
+
+# 设置工作目录
+WORKDIR /opt/ml/processing
+
+# 安装系统依赖
+RUN apt-get update && apt-get install -y --no-install-recommends \
+    build-essential \
+    && rm -rf /var/lib/apt/lists/*
+
+# 升级 pip 并安装 Python 包
+RUN pip install --upgrade pip
+RUN pip install pandas numpy scikit-learn imbalanced-learn
+
+# 将处理脚本复制到容器中
+COPY dataprocess.py /opt/ml/processing/process_data.py
+
+# 设置默认的命令
+ENTRYPOINT ["python", "/opt/ml/processing/process_data.py"]
diff --git a/dataprocess/dataprocess.py b/dataprocess/dataprocess.py
new file mode 100644
index 0000000..c7399be
--- /dev/null
+++ b/dataprocess/dataprocess.py
@@ -0,0 +1,191 @@
+import argparse
+import os
+import pandas as pd
+import numpy as np
+from sklearn.impute import SimpleImputer
+from sklearn.preprocessing import LabelEncoder
+from imblearn.over_sampling import SMOTE
+import json
+import logging
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s [%(levelname)s] %(message)s",
+    handlers=[
+        logging.StreamHandler()  # Log to stdout
+    ]
+)
+label_map = {}  # LabelEncoder 映射
+params = {}
+def encode_numeric_range(df, names, normalized_low=0, normalized_high=1):
+    """将数值数据归一化到指定范围 [normalized_low, normalized_high]"""
+    global params
+    for name in names:
+        data_low = df[name].min()
+        data_high = df[name].max()
+        # 如果 NaN,赋值为 0
+        data_low = 0 if np.isnan(data_low) else data_low
+        data_high = 0 if np.isnan(data_high) else data_high
+        # 避免 ZeroDivisionError(如果 min == max,则直接赋 0)
+        if data_high != data_low:
+            df[name] = ((df[name] - data_low) / (data_high - data_low)) \
+                        * (normalized_high - normalized_low) + normalized_low
+        else:
+            df[name] = 0
+        if name not in params:
+            params[name] = {}
+        params[name]['min'] = data_low
+        params[name]['max'] = data_high
+    return df
+
+def encode_numeric_zscore(df, names):
+    """将数值数据标准化为 Z-score"""
+    global params
+    for name in names:
+        mean = df[name].mean()
+        sd = df[name].std()
+        # 如果 NaN,赋值为 0
+        mean = 0 if np.isnan(mean) else mean
+        sd = 0 if np.isnan(sd) else sd
+        # 避免除以 0 的情况(如果标准差为 0,则直接赋 0)
+        if sd != 0:
+            df[name] = (df[name] - mean) / sd
+        else:
+            df[name] = 0
+        if name not in params:
+            params[name] = {}
+        params[name]['mean'] = mean
+        params[name]['std'] = sd
+    return df
+
+def numerical_encoding(df, label_column):
+    """将类别标签编码为 0,1,2... 并存储映射"""
+    global label_map
+    label_encoder = LabelEncoder()
+    df[label_column] = label_encoder.fit_transform(df[label_column])
+
+    # 生成 label -> 数字 和 数字 -> label 映射
+    label_map = {
+        "label_to_number": {label: int(num) for label, num in zip(label_encoder.classes_, range(len(label_encoder.classes_)))},
+        "number_to_label": {int(num): label for label, num in zip(label_encoder.classes_, range(len(label_encoder.classes_)))}
+    }
+    return df
+
+# def oversample_data(X, y):
+#     """Oversample the minority class using replication to avoid heavy computation."""
+#     lists = [X[y == label].values.tolist() for label in np.unique(y)]
+#     totall_list = []
+#     for i in range(len(lists)):
+#         while len(lists[i]) < 5000:
+#             lists[i].extend(lists[i])
+#         totall_list.extend(lists[i])
+#     new_y = np.concatenate([[label] * len(lists[i]) for i, label in enumerate(np.unique(y))])
+#     new_X = np.array(totall_list)
+#     return pd.DataFrame(new_X, columns=X.columns), pd.Series(new_y)
+
+if __name__ == "__main__":
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--input", type=str, default=r"E:\datasource\MachineLearningCSV\MachineLearningCVE", help="Input folder path containing CSV files")
+    parser.add_argument("--output", type=str, default=r"D:\djangoProject\talktive\dataprocess\result.csv", help="Output folder path for processed CSV files")
+    args = parser.parse_args()
+    # Feature and label column definitions
+    selected_features = [
+        "Destination Port",  # formatted_data["Destination Port"]
+        "Flow Duration",  # formatted_data["Flow Duration (ms)"]
+        "Total Fwd Packets",  # formatted_data["Total Fwd Packets"]
+        "Total Backward Packets",  # formatted_data["Total Backward Packets"]
+        "Total Length of Fwd Packets",  # formatted_data["Total Length of Fwd Packets"]
+        "Total Length of Bwd Packets",  # formatted_data["Total Length of Bwd Packets"]
+        "Fwd Packet Length Max",  # formatted_data["Fwd Packet Length Max"]
+        "Fwd Packet Length Min",  # formatted_data["Fwd Packet Length Min"]
+        "Fwd Packet Length Mean",  # formatted_data["Fwd Packet Length Mean"]
+        "Fwd Packet Length Std",  # formatted_data["Fwd Packet Length Stddev"]
+        "Bwd Packet Length Max",  # formatted_data["Bwd Packet Length Max"]
+        "Bwd Packet Length Min",  # formatted_data["Bwd Packet Length Min"]
+        "Bwd Packet Length Mean",  # formatted_data["Bwd Packet Length Mean"]
+        "Bwd Packet Length Std",  # formatted_data["Bwd Packet Length Stddev"]
+        "Flow Bytes/s",  # formatted_data["Flow Bytes/s"]
+        "Flow Packets/s",  # formatted_data["Flow Packets/s"]
+        "Flow IAT Mean",  # formatted_data["Flow IAT Mean (ms)"]
+        "Flow IAT Std",  # formatted_data["Flow IAT Stddev (ms)"]
+        "Flow IAT Max",  # formatted_data["Flow IAT Max (ms)"]
+        "Flow IAT Min",  # formatted_data["Flow IAT Min (ms)"]
+        "Fwd IAT Mean",  # formatted_data["Fwd IAT Mean (ms)"]
+        "Fwd IAT Std",  # formatted_data["Fwd IAT Stddev (ms)"]
+        "Fwd IAT Max",  # formatted_data["Fwd IAT Max (ms)"]
+        "Fwd IAT Min",  # formatted_data["Fwd IAT Min (ms)"]
+        "Bwd IAT Mean",  # formatted_data["Bwd IAT Mean (ms)"]
+        "Bwd IAT Std",  # formatted_data["Bwd IAT Stddev (ms)"]
+        "Bwd IAT Max",  # formatted_data["Bwd IAT Max (ms)"]
+        "Bwd IAT Min",  # formatted_data["Bwd IAT Min (ms)"]
+        "Fwd PSH Flags",  # formatted_data["Fwd PSH Flags"]
+        "Bwd PSH Flags",  # formatted_data["Bwd PSH Flags"]
+        "Fwd URG Flags",  # formatted_data["Fwd URG Flags"]
+        "Bwd URG Flags",  # formatted_data["Bwd URG Flags"]
+        "Fwd Packets/s",  # formatted_data["Fwd Packets/s"]
+        "Bwd Packets/s",  # formatted_data["Bwd Packets/s"]
+        "Down/Up Ratio",  # formatted_data["down_up_ratio"]
+        "Average Packet Size",  # formatted_data["average_packet_size"]
+        "Avg Fwd Segment Size",  # formatted_data["avg_fwd_segment_size"]
+        "Avg Bwd Segment Size",  # formatted_data["avg_bwd_segment_size"]
+        "Packet Length Mean",  # formatted_data["Packet Length Mean"]
+        "Packet Length Std",  # formatted_data["Packet Length Std"]
+        "FIN Flag Count",  # formatted_data["FIN Flag Count"]
+        "SYN Flag Count",  # formatted_data["SYN Flag Count"]
+        "RST Flag Count",  # formatted_data["RST Flag Count"]
+        "PSH Flag Count",  # formatted_data["PSH Flag Count"]
+        "ACK Flag Count",  # formatted_data["ACK Flag Count"]
+        "URG Flag Count",  # formatted_data["URG Flag Count"]
+        "CWE Flag Count",  # formatted_data["CWE Flag Count"]
+        "ECE Flag Count",  # formatted_data["ECE Flag Count"]
+        "Subflow Fwd Packets",  # formatted_data["Subflow Fwd Packets"]
+        "Subflow Fwd Bytes",  # formatted_data["Subflow Fwd Bytes"]
+        "Subflow Bwd Packets",  # formatted_data["Subflow Bwd Packets"]
+        "Subflow Bwd Bytes",  # formatted_data["Subflow Bwd Bytes"]
+        # "Init_Win_bytes_forward",  # formatted_data["Init_Win_bytes_forward"]
+        # "Init_Win_bytes_backward",  # formatted_data["Init_Win_bytes_backward"],
+        "Label"
+    ]
+
+    standardized_features = [col.strip().lower().replace(" ", "_") for col in selected_features]
+    all_data = []
+    logging.info("Starting data processing...")
+
+    # Load and process CSV files
+    for file_name in os.listdir(args.input):
+        if file_name.endswith(".csv"):
+            file_path = os.path.join(args.input, file_name)
+            logging.info(f"Processing file: {file_path}")
+            df = pd.read_csv(file_path)
+            df.columns = [col.strip().lower().replace(" ", "_") for col in df.columns]
+            df_selected = df[[col for col in standardized_features if col in df.columns]].dropna(how="any", subset=["label"])
+            all_data.append(df_selected)
+
+    # Combine all data
+    if all_data:
+        final_data = pd.concat(all_data, ignore_index=True)
+    else:
+        raise ValueError("No valid CSV files found in the input folder.")
+    final_data = numerical_encoding(final_data, "label")
+    X = final_data.drop(columns=["label"])
+    y = final_data["label"]
+    # Data cleaning and transformation
+    X.replace([float('inf'), float('-inf')], np.nan, inplace=True)
+    X.fillna(X.median(), inplace=True)
+    X = encode_numeric_zscore(X, X.columns)
+    X = encode_numeric_range(X, X.columns)
+    scaling_params_file = os.path.join(args.output, "scaling_params.json")
+    with open(scaling_params_file, 'w') as f:
+        json.dump(params, f)
+    label_map_file=os.path.join(args.output,"label_map.json")
+    with open(label_map_file,"w") as f:
+        json.dump(label_map,f)
+    # Encoding the label column
+    # Save processed data
+
+    final_data = pd.concat([X, y], axis=1)
+    # Save the final data to a CSV file
+    output_path = os.path.join(args.output, "processed_data.csv")
+    final_data.to_csv(output_path, index=False)
+    logging.info(f"Processed data saved to {output_path}")
+    logging.info("Data processing completed successfully!")
-- 
GitLab