diff --git a/ctgan/.idea/.gitignore b/ctgan/.idea/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..13566b81b018ad684f3a35fee301741b2734c8f4
--- /dev/null
+++ b/ctgan/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/ctgan/.idea/ctgan.iml b/ctgan/.idea/ctgan.iml
new file mode 100644
index 0000000000000000000000000000000000000000..8b8c395472a5a6b3598af42086e590417ace9933
--- /dev/null
+++ b/ctgan/.idea/ctgan.iml
@@ -0,0 +1,12 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<module type="PYTHON_MODULE" version="4">
+  <component name="NewModuleRootManager">
+    <content url="file://$MODULE_DIR$" />
+    <orderEntry type="inheritedJdk" />
+    <orderEntry type="sourceFolder" forTests="false" />
+  </component>
+  <component name="PyDocumentationSettings">
+    <option name="format" value="PLAIN" />
+    <option name="myDocStringFormat" value="Plain" />
+  </component>
+</module>
\ No newline at end of file
diff --git a/ctgan/.idea/inspectionProfiles/profiles_settings.xml b/ctgan/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000000000000000000000000000000000000..105ce2da2d6447d11dfe32bfb846c3d5b199fc99
--- /dev/null
+++ b/ctgan/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+<component name="InspectionProjectProfileManager">
+  <settings>
+    <option name="USE_PROJECT_PROFILE" value="false" />
+    <version value="1.0" />
+  </settings>
+</component>
\ No newline at end of file
diff --git a/ctgan/.idea/misc.xml b/ctgan/.idea/misc.xml
new file mode 100644
index 0000000000000000000000000000000000000000..4f4e5a44800916865a9def0e30be9d80359132be
--- /dev/null
+++ b/ctgan/.idea/misc.xml
@@ -0,0 +1,4 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectRootManager" version="2" project-jdk-name="D:\anconda" project-jdk-type="Python SDK" />
+</project>
\ No newline at end of file
diff --git a/ctgan/.idea/modules.xml b/ctgan/.idea/modules.xml
new file mode 100644
index 0000000000000000000000000000000000000000..b65f96f97a89631f2f6562cb543e0e8f892b0236
--- /dev/null
+++ b/ctgan/.idea/modules.xml
@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectModuleManager">
+    <modules>
+      <module fileurl="file://$PROJECT_DIR$/.idea/ctgan.iml" filepath="$PROJECT_DIR$/.idea/ctgan.iml" />
+    </modules>
+  </component>
+</project>
\ No newline at end of file
diff --git a/ctgan/.idea/vcs.xml b/ctgan/.idea/vcs.xml
new file mode 100644
index 0000000000000000000000000000000000000000..6c0b8635858dc7ad44b93df54b762707ce49eefc
--- /dev/null
+++ b/ctgan/.idea/vcs.xml
@@ -0,0 +1,6 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="VcsDirectoryMappings">
+    <mapping directory="$PROJECT_DIR$/.." vcs="Git" />
+  </component>
+</project>
\ No newline at end of file
diff --git a/ctgan/__init__.py b/ctgan/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f693c93a20525d2b2bbda8804f66f814e00f6151
--- /dev/null
+++ b/ctgan/__init__.py
@@ -0,0 +1,16 @@
+# -*- coding: utf-8 -*-
+
+"""Top-level package for ctgan."""
+
+__author__ = 'Qian Wang, Inc.'
+__email__ = 'info@sdv.dev'
+__version__ = '0.10.2.dev0'
+
+from ctgan.load_data import final_data
+from ctgan.synthesizers.ctgan import CTGAN
+from ctgan.synthesizers.tvae import TVAE
+from ctgan.config import HyperparametersParser
+
+__all__ = ('CTGAN', 'TVAE', 'final_data','HyperparametersParser')
+
+
diff --git a/ctgan/__main__.py b/ctgan/__main__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a38b64fa67e9f4382726517a83f641441182f35
--- /dev/null
+++ b/ctgan/__main__.py
@@ -0,0 +1,231 @@
+# import os
+# import argparse
+# import pandas as pd
+# from ctgan import CTGAN
+# import logging
+# 
+# # Set up logging
+# logging.basicConfig(
+#     level=logging.INFO,
+#     format="%(asctime)s [%(levelname)s] %(message)s",
+#     handlers=[
+#         logging.StreamHandler()  # Log to stdout
+#     ]
+# )
+# 
+# def train_ctgan(input_data, output_model, epochs, discrete_columns,embedding_dim,
+#                 generator_dim, discriminator_dim,generator_lr,discriminator_lr ,batch_size):
+#     logging.info("Starting CTGAN training...")
+#     data = pd.read_csv(input_data)
+#     logging.info(f"Data loaded from {input_data} with shape {data.shape}")
+#     ctgan = CTGAN(
+#         embedding_dim=embedding_dim,
+#         generator_dim=generator_dim,
+#         discriminator_dim=discriminator_dim,
+#         generator_lr=generator_lr,
+#         discriminator_lr=discriminator_lr,
+#         batch_size=batch_size,
+#         epochs=epochs,
+#         verbose=True
+#     )
+# 
+#     logging.info(f"CTGAN initialized with {epochs} epochs")
+#     ctgan.fit(data, discrete_columns)
+#     logging.info(f"Training completed. Saving model to {output_model}")
+#     ctgan.save(output_model)
+#     logging.info("Model saved successfully.")
+# 
+# if __name__ == "__main__":
+#     # Define default values that match SageMaker hyperparameters
+#     input_data = os.getenv("SM_CHANNEL_DATA1", "/opt/ml/input/data/data1/processed_data.csv")
+#     output_model = os.getenv("SM_OUTPUT_MODEL", "/opt/ml/model/ctgan_model.pt")
+#     epochs = int(os.getenv("SM_EPOCHS", 80))
+#     discrete_columns = os.getenv("SM_DISCRETE_COLUMNS", "label").split(",")
+# 
+#     embedding_dim = int(os.getenv("SM_EMBEDDING_DIM", 128))
+#     generator_dim = tuple(map(int, os.getenv("SM_GENERATOR_DIM", "256,256").split(",")))
+#     discriminator_dim = tuple(map(int, os.getenv("SM_DISCRIMINATOR_DIM", "256,256").split(",")))
+#     batch_size = int(os.getenv("SM_BATCH_SIZE", 500))
+#     generator_lr = float(os.getenv("SM_GENERATOR_LR", 2e-4))
+#     discriminator_lr = float(os.getenv("SM_DISCRIMINATOR_LR", 2e-4))
+# 
+# 
+#     # Log loaded hyperparameters for debugging
+#     logging.info("Loaded hyperparameters:")
+#     logging.info(f"Input Data Path: {input_data}")
+#     logging.info(f"Output Model Path: {output_model}")
+#     logging.info(f"Epochs: {epochs}")
+#     logging.info(f"Discrete Columns: {discrete_columns}")
+# 
+#     # Call the training function
+#     train_ctgan(input_data, output_model, epochs, discrete_columns,embedding_dim,generator_dim,
+#                 discriminator_dim,generator_lr,discriminator_lr,batch_size)
+import datetime
+import json
+import os
+import random
+
+import torch
+import torch.nn as nn
+import pandas as pd
+import logging
+from itertools import product
+from ctgan import HyperparametersParser
+from sklearn.metrics import mean_squared_error
+from sklearn.model_selection import ParameterGrid
+from ctgan import CTGAN
+
+
+# 设置日志
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s [%(levelname)s] %(message)s",
+    handlers=[logging.StreamHandler()]
+)
+
+
+# 定义模型训练函数
+def train_ctgan(input_data, output_model, epochs, discrete_columns, embedding_dim,
+                generator_dim, discriminator_dim, generator_lr, discriminator_lr, batch_size):
+    logging.info("Starting CTGAN training...")
+    data = pd.read_csv(input_data)
+    logging.info(f"Data loaded from {input_data} with shape {data.shape}")
+
+    # 初始化 CTGAN 模型
+    ctgan = CTGAN(
+        embedding_dim=embedding_dim,
+        generator_dim=generator_dim,
+        discriminator_dim=discriminator_dim,
+        generator_lr=generator_lr,
+        discriminator_lr=discriminator_lr,
+        batch_size=batch_size,
+        epochs=epochs,
+        verbose=True
+    )
+    logging.info(f"begin train ...")
+
+    # 训练模型
+    ctgan.fit(data, discrete_columns)
+
+    # 保存模型
+    ctgan.save(output_model)
+    logging.info(f"Model saved to {output_model}")
+    return ctgan
+
+
+# 网格搜索主程序
+def evaluate_ctgan(ctgan, input_data, discrete_columns, n_samples=1000):
+    """
+    使用 CTGAN 的 `sample` 方法评估生成样本的质量。
+
+    参数:
+        ctgan: 训练好的 CTGAN 模型。
+        input_data: 输入数据文件路径。
+        discrete_columns: 离散列的名称列表。
+        n_samples: 每个类别生成的样本数量。
+
+    返回:
+        accuracy: 分类准确性。
+        loss: 分类损失。
+    """
+    # 读取真实数据
+    data = pd.read_csv(input_data)
+    logging.info(f"Data loaded for evaluation with shape: {data.shape}")
+
+    # 获取真实标签
+    label_column = 'label'  # 假设第一个离散列为目标列
+    unique_labels = data[label_column].unique()
+
+    # 生成少数类样本
+    generated_data = []
+    for label in unique_labels:
+        logging.info(f"Generating samples for label: {label}")
+        generated_samples = ctgan.sample(
+            n=n_samples,
+            condition_column=label_column,
+            condition_value=label
+        )
+        generated_data.append(generated_samples)
+
+    # 合并生成样本
+    generated_data = pd.concat(generated_data, ignore_index=True)
+    logging.info(f"Generated data shape: {generated_data.shape}")
+    # 获取真实数据和生成数据的分布
+    real_distribution = data[label_column].value_counts(normalize=True)
+    generated_distribution = generated_data[label_column].value_counts(normalize=True)
+    # 使用均方误差 (MSE) 评估分布差异
+    common_labels = real_distribution.index.intersection(generated_distribution.index)
+    mse = mean_squared_error(real_distribution[common_labels], generated_distribution[common_labels])
+    logging.info(f"[validate] Generated data distribution MSE: {mse:.4f}")
+    # 假设生成样本的标签完全正确,计算伪分类的准确率
+    predicted_labels = generated_data[label_column].values
+    real_labels = data[label_column].sample(len(generated_data), random_state=42).values
+    correct_predictions = (predicted_labels == real_labels).sum()
+    accuracy = correct_predictions / len(real_labels)
+    # 使用二分类交叉熵计算损失
+    criterion = nn.BCEWithLogitsLoss()
+    predicted_scores_tensor = torch.tensor(predicted_labels, dtype=torch.float32)
+    real_labels_tensor = torch.tensor(real_labels, dtype=torch.float32)
+    loss = criterion(predicted_scores_tensor, real_labels_tensor)
+    logging.info('[validate] loss: {:.4f}, acc: {:.2f}'.format(loss.item(), accuracy * 100))
+    return accuracy, loss.item()
+
+def save_metadata_to_json(output_path, model_path, score, params):
+    metadata = {
+        "ModelPath": model_path,
+        "Score": score,
+        "Parameters": params
+    }
+    json_path = output_model+ "metadata.json"
+    try:
+        with open(json_path, "w") as json_file:
+            json.dump(metadata, json_file, indent=4)
+        logging.info(f"Metadata saved to {json_path}")
+    except Exception as e:
+        logging.warning(f"Error saving metadata to JSON: {e}")
+
+
+if __name__ == "__main__":
+    # 定义数据和超参数
+    input_data = os.getenv("SM_CHANNEL_DATA1", "/opt/ml/input/data/train/processed_data.csv")
+    output_model = os.getenv("SM_OUTPUT_MODEL", "/opt/ml/model")
+    hyperparameters_file = "/opt/ml/input/config/hyperparameters.json"
+
+    # 解析超参数
+    parser = HyperparametersParser(hyperparameters_file)
+    hyperparameters = parser.parse_hyperparameters()
+
+    # 提取超参数
+    epochs = hyperparameters.get("epochs", 40)
+    discrete_columns = hyperparameters.get("discrete_columns", "label").split(",")
+    embedding_dim = hyperparameters.get("embedding_dim", 128)
+    generator_dim = eval(hyperparameters.get("generator_dim", "(256, 256)"))
+    discriminator_dim = eval(hyperparameters.get("discriminator_dim", "(256, 256)"))
+    generator_lr = hyperparameters.get("generator_lr", 1e-4)
+    discriminator_lr = hyperparameters.get("discriminator_lr", 1e-4)
+    batch_size = hyperparameters.get("batch_size", 100)
+
+    # 训练模型
+    timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
+    model_path = os.path.join(output_model, f"ctgan_model_{timestamp}.pt")
+    logging.info(f"Training with parameters: {hyperparameters}")
+
+    ctgan = train_ctgan(
+        input_data=input_data,
+        output_model=model_path,
+        epochs=epochs,
+        discrete_columns=discrete_columns,
+        embedding_dim=embedding_dim,
+        generator_dim=generator_dim,
+        discriminator_dim=discriminator_dim,
+        generator_lr=generator_lr,
+        discriminator_lr=discriminator_lr,
+        batch_size=batch_size
+    )
+
+    # 评估模型
+    score = evaluate_ctgan(ctgan, input_data, discrete_columns)
+    logging.info(f"Evaluation Score: {score}")
+
+    # # 保存元数据
+    # save_metadata_to_json(output_model, model_path, score, hyperparameters)
\ No newline at end of file
diff --git a/ctgan/__pycache__/__init__.cpython-311.pyc b/ctgan/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..22ad241c302e2fbe3a07d78fc8523242805d766d
Binary files /dev/null and b/ctgan/__pycache__/__init__.cpython-311.pyc differ
diff --git a/ctgan/__pycache__/__main__.cpython-311.pyc b/ctgan/__pycache__/__main__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..182e053419b2c39b7bce116a95f6bb164b049920
Binary files /dev/null and b/ctgan/__pycache__/__main__.cpython-311.pyc differ
diff --git a/ctgan/__pycache__/data_sampler.cpython-311.pyc b/ctgan/__pycache__/data_sampler.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c7b71467eacb17d46d3c02c760b931be117cb265
Binary files /dev/null and b/ctgan/__pycache__/data_sampler.cpython-311.pyc differ
diff --git a/ctgan/__pycache__/data_transformer.cpython-311.pyc b/ctgan/__pycache__/data_transformer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eef72269ccd8a69bdbac53fbddfa87b182e24a6e
Binary files /dev/null and b/ctgan/__pycache__/data_transformer.cpython-311.pyc differ
diff --git a/ctgan/__pycache__/load_data.cpython-311.pyc b/ctgan/__pycache__/load_data.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4a7c66a7be7e09ca9ccedc03198c88c49b47c286
Binary files /dev/null and b/ctgan/__pycache__/load_data.cpython-311.pyc differ
diff --git a/ctgan/config.py b/ctgan/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b04e5c0090702980ae49e8ba688aa60ed8d1dbe2
--- /dev/null
+++ b/ctgan/config.py
@@ -0,0 +1,48 @@
+import json
+import logging
+
+
+class HyperparametersParser:
+    def __init__(self, hyperparameters_file):
+        self.hyperparameters_file = hyperparameters_file
+        self.hyperparameters = {}
+
+    def parse_hyperparameters(self):
+        """
+        解析超参数文件并进行类型转换。
+        """
+        try:
+            with open(self.hyperparameters_file, "r") as f:
+                injected_hyperparameters = json.load(f)
+
+            logging.info("Parsing hyperparameters...")
+            for key, value in injected_hyperparameters.items():
+                logging.info(f"Raw hyperparameter - Key: {key}, Value: {value}")
+                self.hyperparameters[key] = self._convert_type(value)
+        except FileNotFoundError:
+            logging.warning(f"Hyperparameters file {self.hyperparameters_file} not found. Using default parameters.")
+        except Exception as e:
+            logging.error(f"Error parsing hyperparameters: {e}")
+        return self.hyperparameters
+
+    @staticmethod
+    def _convert_type(value):
+        """
+        将超参数值转换为合适的 Python 类型(字符串、整数、浮点数或布尔值)。
+        """
+        if isinstance(value, str):
+            # 判断是否为布尔值
+            if value.lower() == "true":
+                return True
+            elif value.lower() == "false":
+                return False
+
+            # 判断是否为整数或浮点数
+            try:
+                if "." in value:
+                    return float(value)
+                else:
+                    return int(value)
+            except ValueError:
+                pass  # 如果无法转换,保留字符串
+        return value
diff --git a/ctgan/data.py b/ctgan/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a48da0c565b060b8fa46e1bd8ad2330ec338a07
--- /dev/null
+++ b/ctgan/data.py
@@ -0,0 +1,88 @@
+"""Data loading."""
+
+import json
+
+import numpy as np
+import pandas as pd
+
+
+def read_csv(csv_filename, meta_filename=None, header=True, discrete=None):
+    """Read a csv file."""
+    data = pd.read_csv(csv_filename, header='infer' if header else None)
+
+    if meta_filename:
+        with open(meta_filename) as meta_file:
+            metadata = json.load(meta_file)
+
+        discrete_columns = [
+            column['name'] for column in metadata['columns'] if column['type'] != 'continuous'
+        ]
+
+    elif discrete:
+        discrete_columns = discrete.split(',')
+        if not header:
+            discrete_columns = [int(i) for i in discrete_columns]
+
+    else:
+        discrete_columns = []
+
+    return data, discrete_columns
+
+
+def read_tsv(data_filename, meta_filename):
+    """Read a tsv file."""
+    with open(meta_filename) as f:
+        column_info = f.readlines()
+
+    column_info_raw = [x.replace('{', ' ').replace('}', ' ').split() for x in column_info]
+
+    discrete = []
+    continuous = []
+    column_info = []
+
+    for idx, item in enumerate(column_info_raw):
+        if item[0] == 'C':
+            continuous.append(idx)
+            column_info.append((float(item[1]), float(item[2])))
+        else:
+            assert item[0] == 'D'
+            discrete.append(idx)
+            column_info.append(item[1:])
+
+    meta = {
+        'continuous_columns': continuous,
+        'discrete_columns': discrete,
+        'column_info': column_info,
+    }
+
+    with open(data_filename) as f:
+        lines = f.readlines()
+
+    data = []
+    for row in lines:
+        row_raw = row.split()
+        row = []
+        for idx, col in enumerate(row_raw):
+            if idx in continuous:
+                row.append(col)
+            else:
+                assert idx in discrete
+                row.append(column_info[idx].index(col))
+
+        data.append(row)
+
+    return np.asarray(data, dtype='float32'), meta['discrete_columns']
+
+
+def write_tsv(data, meta, output_filename):
+    """Write to a tsv file."""
+    with open(output_filename, 'w') as f:
+        for row in data:
+            for idx, col in enumerate(row):
+                if idx in meta['continuous_columns']:
+                    print(col, end=' ', file=f)
+                else:
+                    assert idx in meta['discrete_columns']
+                    print(meta['column_info'][idx][int(col)], end=' ', file=f)
+
+            print(file=f)
diff --git a/ctgan/data_sampler.py b/ctgan/data_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9b67d195a3d3336b87ea360029aeca9e7e1f340
--- /dev/null
+++ b/ctgan/data_sampler.py
@@ -0,0 +1,153 @@
+"""DataSampler module."""
+
+import numpy as np
+
+
+class DataSampler(object):
+    """DataSampler samples the conditional vector and corresponding data for CTGAN."""
+
+    def __init__(self, data, output_info, log_frequency):
+        self._data_length = len(data)
+
+        def is_discrete_column(column_info):
+            return len(column_info) == 1 and column_info[0].activation_fn == 'softmax'
+
+        n_discrete_columns = sum([
+            1 for column_info in output_info if is_discrete_column(column_info)
+        ])
+
+        self._discrete_column_matrix_st = np.zeros(n_discrete_columns, dtype='int32')
+
+        # Store the row id for each category in each discrete column.
+        # For example _rid_by_cat_cols[a][b] is a list of all rows with the
+        # a-th discrete column equal value b.
+        self._rid_by_cat_cols = []
+
+        # Compute _rid_by_cat_cols
+        st = 0
+        for column_info in output_info:
+            if is_discrete_column(column_info):
+                span_info = column_info[0]
+                ed = st + span_info.dim
+
+                rid_by_cat = []
+                for j in range(span_info.dim):
+                    rid_by_cat.append(np.nonzero(data[:, st + j])[0])
+                self._rid_by_cat_cols.append(rid_by_cat)
+                st = ed
+            else:
+                st += sum([span_info.dim for span_info in column_info])
+        assert st == data.shape[1]
+
+        # Prepare an interval matrix for efficiently sample conditional vector
+        max_category = max(
+            [column_info[0].dim for column_info in output_info if is_discrete_column(column_info)],
+            default=0,
+        )
+
+        self._discrete_column_cond_st = np.zeros(n_discrete_columns, dtype='int32')
+        self._discrete_column_n_category = np.zeros(n_discrete_columns, dtype='int32')
+        self._discrete_column_category_prob = np.zeros((n_discrete_columns, max_category))
+        self._n_discrete_columns = n_discrete_columns
+        self._n_categories = sum([
+            column_info[0].dim for column_info in output_info if is_discrete_column(column_info)
+        ])
+
+        st = 0
+        current_id = 0
+        current_cond_st = 0
+        for column_info in output_info:
+            if is_discrete_column(column_info):
+                span_info = column_info[0]
+                ed = st + span_info.dim
+                category_freq = np.sum(data[:, st:ed], axis=0)
+                if log_frequency:
+                    category_freq = np.log(category_freq + 1)
+                category_prob = category_freq / np.sum(category_freq)
+                self._discrete_column_category_prob[current_id, : span_info.dim] = category_prob
+                self._discrete_column_cond_st[current_id] = current_cond_st
+                self._discrete_column_n_category[current_id] = span_info.dim
+                current_cond_st += span_info.dim
+                current_id += 1
+                st = ed
+            else:
+                st += sum([span_info.dim for span_info in column_info])
+
+    def _random_choice_prob_index(self, discrete_column_id):
+        probs = self._discrete_column_category_prob[discrete_column_id]
+        r = np.expand_dims(np.random.rand(probs.shape[0]), axis=1)
+        return (probs.cumsum(axis=1) > r).argmax(axis=1)
+
+    def sample_condvec(self, batch):
+        """Generate the conditional vector for training.
+
+        Returns:
+            cond (batch x #categories):
+                The conditional vector.
+            mask (batch x #discrete columns):
+                A one-hot vector indicating the selected discrete column.
+            discrete column id (batch):
+                Integer representation of mask.
+            category_id_in_col (batch):
+                Selected category in the selected discrete column.
+        """
+        if self._n_discrete_columns == 0:
+            return None
+
+        discrete_column_id = np.random.choice(np.arange(self._n_discrete_columns), batch)
+
+        cond = np.zeros((batch, self._n_categories), dtype='float32')
+        mask = np.zeros((batch, self._n_discrete_columns), dtype='float32')
+        mask[np.arange(batch), discrete_column_id] = 1
+        category_id_in_col = self._random_choice_prob_index(discrete_column_id)
+        category_id = self._discrete_column_cond_st[discrete_column_id] + category_id_in_col
+        cond[np.arange(batch), category_id] = 1
+
+        return cond, mask, discrete_column_id, category_id_in_col
+
+    def sample_original_condvec(self, batch):
+        """Generate the conditional vector for generation use original frequency."""
+        if self._n_discrete_columns == 0:
+            return None
+
+        category_freq = self._discrete_column_category_prob.flatten()
+        category_freq = category_freq[category_freq != 0]
+        category_freq = category_freq / np.sum(category_freq)
+        col_idxs = np.random.choice(np.arange(len(category_freq)), batch, p=category_freq)
+        cond = np.zeros((batch, self._n_categories), dtype='float32')
+        cond[np.arange(batch), col_idxs] = 1
+
+        return cond
+
+    def sample_data(self, data, n, col, opt):
+        """Sample data from original training data satisfying the sampled conditional vector.
+
+        Args:
+            data:
+                The training data.
+
+        Returns:
+            n:
+                n rows of matrix data.
+        """
+        if col is None:
+            idx = np.random.randint(len(data), size=n)
+            return data[idx]
+
+        idx = []
+        for c, o in zip(col, opt):
+            idx.append(np.random.choice(self._rid_by_cat_cols[c][o]))
+
+        return data[idx]
+
+    def dim_cond_vec(self):
+        """Return the total number of categories."""
+        return self._n_categories
+
+    def generate_cond_from_condition_column_info(self, condition_info, batch):
+        """Generate the condition vector."""
+        vec = np.zeros((batch, self._n_categories), dtype='float32')
+        id_ = self._discrete_column_matrix_st[condition_info['discrete_column_id']]
+        id_ += condition_info['value_id']
+        vec[:, id_] = 1
+        return vec
diff --git a/ctgan/data_transformer.py b/ctgan/data_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ed8b0eb129a4d1eb0dbf46780f97d824eb4319d
--- /dev/null
+++ b/ctgan/data_transformer.py
@@ -0,0 +1,265 @@
+"""DataTransformer module."""
+
+from collections import namedtuple
+
+import numpy as np
+import pandas as pd
+from joblib import Parallel, delayed
+from rdt.transformers import ClusterBasedNormalizer, OneHotEncoder
+
+SpanInfo = namedtuple('SpanInfo', ['dim', 'activation_fn'])
+ColumnTransformInfo = namedtuple(
+    'ColumnTransformInfo',
+    ['column_name', 'column_type', 'transform', 'output_info', 'output_dimensions'],
+)
+
+
+class DataTransformer(object):
+    """Data Transformer.
+
+    Model continuous columns with a BayesianGMM and normalize them to a scalar between [-1, 1]
+    and a vector. Discrete columns are encoded using a OneHotEncoder.
+    """
+
+    def __init__(self, max_clusters=10, weight_threshold=0.005):
+        """Create a data transformer.
+
+        Args:
+            max_clusters (int):
+                Maximum number of Gaussian distributions in Bayesian GMM.
+            weight_threshold (float):
+                Weight threshold for a Gaussian distribution to be kept.
+        """
+        self._max_clusters = max_clusters
+        self._weight_threshold = weight_threshold
+
+    def _fit_continuous(self, data):
+        """Train Bayesian GMM for continuous columns.
+
+        Args:
+            data (pd.DataFrame):
+                A dataframe containing a column.
+
+        Returns:
+            namedtuple:
+                A ``ColumnTransformInfo`` object.
+        """
+        column_name = data.columns[0]
+        gm = ClusterBasedNormalizer(
+            missing_value_generation='from_column',
+            max_clusters=min(len(data), self._max_clusters),
+            weight_threshold=self._weight_threshold,
+        )
+        gm.fit(data, column_name)
+        num_components = sum(gm.valid_component_indicator)
+
+        return ColumnTransformInfo(
+            column_name=column_name,
+            column_type='continuous',
+            transform=gm,
+            output_info=[SpanInfo(1, 'tanh'), SpanInfo(num_components, 'softmax')],
+            output_dimensions=1 + num_components,
+        )
+
+    def _fit_discrete(self, data):
+        """Fit one hot encoder for discrete column.
+
+        Args:
+            data (pd.DataFrame):
+                A dataframe containing a column.
+
+        Returns:
+            namedtuple:
+                A ``ColumnTransformInfo`` object.
+        """
+        column_name = data.columns[0]
+        ohe = OneHotEncoder()
+        ohe.fit(data, column_name)
+        num_categories = len(ohe.dummies)
+
+        return ColumnTransformInfo(
+            column_name=column_name,
+            column_type='discrete',
+            transform=ohe,
+            output_info=[SpanInfo(num_categories, 'softmax')],
+            output_dimensions=num_categories,
+        )
+
+    def fit(self, raw_data, discrete_columns=()):
+        """Fit the ``DataTransformer``.
+
+        Fits a ``ClusterBasedNormalizer`` for continuous columns and a
+        ``OneHotEncoder`` for discrete columns.
+
+        This step also counts the #columns in matrix data and span information.
+        """
+        self.output_info_list = []
+        self.output_dimensions = 0
+        self.dataframe = True
+
+        if not isinstance(raw_data, pd.DataFrame):
+            self.dataframe = False
+            # work around for RDT issue #328 Fitting with numerical column names fails
+            discrete_columns = [str(column) for column in discrete_columns]
+            column_names = [str(num) for num in range(raw_data.shape[1])]
+            raw_data = pd.DataFrame(raw_data, columns=column_names)
+
+        self._column_raw_dtypes = raw_data.infer_objects().dtypes
+        self._column_transform_info_list = []
+        for column_name in raw_data.columns:
+            if column_name in discrete_columns:
+                column_transform_info = self._fit_discrete(raw_data[[column_name]])
+            else:
+                column_transform_info = self._fit_continuous(raw_data[[column_name]])
+
+            self.output_info_list.append(column_transform_info.output_info)
+            self.output_dimensions += column_transform_info.output_dimensions
+            self._column_transform_info_list.append(column_transform_info)
+
+    def _transform_continuous(self, column_transform_info, data):
+        column_name = data.columns[0]
+        flattened_column = data[column_name].to_numpy().flatten()
+        data = data.assign(**{column_name: flattened_column})
+        gm = column_transform_info.transform
+        transformed = gm.transform(data)
+
+        #  Converts the transformed data to the appropriate output format.
+        #  The first column (ending in '.normalized') stays the same,
+        #  but the lable encoded column (ending in '.component') is one hot encoded.
+        output = np.zeros((len(transformed), column_transform_info.output_dimensions))
+        output[:, 0] = transformed[f'{column_name}.normalized'].to_numpy()
+        index = transformed[f'{column_name}.component'].to_numpy().astype(int)
+        output[np.arange(index.size), index + 1] = 1.0
+
+        return output
+
+    def _transform_discrete(self, column_transform_info, data):
+        ohe = column_transform_info.transform
+        return ohe.transform(data).to_numpy()
+
+    def _synchronous_transform(self, raw_data, column_transform_info_list):
+        """Take a Pandas DataFrame and transform columns synchronous.
+
+        Outputs a list with Numpy arrays.
+        """
+        column_data_list = []
+        for column_transform_info in column_transform_info_list:
+            column_name = column_transform_info.column_name
+            data = raw_data[[column_name]]
+            if column_transform_info.column_type == 'continuous':
+                column_data_list.append(self._transform_continuous(column_transform_info, data))
+            else:
+                column_data_list.append(self._transform_discrete(column_transform_info, data))
+
+        return column_data_list
+
+    def _parallel_transform(self, raw_data, column_transform_info_list):
+        """Take a Pandas DataFrame and transform columns in parallel.
+
+        Outputs a list with Numpy arrays.
+        """
+        processes = []
+        for column_transform_info in column_transform_info_list:
+            column_name = column_transform_info.column_name
+            data = raw_data[[column_name]]
+            process = None
+            if column_transform_info.column_type == 'continuous':
+                process = delayed(self._transform_continuous)(column_transform_info, data)
+            else:
+                process = delayed(self._transform_discrete)(column_transform_info, data)
+            processes.append(process)
+
+        return Parallel(n_jobs=-1)(processes)
+
+    def transform(self, raw_data):
+        """Take raw data and output a matrix data."""
+        if not isinstance(raw_data, pd.DataFrame):
+            column_names = [str(num) for num in range(raw_data.shape[1])]
+            raw_data = pd.DataFrame(raw_data, columns=column_names)
+
+        # Only use parallelization with larger data sizes.
+        # Otherwise, the transformation will be slower.
+        if raw_data.shape[0] < 500:
+            column_data_list = self._synchronous_transform(
+                raw_data, self._column_transform_info_list
+            )
+        else:
+            column_data_list = self._parallel_transform(raw_data, self._column_transform_info_list)
+
+        return np.concatenate(column_data_list, axis=1).astype(float)
+
+    def _inverse_transform_continuous(self, column_transform_info, column_data, sigmas, st):
+        gm = column_transform_info.transform
+        data = pd.DataFrame(column_data[:, :2], columns=list(gm.get_output_sdtypes())).astype(float)
+        data[data.columns[1]] = np.argmax(column_data[:, 1:], axis=1)
+        if sigmas is not None:
+            selected_normalized_value = np.random.normal(data.iloc[:, 0], sigmas[st])
+            data.iloc[:, 0] = selected_normalized_value
+
+        return gm.reverse_transform(data)
+
+    def _inverse_transform_discrete(self, column_transform_info, column_data):
+        ohe = column_transform_info.transform
+        data = pd.DataFrame(column_data, columns=list(ohe.get_output_sdtypes()))
+        return ohe.reverse_transform(data)[column_transform_info.column_name]
+
+    def inverse_transform(self, data, sigmas=None):
+        """Take matrix data and output raw data.
+
+        Output uses the same type as input to the transform function.
+        Either np array or pd dataframe.
+        """
+        st = 0
+        recovered_column_data_list = []
+        column_names = []
+        for column_transform_info in self._column_transform_info_list:
+            dim = column_transform_info.output_dimensions
+            column_data = data[:, st : st + dim]
+            if column_transform_info.column_type == 'continuous':
+                recovered_column_data = self._inverse_transform_continuous(
+                    column_transform_info, column_data, sigmas, st
+                )
+            else:
+                recovered_column_data = self._inverse_transform_discrete(
+                    column_transform_info, column_data
+                )
+
+            recovered_column_data_list.append(recovered_column_data)
+            column_names.append(column_transform_info.column_name)
+            st += dim
+
+        recovered_data = np.column_stack(recovered_column_data_list)
+        recovered_data = pd.DataFrame(recovered_data, columns=column_names).astype(
+            self._column_raw_dtypes
+        )
+        if not self.dataframe:
+            recovered_data = recovered_data.to_numpy()
+
+        return recovered_data
+
+    def convert_column_name_value_to_id(self, column_name, value):
+        """Get the ids of the given `column_name`."""
+        discrete_counter = 0
+        column_id = 0
+        for column_transform_info in self._column_transform_info_list:
+            if column_transform_info.column_name == column_name:
+                break
+            if column_transform_info.column_type == 'discrete':
+                discrete_counter += 1
+
+            column_id += 1
+
+        else:
+            raise ValueError(f"The column_name `{column_name}` doesn't exist in the data.")
+
+        ohe = column_transform_info.transform
+        data = pd.DataFrame([value], columns=[column_transform_info.column_name])
+        one_hot = ohe.transform(data).to_numpy()[0]
+        if sum(one_hot) == 0:
+            raise ValueError(f"The value `{value}` doesn't exist in the column `{column_name}`.")
+
+        return {
+            'discrete_column_id': discrete_counter,
+            'column_id': column_id,
+            'value_id': np.argmax(one_hot),
+        }
diff --git a/ctgan/dataprocess_using_gan.py b/ctgan/dataprocess_using_gan.py
new file mode 100644
index 0000000000000000000000000000000000000000..53cfff2b061c764689cd39a86edb937eb2b2ab72
--- /dev/null
+++ b/ctgan/dataprocess_using_gan.py
@@ -0,0 +1,141 @@
+import argparse
+import logging
+import os
+import tarfile
+
+import pandas as pd
+from ctgan import CTGAN
+
+
+# 配置日志记录
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s [%(levelname)s] %(message)s",
+    handlers=[
+        logging.StreamHandler()
+    ]
+)
+def extract_model_and_tokenizer(tar_path, extract_path="/tmp"):
+    """
+    从 tar.gz 文件中提取模型文件和分词器目录。
+
+    Args:
+        tar_path (str): 模型 tar.gz 文件路径。
+        extract_path (str): 提取文件的目录。
+
+    Returns:
+        model_path (str): 提取后的模型文件路径 (.pth)。
+        tokenizer_path (str): 提取后的分词器目录路径。
+    """
+    with tarfile.open(tar_path, "r:gz") as tar:
+        tar.extractall(path=extract_path)
+
+    model_path = None
+
+
+    for root, dirs, files in os.walk(extract_path):
+        for file in files:
+            if file.endswith(".pt"):
+                model_path = os.path.join(root, file)
+
+
+    if not model_path:
+        raise FileNotFoundError("No .pth model file found in the extracted tar.gz.")
+
+
+    return model_path
+def load_model(model_path):
+    """加载预训练的 CTGAN 模型."""
+    logging.info(f"Loading model from {model_path}")
+
+    model = CTGAN.load(model_path)
+    logging.info("Model loaded successfully.")
+    return model
+
+
+def generate_samples(ctgan, data, discrete_columns, min_samples_per_class=5000):
+    """
+    使用 CTGAN 模型为不足类别生成样本.
+
+    Args:
+        ctgan: 加载的 CTGAN 模型.
+        data: 输入数据(DataFrame).
+        discrete_columns: 离散列名称列表.
+        min_samples_per_class: 每个类别的最小样本数量.
+
+    Returns:
+        包含生成数据的 DataFrame.
+    """
+    label_column = 'label'
+    label_counts = data[label_column].value_counts()
+    logging.info("Starting sample generation for underrepresented classes...")
+    logging.info(f"Current label distribution:\n{label_counts}")
+    generated_data = []
+    for label, count in label_counts.items():
+        if count < min_samples_per_class:
+            # 需要生成的样本数量
+            n_to_generate = min_samples_per_class - count
+            logging.info(f"Label '{label}' has {count} samples, generating {n_to_generate} more samples...")
+
+            # 使用 CTGAN 生成样本
+            samples = ctgan.sample(
+                n=n_to_generate,
+                condition_column=label_column,
+                condition_value=label
+            )
+            generated_data.append(samples)
+
+    if generated_data:
+        # 合并生成的数据
+        generated_data = pd.concat(generated_data, ignore_index=True)
+        logging.info(f"Generated data shape: {generated_data.shape}")
+    else:
+        logging.info("No underrepresented classes found. No samples were generated.")
+        generated_data = pd.DataFrame()
+
+    return generated_data
+
+def main():
+    parser = argparse.ArgumentParser(description="CTGAN Sample Generation Script")
+    parser.add_argument("--input", type=str, required=True, help="Path to input CSV file.")
+    parser.add_argument("--model", type=str, required=True, help="Path to the trained CTGAN model.")
+    parser.add_argument("--output", type=str, required=True, help="Path to save the output CSV file.")
+    parser.add_argument("--n_samples", type=int, default=1000, help="Number of samples to generate per class.")
+    args = parser.parse_args()
+
+    input_path = args.input
+    model_path = args.model
+    output_dir = args.output
+    n_samples = args.n_samples
+
+    # 加载输入数据
+    logging.info(f"Loading input data from {input_path}")
+    data = pd.read_csv(input_path)
+    logging.info(f"Input data shape: {data.shape}")
+
+    # 识别离散列
+    label_column = "label"  # 假设 label 列为目标列
+    discrete_columns = [label_column]
+    if label_column not in data.columns:
+        raise ValueError(f"Label column '{label_column}' not found in input data.")
+    model_p=extract_model_and_tokenizer(model_path)
+    # 加载 CTGAN 模型
+    ctgan = load_model(model_p)
+
+    # 生成样本
+    generated_data = generate_samples(ctgan, data, discrete_columns)
+
+    # 将生成的样本与原始数据合并
+    combined_data = pd.concat([data, generated_data], ignore_index=True)
+    logging.info(f"Combined data shape: {combined_data.shape}")
+
+    # 保存生成的样本和原始数据
+    os.makedirs(output_dir, exist_ok=True)
+    output_path = os.path.join(output_dir, "generated_samples.csv")  # 添加文件名
+
+    logging.info(f"Saving combined data to {output_path}")
+    combined_data.to_csv(output_path, index=False)
+    logging.info("Sample generation and saving completed successfully.")
+
+if __name__ == "__main__":
+    main()
diff --git a/ctgan/load_data.py b/ctgan/load_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe98e2e2b7aadce78b7c0409b62ff6c8e7fa9da5
--- /dev/null
+++ b/ctgan/load_data.py
@@ -0,0 +1,73 @@
+# import pandas as pd
+# import numpy as np
+# from sklearn.impute import SimpleImputer
+# from imblearn.over_sampling import SMOTE
+# import os
+#
+# # 定义文件夹路径
+# folder_path = "data"
+#
+# # 定义需要提取的特征
+# selected_features = [
+#     "Destination Port", "Flow Duration", "Total Fwd Packets", "Total Backward Packets",
+#     "Total Length of Fwd Packets", "Total Length of Bwd Packets", "Fwd Packet Length Max",
+#     "Fwd Packet Length Min", "Fwd Packet Length Mean", "Fwd Packet Length Stddev",
+#     "Bwd Packet Length Max", "Bwd Packet Length Min", "Bwd Packet Length Mean",
+#     "Bwd Packet Length Stddev", "Flow Bytes/s", "Flow Packets/s", "Fwd IAT Mean (ms)",
+#     "Fwd IAT Stddev (ms)", "Fwd IAT Max (ms)", "Fwd IAT Min (ms)", "Bwd IAT Mean (ms)",
+#     "Bwd IAT Stddev (ms)", "Bwd IAT Max (ms)", "Bwd IAT Min (ms)", "Fwd PSH Flags",
+#     "Bwd PSH Flags", "Fwd URG Flags", "Bwd URG Flags", "Fwd Packets/s", "Bwd Packets/s",
+#     "down_up_ratio", "average_packet_size", "avg_fwd_segment_size", "avg_bwd_segment_size",
+#     "Subflow Fwd Packets", "Subflow Fwd Bytes", "Subflow Bwd Packets", "Subflow Bwd Bytes",
+#     "Label"
+# ]
+#
+# # 标准化列名函数
+# def standardize_columns(df):
+#     df.columns = [col.strip().lower().replace(" ", "_") for col in df.columns]
+#     return df
+#
+# standardized_features = [col.strip().lower().replace(" ", "_") for col in selected_features]
+#
+# # 读取所有CSV文件
+# all_data = []
+# for file_name in os.listdir(folder_path):
+#     if file_name.endswith(".csv"):
+#         file_path = os.path.join(folder_path, file_name)
+#         df = pd.read_csv(file_path)
+#         df = standardize_columns(df)
+#         df_selected = df[[col for col in standardized_features if col in df.columns]].dropna(how="any", subset=["label"])
+#         all_data.append(df_selected)
+#
+# # 合并数据
+# final_data = pd.concat(all_data, ignore_index=True)
+#
+# # 分离特征和标签
+# X = final_data.drop(columns=["label"])
+# y = final_data["label"]
+#
+# # 检查和处理 inf 或超大值
+# X.replace([float('inf'), float('-inf')], np.nan, inplace=True)
+#
+# # 检查并替换所有 pd.NA 为 np.nan
+# X = X.replace({pd.NA: np.nan})
+#
+# # 确保所有列为浮点数类型
+# X = X.astype(float)
+#
+# # 检查是否仍有 NaN
+# if X.isnull().values.any():
+#     print("存在缺失值,填充中位数...")
+#     imputer = SimpleImputer(strategy="median")  # 使用中位数填充
+#     X = pd.DataFrame(imputer.fit_transform(X), columns=X.columns)
+#
+# # 应用 SMOTE
+# X_resampled, y_resampled = SMOTE().fit_resample(X, y)
+#
+# # 转换为 DataFrame
+# balanced_data = pd.concat([pd.DataFrame(X_resampled, columns=X.columns), pd.DataFrame(y_resampled, columns=["label"])], axis=1)
+#
+# # 查看结果
+# print(balanced_data["label"].value_counts())
+
+final_data=[]
diff --git a/ctgan/synthesizers/__init__.py b/ctgan/synthesizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..881e08f02bc2f410d13cbe415d22b8caea088ccc
--- /dev/null
+++ b/ctgan/synthesizers/__init__.py
@@ -0,0 +1,10 @@
+"""Synthesizers module."""
+
+from ctgan.synthesizers.ctgan import CTGAN
+from ctgan.synthesizers.tvae import TVAE
+
+__all__ = ('CTGAN', 'TVAE')
+
+
+def get_all_synthesizers():
+    return {name: globals()[name] for name in __all__}
diff --git a/ctgan/synthesizers/__pycache__/__init__.cpython-311.pyc b/ctgan/synthesizers/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e7bfe77a077018470ff0b67846f5b064c5acf443
Binary files /dev/null and b/ctgan/synthesizers/__pycache__/__init__.cpython-311.pyc differ
diff --git a/ctgan/synthesizers/__pycache__/base.cpython-311.pyc b/ctgan/synthesizers/__pycache__/base.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96a5fbe3cdb9bc443d8d155f16e47ad2d9aaaa4d
Binary files /dev/null and b/ctgan/synthesizers/__pycache__/base.cpython-311.pyc differ
diff --git a/ctgan/synthesizers/__pycache__/ctgan.cpython-311.pyc b/ctgan/synthesizers/__pycache__/ctgan.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8119686edfd9f07e4536c37ba0978bca11bb7798
Binary files /dev/null and b/ctgan/synthesizers/__pycache__/ctgan.cpython-311.pyc differ
diff --git a/ctgan/synthesizers/__pycache__/tvae.cpython-311.pyc b/ctgan/synthesizers/__pycache__/tvae.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1d6ab1a3b33b5e4a9046deed08c1c7ac4bfa0f48
Binary files /dev/null and b/ctgan/synthesizers/__pycache__/tvae.cpython-311.pyc differ
diff --git a/ctgan/synthesizers/base.py b/ctgan/synthesizers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..add0dd7e1419a5500f7ed29dfccb3b2672da18bb
--- /dev/null
+++ b/ctgan/synthesizers/base.py
@@ -0,0 +1,151 @@
+"""BaseSynthesizer module."""
+
+import contextlib
+
+import numpy as np
+import torch
+
+
+@contextlib.contextmanager
+def set_random_states(random_state, set_model_random_state):
+    """Context manager for managing the random state.
+
+    Args:
+        random_state (int or tuple):
+            The random seed or a tuple of (numpy.random.RandomState, torch.Generator).
+        set_model_random_state (function):
+            Function to set the random state on the model.
+    """
+    original_np_state = np.random.get_state()
+    original_torch_state = torch.get_rng_state()
+
+    random_np_state, random_torch_state = random_state
+
+    np.random.set_state(random_np_state.get_state())
+    torch.set_rng_state(random_torch_state.get_state())
+
+    try:
+        yield
+    finally:
+        current_np_state = np.random.RandomState()
+        current_np_state.set_state(np.random.get_state())
+        current_torch_state = torch.Generator()
+        current_torch_state.set_state(torch.get_rng_state())
+        set_model_random_state((current_np_state, current_torch_state))
+
+        np.random.set_state(original_np_state)
+        torch.set_rng_state(original_torch_state)
+
+
+def random_state(function):
+    """Set the random state before calling the function.
+
+    Args:
+        function (Callable):
+            The function to wrap around.
+    """
+
+    def wrapper(self, *args, **kwargs):
+        if self.random_states is None:
+            return function(self, *args, **kwargs)
+
+        else:
+            with set_random_states(self.random_states, self.set_random_state):
+                return function(self, *args, **kwargs)
+
+    return wrapper
+
+
+class BaseSynthesizer:
+    """Base class for all default synthesizers of ``CTGAN``."""
+
+    random_states = None
+
+    def __getstate__(self):
+        """Improve pickling state for ``BaseSynthesizer``.
+
+        Convert to ``cpu`` device before starting the pickling process in order to be able to
+        load the model even when used from an external tool such as ``SDV``. Also, if
+        ``random_states`` are set, store their states as dictionaries rather than generators.
+
+        Returns:
+            dict:
+                Python dict representing the object.
+        """
+        device_backup = self._device
+        self.set_device(torch.device('cpu'))
+        state = self.__dict__.copy()
+        self.set_device(device_backup)
+        if (
+            isinstance(self.random_states, tuple)
+            and isinstance(self.random_states[0], np.random.RandomState)
+            and isinstance(self.random_states[1], torch.Generator)
+        ):
+            state['_numpy_random_state'] = self.random_states[0].get_state()
+            state['_torch_random_state'] = self.random_states[1].get_state()
+            state.pop('random_states')
+
+        return state
+
+    def __setstate__(self, state):
+        """Restore the state of a ``BaseSynthesizer``.
+
+        Restore the ``random_states`` from the state dict if those are present and then
+        set the device according to the current hardware.
+        """
+        if '_numpy_random_state' in state and '_torch_random_state' in state:
+            np_state = state.pop('_numpy_random_state')
+            torch_state = state.pop('_torch_random_state')
+
+            current_torch_state = torch.Generator()
+            current_torch_state.set_state(torch_state)
+
+            current_numpy_state = np.random.RandomState()
+            current_numpy_state.set_state(np_state)
+            state['random_states'] = (current_numpy_state, current_torch_state)
+
+        self.__dict__ = state
+        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+        self.set_device(device)
+
+    def save(self, path):
+        """Save the model in the passed `path`."""
+        device_backup = self._device
+        self.set_device(torch.device('cpu'))
+        torch.save(self, path)
+        self.set_device(device_backup)
+
+    @classmethod
+    def load(cls, path):
+        """Load the model stored in the passed `path`."""
+        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+        model = torch.load(path)
+        model.set_device(device)
+        return model
+
+    def set_random_state(self, random_state):
+        """Set the random state.
+
+        Args:
+            random_state (int, tuple, or None):
+                Either a tuple containing the (numpy.random.RandomState, torch.Generator)
+                or an int representing the random seed to use for both random states.
+        """
+        if random_state is None:
+            self.random_states = random_state
+        elif isinstance(random_state, int):
+            self.random_states = (
+                np.random.RandomState(seed=random_state),
+                torch.Generator().manual_seed(random_state),
+            )
+        elif (
+            isinstance(random_state, tuple)
+            and isinstance(random_state[0], np.random.RandomState)
+            and isinstance(random_state[1], torch.Generator)
+        ):
+            self.random_states = random_state
+        else:
+            raise TypeError(
+                f'`random_state` {random_state} expected to be an int or a tuple of '
+                '(`np.random.RandomState`, `torch.Generator`)'
+            )
diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..704468e9c382f0e207634e63ea953e5c421cff35
--- /dev/null
+++ b/ctgan/synthesizers/ctgan.py
@@ -0,0 +1,564 @@
+"""CTGAN module."""
+import logging
+import sys
+import warnings
+
+import numpy as np
+import pandas as pd
+import torch
+from torch import optim
+from torch.nn import BatchNorm1d, Dropout, LeakyReLU, Linear, Module, ReLU, Sequential, functional
+from tqdm import tqdm
+
+from ctgan.data_sampler import DataSampler
+from ctgan.data_transformer import DataTransformer
+from ctgan.synthesizers.base import BaseSynthesizer, random_state
+
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s [%(levelname)s] %(message)s",
+    handlers=[logging.StreamHandler()]
+)
+class Discriminator(Module):
+    """Discriminator for the CTGAN."""
+
+    def __init__(self, input_dim, discriminator_dim, pac=10):
+        super(Discriminator, self).__init__()
+        dim = input_dim * pac
+        self.pac = pac
+        self.pacdim = dim
+        seq = []
+        for item in list(discriminator_dim):
+            seq += [Linear(dim, item), LeakyReLU(0.2), Dropout(0.5)]
+            dim = item
+
+        seq += [Linear(dim, 1)]
+        self.seq = Sequential(*seq)
+
+    def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lambda_=10):
+        """Compute the gradient penalty."""
+        alpha = torch.rand(real_data.size(0) // pac, 1, 1, device=device)
+        alpha = alpha.repeat(1, pac, real_data.size(1))
+        alpha = alpha.view(-1, real_data.size(1))
+
+        interpolates = alpha * real_data + ((1 - alpha) * fake_data)
+
+        disc_interpolates = self(interpolates)
+
+        gradients = torch.autograd.grad(
+            outputs=disc_interpolates,
+            inputs=interpolates,
+            grad_outputs=torch.ones(disc_interpolates.size(), device=device),
+            create_graph=True,
+            retain_graph=True,
+            only_inputs=True,
+        )[0]
+
+        gradients_view = gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1
+        gradient_penalty = ((gradients_view) ** 2).mean() * lambda_
+
+        return gradient_penalty
+
+    def forward(self, input_):
+        """Apply the Discriminator to the `input_`."""
+        assert input_.size()[0] % self.pac == 0
+        return self.seq(input_.view(-1, self.pacdim))
+
+
+class Residual(Module):
+    """Residual layer for the CTGAN."""
+
+    def __init__(self, i, o):
+        super(Residual, self).__init__()
+        self.fc = Linear(i, o)
+        self.bn = BatchNorm1d(o)
+        self.relu = ReLU()
+
+    def forward(self, input_):
+        """Apply the Residual layer to the `input_`."""
+        out = self.fc(input_)
+        out = self.bn(out)
+        out = self.relu(out)
+        return torch.cat([out, input_], dim=1)
+
+
+class Generator(Module):
+    """Generator for the CTGAN."""
+
+    def __init__(self, embedding_dim, generator_dim, data_dim):
+        super(Generator, self).__init__()
+        dim = embedding_dim
+        seq = []
+        for item in list(generator_dim):
+            seq += [Residual(dim, item)]
+            dim += item
+        seq.append(Linear(dim, data_dim))
+        self.seq = Sequential(*seq)
+
+    def forward(self, input_):
+        """Apply the Generator to the `input_`."""
+        data = self.seq(input_)
+        return data
+
+
+class CTGAN(BaseSynthesizer):
+    """Conditional Table GAN Synthesizer.
+
+    This is the core class of the CTGAN project, where the different components
+    are orchestrated together.
+    For more details about the process, please check the [Modeling Tabular data using
+    Conditional GAN](https://arxiv.org/abs/1907.00503) paper.
+
+    Args:
+        embedding_dim (int):
+            Size of the random sample passed to the Generator. Defaults to 128.
+        generator_dim (tuple or list of ints):
+            Size of the output samples for each one of the Residuals. A Residual Layer
+            will be created for each one of the values provided. Defaults to (256, 256).
+        discriminator_dim (tuple or list of ints):
+            Size of the output samples for each one of the Discriminator Layers. A Linear Layer
+            will be created for each one of the values provided. Defaults to (256, 256).
+        generator_lr (float):
+            Learning rate for the generator. Defaults to 2e-4.
+        generator_decay (float):
+            Generator weight decay for the Adam Optimizer. Defaults to 1e-6.
+        discriminator_lr (float):
+            Learning rate for the discriminator. Defaults to 2e-4.
+        discriminator_decay (float):
+            Discriminator weight decay for the Adam Optimizer. Defaults to 1e-6.
+        batch_size (int):
+            Number of data samples to process in each step.
+        discriminator_steps (int):
+            Number of discriminator updates to do for each generator update.
+            From the WGAN paper: https://arxiv.org/abs/1701.07875. WGAN paper
+            default is 5. Default used is 1 to match original CTGAN implementation.
+        log_frequency (boolean):
+            Whether to use log frequency of categorical levels in conditional
+            sampling. Defaults to ``True``.
+        verbose (boolean):
+            Whether to have print statements for progress results. Defaults to ``False``.
+        epochs (int):
+            Number of training epochs. Defaults to 300.
+        pac (int):
+            Number of samples to group together when applying the discriminator.
+            Defaults to 10.
+        cuda (bool):
+            Whether to attempt to use cuda for GPU computation.
+            If this is False or CUDA is not available, CPU will be used.
+            Defaults to ``True``.
+    """
+
+    def __init__(
+        self,
+        embedding_dim=128,
+        generator_dim=(256, 256),
+        discriminator_dim=(256, 256),
+        generator_lr=2e-4,
+        generator_decay=1e-6,
+        discriminator_lr=2e-4,
+        discriminator_decay=1e-6,
+        batch_size=500,
+        discriminator_steps=1,
+        log_frequency=True,
+        verbose=False,
+        epochs=300,
+        pac=10,
+        cuda=True,
+    ):
+        assert batch_size % 2 == 0
+
+        self._embedding_dim = embedding_dim
+        self._generator_dim = generator_dim
+        self._discriminator_dim = discriminator_dim
+
+        self._generator_lr = generator_lr
+        self._generator_decay = generator_decay
+        self._discriminator_lr = discriminator_lr
+        self._discriminator_decay = discriminator_decay
+
+        self._batch_size = batch_size
+        self._discriminator_steps = discriminator_steps
+        self._log_frequency = log_frequency
+        self._verbose = verbose
+        self._epochs = epochs
+        self.pac = pac
+
+        if not cuda or not torch.cuda.is_available():
+            device = 'cpu'
+        elif isinstance(cuda, str):
+            device = cuda
+        else:
+            device = 'cuda'
+
+        self._device = torch.device(device)
+
+        self._transformer = None
+        self._data_sampler = None
+        self._generator = None
+        self.loss_values = None
+
+    @staticmethod
+    def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
+        """Deals with the instability of the gumbel_softmax for older versions of torch.
+
+        For more details about the issue:
+        https://drive.google.com/file/d/1AA5wPfZ1kquaRtVruCd6BiYZGcDeNxyP/view?usp=sharing
+
+        Args:
+            logits […, num_features]:
+                Unnormalized log probabilities
+            tau:
+                Non-negative scalar temperature
+            hard (bool):
+                If True, the returned samples will be discretized as one-hot vectors,
+                but will be differentiated as if it is the soft sample in autograd
+            dim (int):
+                A dimension along which softmax will be computed. Default: -1.
+
+        Returns:
+            Sampled tensor of same shape as logits from the Gumbel-Softmax distribution.
+        """
+        for _ in range(10):
+            transformed = functional.gumbel_softmax(logits, tau=tau, hard=hard, eps=eps, dim=dim)
+            if not torch.isnan(transformed).any():
+                return transformed
+
+        raise ValueError('gumbel_softmax returning NaN.')
+
+    def _apply_activate(self, data):
+        """Apply proper activation function to the output of the generator."""
+        data_t = []
+        st = 0
+        for column_info in self._transformer.output_info_list:
+            for span_info in column_info:
+                if span_info.activation_fn == 'tanh':
+                    ed = st + span_info.dim
+                    data_t.append(torch.tanh(data[:, st:ed]))
+                    st = ed
+                elif span_info.activation_fn == 'softmax':
+                    ed = st + span_info.dim
+                    transformed = self._gumbel_softmax(data[:, st:ed], tau=0.2)
+                    data_t.append(transformed)
+                    st = ed
+                else:
+                    raise ValueError(f'Unexpected activation function {span_info.activation_fn}.')
+
+        return torch.cat(data_t, dim=1)
+
+    def _cond_loss(self, data, c, m):
+        """Compute the cross entropy loss on the fixed discrete column."""
+        loss = []
+        st = 0
+        st_c = 0
+        for column_info in self._transformer.output_info_list:
+            for span_info in column_info:
+                if len(column_info) != 1 or span_info.activation_fn != 'softmax':
+                    # not discrete column
+                    st += span_info.dim
+                else:
+                    ed = st + span_info.dim
+                    ed_c = st_c + span_info.dim
+                    tmp = functional.cross_entropy(
+                        data[:, st:ed], torch.argmax(c[:, st_c:ed_c], dim=1), reduction='none'
+                    )
+                    loss.append(tmp)
+                    st = ed
+                    st_c = ed_c
+
+        loss = torch.stack(loss, dim=1)  # noqa: PD013
+
+        return (loss * m).sum() / data.size()[0]
+
+    def _validate_discrete_columns(self, train_data, discrete_columns):
+        """Check whether ``discrete_columns`` exists in ``train_data``.
+
+        Args:
+            train_data (numpy.ndarray or pandas.DataFrame):
+                Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
+            discrete_columns (list-like):
+                List of discrete columns to be used to generate the Conditional
+                Vector. If ``train_data`` is a Numpy array, this list should
+                contain the integer indices of the columns. Otherwise, if it is
+                a ``pandas.DataFrame``, this list should contain the column names.
+        """
+        if isinstance(train_data, pd.DataFrame):
+            invalid_columns = set(discrete_columns) - set(train_data.columns)
+        elif isinstance(train_data, np.ndarray):
+            invalid_columns = []
+            for column in discrete_columns:
+                if column < 0 or column >= train_data.shape[1]:
+                    invalid_columns.append(column)
+        else:
+            raise TypeError('``train_data`` should be either pd.DataFrame or np.array.')
+
+        if invalid_columns:
+            raise ValueError(f'Invalid columns found: {invalid_columns}')
+
+    @random_state
+    def fit(self, train_data, discrete_columns=(), epochs=None):
+        """Fit the CTGAN Synthesizer models to the training data.
+
+        Args:
+            train_data (numpy.ndarray or pandas.DataFrame):
+                Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
+            discrete_columns (list-like):
+                List of discrete columns to be used to generate the Conditional
+                Vector. If ``train_data`` is a Numpy array, this list should
+                contain the integer indices of the columns. Otherwise, if it is
+                a ``pandas.DataFrame``, this list should contain the column names.
+        """
+        self._validate_discrete_columns(train_data, discrete_columns)
+
+        if epochs is None:
+            epochs = self._epochs
+        else:
+            warnings.warn(
+                (
+                    '`epochs` argument in `fit` method has been deprecated and will be removed '
+                    'in a future version. Please pass `epochs` to the constructor instead'
+                ),
+                DeprecationWarning,
+            )
+
+        self._transformer = DataTransformer()
+        self._transformer.fit(train_data, discrete_columns)
+
+        train_data = self._transformer.transform(train_data)
+
+        self._data_sampler = DataSampler(
+            train_data, self._transformer.output_info_list, self._log_frequency
+        )
+
+        data_dim = self._transformer.output_dimensions
+
+        self._generator = Generator(
+            self._embedding_dim + self._data_sampler.dim_cond_vec(), self._generator_dim, data_dim
+        ).to(self._device)
+
+        discriminator = Discriminator(
+            data_dim + self._data_sampler.dim_cond_vec(), self._discriminator_dim, pac=self.pac
+        ).to(self._device)
+
+        optimizerG = optim.Adam(
+            self._generator.parameters(),
+            lr=self._generator_lr,
+            betas=(0.5, 0.9),
+            weight_decay=self._generator_decay,
+        )
+
+        optimizerD = optim.Adam(
+            discriminator.parameters(),
+            lr=self._discriminator_lr,
+            betas=(0.5, 0.9),
+            weight_decay=self._discriminator_decay,
+        )
+
+        mean = torch.zeros(self._batch_size, self._embedding_dim, device=self._device)
+        std = mean + 1
+
+        self.loss_values = pd.DataFrame(columns=['Epoch', 'Generator Loss', 'Distriminator Loss'])
+
+        epoch_iterator = tqdm(range(epochs), disable=(not self._verbose),file=sys.stdout)
+        if self._verbose:
+            description = 'Gen. ({gen:.2f}) | Discrim. ({dis:.2f})'
+            epoch_iterator.set_description(description.format(gen=0, dis=0))
+
+        steps_per_epoch = max(len(train_data) // self._batch_size, 1)
+        for i in epoch_iterator:
+            logging.info(f" the epoch {i}")
+            for id_ in range(steps_per_epoch):
+                for n in range(self._discriminator_steps):
+                    fakez = torch.normal(mean=mean, std=std)
+
+                    condvec = self._data_sampler.sample_condvec(self._batch_size)
+                    if condvec is None:
+                        c1, m1, col, opt = None, None, None, None
+                        real = self._data_sampler.sample_data(
+                            train_data, self._batch_size, col, opt
+                        )
+                    else:
+                        c1, m1, col, opt = condvec
+                        c1 = torch.from_numpy(c1).to(self._device)
+                        m1 = torch.from_numpy(m1).to(self._device)
+                        fakez = torch.cat([fakez, c1], dim=1)
+
+                        perm = np.arange(self._batch_size)
+                        np.random.shuffle(perm)
+                        real = self._data_sampler.sample_data(
+                            train_data, self._batch_size, col[perm], opt[perm]
+                        )
+                        c2 = c1[perm]
+
+                    fake = self._generator(fakez)
+                    fakeact = self._apply_activate(fake)
+
+                    real = torch.from_numpy(real.astype('float32')).to(self._device)
+
+                    if c1 is not None:
+                        fake_cat = torch.cat([fakeact, c1], dim=1)
+                        real_cat = torch.cat([real, c2], dim=1)
+                    else:
+                        real_cat = real
+                        fake_cat = fakeact
+
+                    y_fake = discriminator(fake_cat)
+                    y_real = discriminator(real_cat)
+
+                    pen = discriminator.calc_gradient_penalty(
+                        real_cat, fake_cat, self._device, self.pac
+                    )
+                    loss_d = -(torch.mean(y_real) - torch.mean(y_fake))
+
+                    optimizerD.zero_grad(set_to_none=False)
+                    pen.backward(retain_graph=True)
+                    loss_d.backward()
+                    optimizerD.step()
+
+                fakez = torch.normal(mean=mean, std=std)
+                condvec = self._data_sampler.sample_condvec(self._batch_size)
+
+                if condvec is None:
+                    c1, m1, col, opt = None, None, None, None
+                else:
+                    c1, m1, col, opt = condvec
+                    c1 = torch.from_numpy(c1).to(self._device)
+                    m1 = torch.from_numpy(m1).to(self._device)
+                    fakez = torch.cat([fakez, c1], dim=1)
+
+                fake = self._generator(fakez)
+                fakeact = self._apply_activate(fake)
+
+                if c1 is not None:
+                    y_fake = discriminator(torch.cat([fakeact, c1], dim=1))
+                else:
+                    y_fake = discriminator(fakeact)
+
+                if condvec is None:
+                    cross_entropy = 0
+                else:
+                    cross_entropy = self._cond_loss(fake, c1, m1)
+
+                loss_g = -torch.mean(y_fake) + cross_entropy
+
+                optimizerG.zero_grad(set_to_none=False)
+                loss_g.backward()
+                optimizerG.step()
+
+            generator_loss = loss_g.detach().cpu().item()
+            discriminator_loss = loss_d.detach().cpu().item()
+
+            epoch_loss_df = pd.DataFrame({
+                'Epoch': [i],
+                'Generator Loss': [generator_loss],
+                'Discriminator Loss': [discriminator_loss],
+            })
+            if not self.loss_values.empty:
+                self.loss_values = pd.concat([self.loss_values, epoch_loss_df]).reset_index(
+                    drop=True
+                )
+            else:
+                self.loss_values = epoch_loss_df
+            print(description.format(gen=generator_loss, dis=discriminator_loss))
+            if self._verbose:
+                epoch_iterator.set_description(
+                    description.format(gen=generator_loss, dis=discriminator_loss)
+                )
+
+    @random_state
+    def sample(self, n, condition_column=None, condition_value=None):
+        """Sample data similar to the training data.
+
+        Choosing a condition_column and condition_value will increase the probability of the
+        discrete condition_value happening in the condition_column.
+
+        Args:
+            n (int):
+                Number of rows to sample.
+            condition_column (string):
+                Name of a discrete column.
+            condition_value (string):
+                Name of the category in the condition_column which we wish to increase the
+                probability of happening.
+
+        Returns:
+            numpy.ndarray or pandas.DataFrame
+        """
+        if condition_column is not None and condition_value is not None:
+            condition_info = self._transformer.convert_column_name_value_to_id(
+                condition_column, condition_value
+            )
+            global_condition_vec = self._data_sampler.generate_cond_from_condition_column_info(
+                condition_info, self._batch_size
+            )
+        else:
+            global_condition_vec = None
+
+        steps = n // self._batch_size + 1
+        data = []
+        for i in range(steps):
+            mean = torch.zeros(self._batch_size, self._embedding_dim)
+            std = mean + 1
+            fakez = torch.normal(mean=mean, std=std).to(self._device)
+
+            if global_condition_vec is not None:
+                condvec = global_condition_vec.copy()
+            else:
+                condvec = self._data_sampler.sample_original_condvec(self._batch_size)
+
+            if condvec is None:
+                pass
+            else:
+                c1 = condvec
+                c1 = torch.from_numpy(c1).to(self._device)
+                fakez = torch.cat([fakez, c1], dim=1)
+
+            fake = self._generator(fakez)
+            fakeact = self._apply_activate(fake)
+            data.append(fakeact.detach().cpu().numpy())
+
+        data = np.concatenate(data, axis=0)
+        data = data[:n]
+
+        return self._transformer.inverse_transform(data)
+
+    def set_device(self, device):
+        """Set the `device` to be used ('GPU' or 'CPU)."""
+        self._device = device
+        if self._generator is not None:
+            self._generator.to(self._device)
+
+    def predict(self, data):
+        if self._generator is None or self._transformer is None or self._data_sampler is None:
+            raise RuntimeError("模型尚未训练,请先调用 `fit` 方法训练模型。")
+
+        # 预处理输入数据
+        if isinstance(data, pd.DataFrame):
+            data = self._transformer.transform(data)
+        elif isinstance(data, np.ndarray):
+            if data.shape[1] != self._transformer.output_dimensions:
+                raise ValueError(
+                    f"输入数据的列数 ({data.shape[1]}) 与训练数据的列数 ({self._transformer.output_dimensions}) 不匹配。"
+                )
+        else:
+            raise TypeError("输入数据必须是 pandas.DataFrame 或 numpy.ndarray 类型。")
+
+        # 将数据转换为 PyTorch 张量
+        data_tensor = torch.from_numpy(data.astype("float32")).to(self._device)
+
+        # 为 Discriminator 添加条件向量
+        condvec = self._data_sampler.sample_original_condvec(data.shape[0])
+        if condvec is not None:
+            c1 = torch.from_numpy(condvec).to(self._device)
+            data_tensor = torch.cat([data_tensor, c1], dim=1)
+
+        # 使用 Discriminator 生成分数
+        discriminator = Discriminator(
+            data_tensor.shape[1], self._discriminator_dim, pac=self.pac
+        ).to(self._device)
+        discriminator.load_state_dict(self._generator.state_dict())  # 使用训练好的权重
+
+        # 预测输出
+        with torch.no_grad():
+            scores = discriminator(data_tensor)
+
+        return scores.detach().cpu().numpy()
\ No newline at end of file
diff --git a/ctgan/synthesizers/tvae.py b/ctgan/synthesizers/tvae.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecefbb5f5040570c56469e4fc5686852d79b941f
--- /dev/null
+++ b/ctgan/synthesizers/tvae.py
@@ -0,0 +1,246 @@
+"""TVAE module."""
+
+import numpy as np
+import pandas as pd
+import torch
+from torch.nn import Linear, Module, Parameter, ReLU, Sequential
+from torch.nn.functional import cross_entropy
+from torch.optim import Adam
+from torch.utils.data import DataLoader, TensorDataset
+from tqdm import tqdm
+
+from ctgan.data_transformer import DataTransformer
+from ctgan.synthesizers.base import BaseSynthesizer, random_state
+
+
+class Encoder(Module):
+    """Encoder for the TVAE.
+
+    Args:
+        data_dim (int):
+            Dimensions of the data.
+        compress_dims (tuple or list of ints):
+            Size of each hidden layer.
+        embedding_dim (int):
+            Size of the output vector.
+    """
+
+    def __init__(self, data_dim, compress_dims, embedding_dim):
+        super(Encoder, self).__init__()
+        dim = data_dim
+        seq = []
+        for item in list(compress_dims):
+            seq += [Linear(dim, item), ReLU()]
+            dim = item
+
+        self.seq = Sequential(*seq)
+        self.fc1 = Linear(dim, embedding_dim)
+        self.fc2 = Linear(dim, embedding_dim)
+
+    def forward(self, input_):
+        """Encode the passed `input_`."""
+        feature = self.seq(input_)
+        mu = self.fc1(feature)
+        logvar = self.fc2(feature)
+        std = torch.exp(0.5 * logvar)
+        return mu, std, logvar
+
+
+class Decoder(Module):
+    """Decoder for the TVAE.
+
+    Args:
+        embedding_dim (int):
+            Size of the input vector.
+        decompress_dims (tuple or list of ints):
+            Size of each hidden layer.
+        data_dim (int):
+            Dimensions of the data.
+    """
+
+    def __init__(self, embedding_dim, decompress_dims, data_dim):
+        super(Decoder, self).__init__()
+        dim = embedding_dim
+        seq = []
+        for item in list(decompress_dims):
+            seq += [Linear(dim, item), ReLU()]
+            dim = item
+
+        seq.append(Linear(dim, data_dim))
+        self.seq = Sequential(*seq)
+        self.sigma = Parameter(torch.ones(data_dim) * 0.1)
+
+    def forward(self, input_):
+        """Decode the passed `input_`."""
+        return self.seq(input_), self.sigma
+
+
+def _loss_function(recon_x, x, sigmas, mu, logvar, output_info, factor):
+    st = 0
+    loss = []
+    for column_info in output_info:
+        for span_info in column_info:
+            if span_info.activation_fn != 'softmax':
+                ed = st + span_info.dim
+                std = sigmas[st]
+                eq = x[:, st] - torch.tanh(recon_x[:, st])
+                loss.append((eq**2 / 2 / (std**2)).sum())
+                loss.append(torch.log(std) * x.size()[0])
+                st = ed
+
+            else:
+                ed = st + span_info.dim
+                loss.append(
+                    cross_entropy(
+                        recon_x[:, st:ed], torch.argmax(x[:, st:ed], dim=-1), reduction='sum'
+                    )
+                )
+                st = ed
+
+    assert st == recon_x.size()[1]
+    KLD = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp())
+    return sum(loss) * factor / x.size()[0], KLD / x.size()[0]
+
+
+class TVAE(BaseSynthesizer):
+    """TVAE."""
+
+    def __init__(
+        self,
+        embedding_dim=128,
+        compress_dims=(128, 128),
+        decompress_dims=(128, 128),
+        l2scale=1e-5,
+        batch_size=500,
+        epochs=300,
+        loss_factor=2,
+        cuda=True,
+        verbose=False,
+    ):
+        self.embedding_dim = embedding_dim
+        self.compress_dims = compress_dims
+        self.decompress_dims = decompress_dims
+
+        self.l2scale = l2scale
+        self.batch_size = batch_size
+        self.loss_factor = loss_factor
+        self.epochs = epochs
+        self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss'])
+        self.verbose = verbose
+
+        if not cuda or not torch.cuda.is_available():
+            device = 'cpu'
+        elif isinstance(cuda, str):
+            device = cuda
+        else:
+            device = 'cuda'
+
+        self._device = torch.device(device)
+
+    @random_state
+    def fit(self, train_data, discrete_columns=()):
+        """Fit the TVAE Synthesizer models to the training data.
+
+        Args:
+            train_data (numpy.ndarray or pandas.DataFrame):
+                Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
+            discrete_columns (list-like):
+                List of discrete columns to be used to generate the Conditional
+                Vector. If ``train_data`` is a Numpy array, this list should
+                contain the integer indices of the columns. Otherwise, if it is
+                a ``pandas.DataFrame``, this list should contain the column names.
+        """
+        self.transformer = DataTransformer()
+        self.transformer.fit(train_data, discrete_columns)
+        train_data = self.transformer.transform(train_data)
+        dataset = TensorDataset(torch.from_numpy(train_data.astype('float32')).to(self._device))
+        loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, drop_last=False)
+
+        data_dim = self.transformer.output_dimensions
+        encoder = Encoder(data_dim, self.compress_dims, self.embedding_dim).to(self._device)
+        self.decoder = Decoder(self.embedding_dim, self.decompress_dims, data_dim).to(self._device)
+        optimizerAE = Adam(
+            list(encoder.parameters()) + list(self.decoder.parameters()), weight_decay=self.l2scale
+        )
+
+        self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss'])
+        iterator = tqdm(range(self.epochs), disable=(not self.verbose))
+        if self.verbose:
+            iterator_description = 'Loss: {loss:.3f}'
+            iterator.set_description(iterator_description.format(loss=0))
+
+        for i in iterator:
+            loss_values = []
+            batch = []
+            for id_, data in enumerate(loader):
+                optimizerAE.zero_grad()
+                real = data[0].to(self._device)
+                mu, std, logvar = encoder(real)
+                eps = torch.randn_like(std)
+                emb = eps * std + mu
+                rec, sigmas = self.decoder(emb)
+                loss_1, loss_2 = _loss_function(
+                    rec,
+                    real,
+                    sigmas,
+                    mu,
+                    logvar,
+                    self.transformer.output_info_list,
+                    self.loss_factor,
+                )
+                loss = loss_1 + loss_2
+                loss.backward()
+                optimizerAE.step()
+                self.decoder.sigma.data.clamp_(0.01, 1.0)
+
+                batch.append(id_)
+                loss_values.append(loss.detach().cpu().item())
+
+            epoch_loss_df = pd.DataFrame({
+                'Epoch': [i] * len(batch),
+                'Batch': batch,
+                'Loss': loss_values,
+            })
+            if not self.loss_values.empty:
+                self.loss_values = pd.concat([self.loss_values, epoch_loss_df]).reset_index(
+                    drop=True
+                )
+            else:
+                self.loss_values = epoch_loss_df
+
+            if self.verbose:
+                iterator.set_description(
+                    iterator_description.format(loss=loss.detach().cpu().item())
+                )
+
+    @random_state
+    def sample(self, samples):
+        """Sample data similar to the training data.
+
+        Args:
+            samples (int):
+                Number of rows to sample.
+
+        Returns:
+            numpy.ndarray or pandas.DataFrame
+        """
+        self.decoder.eval()
+
+        steps = samples // self.batch_size + 1
+        data = []
+        for _ in range(steps):
+            mean = torch.zeros(self.batch_size, self.embedding_dim)
+            std = mean + 1
+            noise = torch.normal(mean=mean, std=std).to(self._device)
+            fake, sigmas = self.decoder(noise)
+            fake = torch.tanh(fake)
+            data.append(fake.detach().cpu().numpy())
+
+        data = np.concatenate(data, axis=0)
+        data = data[:samples]
+        return self.transformer.inverse_transform(data, sigmas.detach().cpu().numpy())
+
+    def set_device(self, device):
+        """Set the `device` to be used ('GPU' or 'CPU)."""
+        self._device = device
+        self.decoder.to(self._device)
diff --git a/dataprocess/Dockerfile b/dataprocess/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..b5b5ebcaa12ed30465c0c8b15a776f4505263606
--- /dev/null
+++ b/dataprocess/Dockerfile
@@ -0,0 +1,20 @@
+# 使用官方的 Python 基础镜像
+FROM python:3.11-slim
+
+# 设置工作目录
+WORKDIR /opt/ml/processing
+
+# 安装系统依赖
+RUN apt-get update && apt-get install -y --no-install-recommends \
+    build-essential \
+    && rm -rf /var/lib/apt/lists/*
+
+# 升级 pip 并安装 Python 包
+RUN pip install --upgrade pip
+RUN pip install pandas numpy scikit-learn imbalanced-learn
+
+# 将处理脚本复制到容器中
+COPY dataprocess.py /opt/ml/processing/process_data.py
+
+# 设置默认的命令
+ENTRYPOINT ["python", "/opt/ml/processing/process_data.py"]
diff --git a/dataprocess/dataprocess.py b/dataprocess/dataprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7399bed5832a9e4a2bbaf8c937f55e32467f645
--- /dev/null
+++ b/dataprocess/dataprocess.py
@@ -0,0 +1,191 @@
+import argparse
+import os
+import pandas as pd
+import numpy as np
+from sklearn.impute import SimpleImputer
+from sklearn.preprocessing import LabelEncoder
+from imblearn.over_sampling import SMOTE
+import json
+import logging
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s [%(levelname)s] %(message)s",
+    handlers=[
+        logging.StreamHandler()  # Log to stdout
+    ]
+)
+label_map = {}  # LabelEncoder 映射
+params = {}
+def encode_numeric_range(df, names, normalized_low=0, normalized_high=1):
+    """将数值数据归一化到指定范围 [normalized_low, normalized_high]"""
+    global params
+    for name in names:
+        data_low = df[name].min()
+        data_high = df[name].max()
+        # 如果 NaN,赋值为 0
+        data_low = 0 if np.isnan(data_low) else data_low
+        data_high = 0 if np.isnan(data_high) else data_high
+        # 避免 ZeroDivisionError(如果 min == max,则直接赋 0)
+        if data_high != data_low:
+            df[name] = ((df[name] - data_low) / (data_high - data_low)) \
+                        * (normalized_high - normalized_low) + normalized_low
+        else:
+            df[name] = 0
+        if name not in params:
+            params[name] = {}
+        params[name]['min'] = data_low
+        params[name]['max'] = data_high
+    return df
+
+def encode_numeric_zscore(df, names):
+    """将数值数据标准化为 Z-score"""
+    global params
+    for name in names:
+        mean = df[name].mean()
+        sd = df[name].std()
+        # 如果 NaN,赋值为 0
+        mean = 0 if np.isnan(mean) else mean
+        sd = 0 if np.isnan(sd) else sd
+        # 避免除以 0 的情况(如果标准差为 0,则直接赋 0)
+        if sd != 0:
+            df[name] = (df[name] - mean) / sd
+        else:
+            df[name] = 0
+        if name not in params:
+            params[name] = {}
+        params[name]['mean'] = mean
+        params[name]['std'] = sd
+    return df
+
+def numerical_encoding(df, label_column):
+    """将类别标签编码为 0,1,2... 并存储映射"""
+    global label_map
+    label_encoder = LabelEncoder()
+    df[label_column] = label_encoder.fit_transform(df[label_column])
+
+    # 生成 label -> 数字 和 数字 -> label 映射
+    label_map = {
+        "label_to_number": {label: int(num) for label, num in zip(label_encoder.classes_, range(len(label_encoder.classes_)))},
+        "number_to_label": {int(num): label for label, num in zip(label_encoder.classes_, range(len(label_encoder.classes_)))}
+    }
+    return df
+
+# def oversample_data(X, y):
+#     """Oversample the minority class using replication to avoid heavy computation."""
+#     lists = [X[y == label].values.tolist() for label in np.unique(y)]
+#     totall_list = []
+#     for i in range(len(lists)):
+#         while len(lists[i]) < 5000:
+#             lists[i].extend(lists[i])
+#         totall_list.extend(lists[i])
+#     new_y = np.concatenate([[label] * len(lists[i]) for i, label in enumerate(np.unique(y))])
+#     new_X = np.array(totall_list)
+#     return pd.DataFrame(new_X, columns=X.columns), pd.Series(new_y)
+
+if __name__ == "__main__":
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--input", type=str, default=r"E:\datasource\MachineLearningCSV\MachineLearningCVE", help="Input folder path containing CSV files")
+    parser.add_argument("--output", type=str, default=r"D:\djangoProject\talktive\dataprocess\result.csv", help="Output folder path for processed CSV files")
+    args = parser.parse_args()
+    # Feature and label column definitions
+    selected_features = [
+        "Destination Port",  # formatted_data["Destination Port"]
+        "Flow Duration",  # formatted_data["Flow Duration (ms)"]
+        "Total Fwd Packets",  # formatted_data["Total Fwd Packets"]
+        "Total Backward Packets",  # formatted_data["Total Backward Packets"]
+        "Total Length of Fwd Packets",  # formatted_data["Total Length of Fwd Packets"]
+        "Total Length of Bwd Packets",  # formatted_data["Total Length of Bwd Packets"]
+        "Fwd Packet Length Max",  # formatted_data["Fwd Packet Length Max"]
+        "Fwd Packet Length Min",  # formatted_data["Fwd Packet Length Min"]
+        "Fwd Packet Length Mean",  # formatted_data["Fwd Packet Length Mean"]
+        "Fwd Packet Length Std",  # formatted_data["Fwd Packet Length Stddev"]
+        "Bwd Packet Length Max",  # formatted_data["Bwd Packet Length Max"]
+        "Bwd Packet Length Min",  # formatted_data["Bwd Packet Length Min"]
+        "Bwd Packet Length Mean",  # formatted_data["Bwd Packet Length Mean"]
+        "Bwd Packet Length Std",  # formatted_data["Bwd Packet Length Stddev"]
+        "Flow Bytes/s",  # formatted_data["Flow Bytes/s"]
+        "Flow Packets/s",  # formatted_data["Flow Packets/s"]
+        "Flow IAT Mean",  # formatted_data["Flow IAT Mean (ms)"]
+        "Flow IAT Std",  # formatted_data["Flow IAT Stddev (ms)"]
+        "Flow IAT Max",  # formatted_data["Flow IAT Max (ms)"]
+        "Flow IAT Min",  # formatted_data["Flow IAT Min (ms)"]
+        "Fwd IAT Mean",  # formatted_data["Fwd IAT Mean (ms)"]
+        "Fwd IAT Std",  # formatted_data["Fwd IAT Stddev (ms)"]
+        "Fwd IAT Max",  # formatted_data["Fwd IAT Max (ms)"]
+        "Fwd IAT Min",  # formatted_data["Fwd IAT Min (ms)"]
+        "Bwd IAT Mean",  # formatted_data["Bwd IAT Mean (ms)"]
+        "Bwd IAT Std",  # formatted_data["Bwd IAT Stddev (ms)"]
+        "Bwd IAT Max",  # formatted_data["Bwd IAT Max (ms)"]
+        "Bwd IAT Min",  # formatted_data["Bwd IAT Min (ms)"]
+        "Fwd PSH Flags",  # formatted_data["Fwd PSH Flags"]
+        "Bwd PSH Flags",  # formatted_data["Bwd PSH Flags"]
+        "Fwd URG Flags",  # formatted_data["Fwd URG Flags"]
+        "Bwd URG Flags",  # formatted_data["Bwd URG Flags"]
+        "Fwd Packets/s",  # formatted_data["Fwd Packets/s"]
+        "Bwd Packets/s",  # formatted_data["Bwd Packets/s"]
+        "Down/Up Ratio",  # formatted_data["down_up_ratio"]
+        "Average Packet Size",  # formatted_data["average_packet_size"]
+        "Avg Fwd Segment Size",  # formatted_data["avg_fwd_segment_size"]
+        "Avg Bwd Segment Size",  # formatted_data["avg_bwd_segment_size"]
+        "Packet Length Mean",  # formatted_data["Packet Length Mean"]
+        "Packet Length Std",  # formatted_data["Packet Length Std"]
+        "FIN Flag Count",  # formatted_data["FIN Flag Count"]
+        "SYN Flag Count",  # formatted_data["SYN Flag Count"]
+        "RST Flag Count",  # formatted_data["RST Flag Count"]
+        "PSH Flag Count",  # formatted_data["PSH Flag Count"]
+        "ACK Flag Count",  # formatted_data["ACK Flag Count"]
+        "URG Flag Count",  # formatted_data["URG Flag Count"]
+        "CWE Flag Count",  # formatted_data["CWE Flag Count"]
+        "ECE Flag Count",  # formatted_data["ECE Flag Count"]
+        "Subflow Fwd Packets",  # formatted_data["Subflow Fwd Packets"]
+        "Subflow Fwd Bytes",  # formatted_data["Subflow Fwd Bytes"]
+        "Subflow Bwd Packets",  # formatted_data["Subflow Bwd Packets"]
+        "Subflow Bwd Bytes",  # formatted_data["Subflow Bwd Bytes"]
+        # "Init_Win_bytes_forward",  # formatted_data["Init_Win_bytes_forward"]
+        # "Init_Win_bytes_backward",  # formatted_data["Init_Win_bytes_backward"],
+        "Label"
+    ]
+
+    standardized_features = [col.strip().lower().replace(" ", "_") for col in selected_features]
+    all_data = []
+    logging.info("Starting data processing...")
+
+    # Load and process CSV files
+    for file_name in os.listdir(args.input):
+        if file_name.endswith(".csv"):
+            file_path = os.path.join(args.input, file_name)
+            logging.info(f"Processing file: {file_path}")
+            df = pd.read_csv(file_path)
+            df.columns = [col.strip().lower().replace(" ", "_") for col in df.columns]
+            df_selected = df[[col for col in standardized_features if col in df.columns]].dropna(how="any", subset=["label"])
+            all_data.append(df_selected)
+
+    # Combine all data
+    if all_data:
+        final_data = pd.concat(all_data, ignore_index=True)
+    else:
+        raise ValueError("No valid CSV files found in the input folder.")
+    final_data = numerical_encoding(final_data, "label")
+    X = final_data.drop(columns=["label"])
+    y = final_data["label"]
+    # Data cleaning and transformation
+    X.replace([float('inf'), float('-inf')], np.nan, inplace=True)
+    X.fillna(X.median(), inplace=True)
+    X = encode_numeric_zscore(X, X.columns)
+    X = encode_numeric_range(X, X.columns)
+    scaling_params_file = os.path.join(args.output, "scaling_params.json")
+    with open(scaling_params_file, 'w') as f:
+        json.dump(params, f)
+    label_map_file=os.path.join(args.output,"label_map.json")
+    with open(label_map_file,"w") as f:
+        json.dump(label_map,f)
+    # Encoding the label column
+    # Save processed data
+
+    final_data = pd.concat([X, y], axis=1)
+    # Save the final data to a CSV file
+    output_path = os.path.join(args.output, "processed_data.csv")
+    final_data.to_csv(output_path, index=False)
+    logging.info(f"Processed data saved to {output_path}")
+    logging.info("Data processing completed successfully!")