From 48d5f3f180687d9b6ada4bfbbec93bd6294f6bd5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E6=B1=AA=E5=89=8D?= <1104239712@qq.com>
Date: Sun, 9 Feb 2025 15:26:06 +0800
Subject: [PATCH] icnnprocess

---
 classfications/__init__.py                    |  11 ++
 classfications/__main__.py                    | 179 ++++++++++++++++++
 .../__pycache__/__init__.cpython-311.pyc      | Bin 0 -> 493 bytes
 .../__pycache__/train_data.cpython-311.pyc    | Bin 0 -> 1338 bytes
 classfications/datas/train_data.py            |  16 ++
 classfications/evalute.py                     | 144 ++++++++++++++
 .../model/__pycache__/icnn.cpython-311.pyc    | Bin 0 -> 2839 bytes
 classfications/model/icnn.py                  |  44 +++++
 classfications/predict.py                     | 149 +++++++++++++++
 classfications/test/test1.py                  |  44 +++++
 10 files changed, 587 insertions(+)
 create mode 100644 classfications/__init__.py
 create mode 100644 classfications/__main__.py
 create mode 100644 classfications/__pycache__/__init__.cpython-311.pyc
 create mode 100644 classfications/datas/__pycache__/train_data.cpython-311.pyc
 create mode 100644 classfications/datas/train_data.py
 create mode 100644 classfications/evalute.py
 create mode 100644 classfications/model/__pycache__/icnn.cpython-311.pyc
 create mode 100644 classfications/model/icnn.py
 create mode 100644 classfications/predict.py
 create mode 100644 classfications/test/test1.py

diff --git a/classfications/__init__.py b/classfications/__init__.py
new file mode 100644
index 0000000..ee2bb24
--- /dev/null
+++ b/classfications/__init__.py
@@ -0,0 +1,11 @@
+# -*- coding: utf-8 -*-
+
+"""Top-level package for ICNN."""
+
+__author__ = 'Qian Wang, Inc.'
+__email__ = 'info@sdv.dev'
+__version__ = '0.10.2.dev0'
+
+from classfications.model . icnn import CNN
+from classfications.datas.train_data import LoadData
+__all__ = ('CNN','LoadData')
\ No newline at end of file
diff --git a/classfications/__main__.py b/classfications/__main__.py
new file mode 100644
index 0000000..c309926
--- /dev/null
+++ b/classfications/__main__.py
@@ -0,0 +1,179 @@
+import argparse
+import logging
+import os
+
+import torch
+from matplotlib import pyplot as plt
+from torch.utils.data import DataLoader
+
+from classfications import LoadData
+from classfications import CNN
+import pandas as pd
+import json
+# 配置日志
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s [%(levelname)s] %(message)s",
+    handlers=[logging.StreamHandler()]
+)
+
+# 设定设备
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+def loss_value_plot(losses, iter):
+    plt.figure()
+    plt.plot([i for i in range(1, iter + 1)], losses)
+    plt.xlabel('Iterations (×100)')
+    plt.ylabel('Loss Value')
+    plt.show()
+
+def train(model, optimizer, loss_fn, train_dataloader, epochs, X_dimension):
+    losses = []
+    iter = 0
+
+    for epoch in range(epochs):
+        logging.info(f"Epoch {epoch + 1}/{epochs} started")
+        for i, (X, y) in enumerate(train_dataloader):
+            X, y = X.to(device), y.to(device)
+            X = X.reshape(X.shape[0], 1, X_dimension)  # Dynamic dimension
+            y_pred = model(X)
+            loss = loss_fn(y_pred, y)
+
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+
+            if i % 100 == 0:
+                logging.info(f"Loss: {loss.item():.4f}\t[{(i + 1) * len(X)}/{len(train_dataloader.dataset)}]")
+
+                iter += 1
+                losses.append(loss.item())
+
+    return losses, iter
+
+
+# 定义测试函数
+def test(model, test_dataloader, loss_fn, X_dimension):
+    positive = 0
+    negative = 0
+    loss_sum = 0
+    iter = 0
+
+    with torch.no_grad():
+        for X, y in test_dataloader:
+            X, y = X.to(device), y.to(device)
+            X = X.reshape(X.shape[0], 1, X_dimension)  # Dynamic dimension
+            y_pred = model(X)
+            loss = loss_fn(y_pred, y)
+            loss_sum += loss.item()
+            iter += 1
+
+            for item in zip(y_pred, y):
+                if torch.argmax(item[0]) == item[1]:
+                    positive += 1
+                else:
+                    negative += 1
+
+    acc = positive / (positive + negative)
+    avg_loss = loss_sum / iter
+    logging.info(f"Accuracy: {acc:.4f}")
+    logging.info(f"Average Loss: {avg_loss:.4f}")
+    return acc, avg_loss
+
+class HyperparametersParser:
+    def __init__(self, hyperparameters_file):
+        self.hyperparameters_file = hyperparameters_file
+        self.hyperparameters = {}
+
+    def parse_hyperparameters(self):
+        """
+        解析超参数文件并进行类型转换。
+        """
+        try:
+            with open(self.hyperparameters_file, "r") as f:
+                injected_hyperparameters = json.load(f)
+
+            logging.info("Parsing hyperparameters...")
+            for key, value in injected_hyperparameters.items():
+                logging.info(f"Raw hyperparameter - Key: {key}, Value: {value}")
+                self.hyperparameters[key] = self._convert_type(value)
+        except FileNotFoundError:
+            logging.warning(f"Hyperparameters file {self.hyperparameters_file} not found. Using default parameters.")
+        except Exception as e:
+            logging.error(f"Error parsing hyperparameters: {e}")
+        return self.hyperparameters
+
+    @staticmethod
+    def _convert_type(value):
+        """
+        将超参数值转换为合适的 Python 类型(字符串、整数、浮点数或布尔值)。
+        """
+        if isinstance(value, str):
+            # 判断是否为布尔值
+            if value.lower() == "true":
+                return True
+            elif value.lower() == "false":
+                return False
+
+            # 判断是否为整数或浮点数
+            try:
+                if "." in value:
+                    return float(value)
+                else:
+                    return int(value)
+            except ValueError:
+                pass  # 如果无法转换,保留字符串
+        return value
+# 读取训练数据的输入路径和输出路径
+
+if __name__ == '__main__':
+    # 超参数文件路径
+    hyperparameters_file = "/opt/ml/input/config/hyperparameters.json"
+
+    # 解析超参数
+    parser = HyperparametersParser(hyperparameters_file)
+    hyperparameters = parser.parse_hyperparameters()
+
+    # 提取超参数值
+    input_data_path = hyperparameters.get("input_data", r"D:/djangoProject/talktive/classfications/sample_dataas/processed_data.csv")
+    output_path = hyperparameters.get("output_path", r"D:/djangoProject/talktive/classfications/sample_dataas")
+    epochs = hyperparameters.get("epochs", 10)
+
+    # 加载数据
+    logging.info("Loading data...")
+    df = pd.read_csv(input_data_path)
+    X = df.drop(columns=["label"])  # Assuming "label" is the column to predict
+    y = df["label"]
+
+    # 自动获取输入特征的维度
+    X_dimension = X.shape[1]
+    logging.info(f"X_dimension (number of features): {X_dimension}")
+
+    # 创建数据集和数据加载器
+    dataset = LoadData(X, y)
+    train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
+
+    # 计算 y_dimension
+    y_dimension = len(y.unique())
+    logging.info(f"y_dimension (number of classes): {y_dimension}")
+
+    # 载入模型和设置训练
+    model = CNN(y_dimension)  # 传入 y_dimension
+    model.to(device)
+
+    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
+    loss_fn = torch.nn.CrossEntropyLoss()
+
+    logging.info(f"Training the model for {epochs} epochs...")
+
+    # 训练模型
+    losses, iter = train(model, optimizer, loss_fn, train_loader, epochs, X_dimension)
+
+    # 绘制损失曲线
+    loss_value_plot(losses, iter)
+
+    # 保存训练后的模型
+    os.makedirs(output_path, exist_ok=True)
+    model_save_path = os.path.join(output_path, "trained_model.pth")
+    torch.save(model.state_dict(), model_save_path)
+    logging.info(f"Model saved to {model_save_path}")
diff --git a/classfications/__pycache__/__init__.cpython-311.pyc b/classfications/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..891a91dc1630421dfd9b5b6b3271f5a538d30605
GIT binary patch
literal 493
zcmZ3^%ge>Uz`(#`+np}O$iVOz#DQTZDC4sc0|Uc!h7^V<h7`sq#uTO~rWEEV<`k9`
z)*O~x)+km+h7^_*wj8!x_9%8nh7|TFjuehy22IW?>5%*a-JI02)EtF^#N_P6^i+kk
z{2~QUXFor^D*nLC#5{%Y#JqGJ1<$->y(*r}ytI6W;*>JIl+>~+ZUa3-13e=U+u$Y0
zC7O)4n4SInG#PJk_~a+1xFnV&YBCiuGcYg|u`n<&_-S(9Vvmnc$xn`tzr__FpIBOw
zkzW)ae~U9dJ~cNnGbcX&7I%DnS!z*nW`16L{4L4koW$bdw9MqhlFa<PV!hn_l++x(
z%;da0u%Wl)(d1JSOA?FqN{SLQ^Ws5lu&p2~6LWIn<5x0#208MViHlWCN>*ZCdVWAr
zepYI7NlZy%PIgIVS!xWbc`@<vnR%Hd@$q^EmA^P_a`RJ4b5iY!_!$@&7#SECiai+^
z7(OsFGBSQ(V_;Oez@YemSFVBU29Hz&=M6sT2JQz8+80pK4F;17sOSclOanUz7V$DL
GFaQAjqL4%Y

literal 0
HcmV?d00001

diff --git a/classfications/datas/__pycache__/train_data.cpython-311.pyc b/classfications/datas/__pycache__/train_data.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b2711635f9b33c5b84dc358d65e6b104833db6b4
GIT binary patch
literal 1338
zcmZ3^%ge>Uz`zg~(Vf1Rk%8echy%l{P{wC71_p-d3@HpLj5!QZAet$MF_$TdiIIVc
z!JQ$6xrHHxC6#FzGXukFW~d&9DCQK#U<OUrmmmQ@O~za7E{P?H#i=F9AORR=gfc!W
zfUQVnh+<4(h+;}%Yyla-)WQ(O0x~{|HH8Ia*ewp9{KOQHLCJ877#J9s7#J8p{Le~Y
zQ)(D%7~-J@GceRJ)iA`v>|sb@2xeHx=%>kai!tIBW2GkZEtcZcoU|ei1_p*AP6h^s
zVvs=!3JSk$T&!YJvJ&&s^8<?Vvr>~wVoDNovP&|{Qe%>H5{rw|GLsWaGV}9_V^TmK
zizz8e%*=}ivGocnZ*j!OXXa&=#K%_&!knxJ6Ji5dBf`MI(7<p*P^g2YhYd<~uywF~
z201(#?p~<RK#uyX1@>JDl3Wc#7M!<?k%3_~oEOYc#LmFLpviQLIVUv_>^)7!B2Zix
zgVZP#aWgP5++vT9&q>XTkFOFyawJ#;tWp5%#=HDN6HKO<byoLOU*M1`;$>i9NCt&C
z$e|$20(Il(QyA_;xUPmF3$7;%<QtGqh8jjmFbyI~Km-(Lfy@N6*Dx+)VqjPe*IUbk
z?0*#VVab(&2=hTvfZcrN8m1cNG^St%O(syPWlbrmEJ)R4zr|XTUzD72i>)LzuQ<O5
z9H6&YGIR2iZ?UK4<R_LG8-bJ0EtZ`8ymU<#a8Tc3&CE+ltpF<q%NB$DrJ$ey4Sw$U
z`1I70%#zgH`1mRjq$q%kf&#nP6A}$Pf;~PnMCORikzNpbMN)gK=@#1yJdPK69Ix;=
zb}-(>l(xFSV|S6q?h21x2jdNXfeDHoB^})0`18|b1EtWs#N5>Q_*-1@@wxdar8yur
zPkek~X<`mUricfWD?q8B2oz5%8NiysNv8-D+F%02D0T!pv4H^sKQJ+|3V&dL6C8}J
z(jOS$1RFD}$Oi^Y0%AJIJWZA&VNm)M00+8WX-Q^Iu^uRE6@iit*xDjlklR4v19s0Z
z4jYIc?26<W7#Kilyf}n`f#Cx)BO~Jt2A&2myuqMy0Tq2<<7O23z<^0~lz#+?egP3^
I@?duY00m_b8vp<R

literal 0
HcmV?d00001

diff --git a/classfications/datas/train_data.py b/classfications/datas/train_data.py
new file mode 100644
index 0000000..caa6bd8
--- /dev/null
+++ b/classfications/datas/train_data.py
@@ -0,0 +1,16 @@
+import torch
+from torch.utils.data import Dataset
+
+
+class LoadData(Dataset):
+    def __init__(self, X, y):
+        self.X = X
+        self.y = y
+
+    def __len__(self):
+        return len(self.X)
+
+    def __getitem__(self, index):
+        X = torch.tensor(self.X.iloc[index], dtype=torch.float32)  # Ensure float32
+        y = torch.tensor(self.y.iloc[index], dtype=torch.long)  # Ensure long type for labels
+        return X, y
diff --git a/classfications/evalute.py b/classfications/evalute.py
new file mode 100644
index 0000000..f6bd2a7
--- /dev/null
+++ b/classfications/evalute.py
@@ -0,0 +1,144 @@
+import argparse
+import logging
+import os
+import tarfile
+
+import torch
+import pandas as pd
+from torch.utils.data import DataLoader
+from classfications import LoadData
+from classfications import CNN
+import json
+from matplotlib import pyplot as plt
+# 配置日志
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s [%(levelname)s] %(message)s",
+    handlers=[logging.StreamHandler()]
+)
+
+# 设定设备
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+def loss_value_plot(losses, iter):
+    plt.figure()
+    plt.plot([i for i in range(1, iter + 1)], losses)
+    plt.xlabel('Iterations (×100)')
+    plt.ylabel('Loss Value')
+    plt.show()
+
+
+def extract_model_and_tokenizer(tar_path, extract_path="/tmp"):
+    """
+    从 tar.gz 文件中提取模型文件和分词器目录。
+
+    Args:
+        tar_path (str): 模型 tar.gz 文件路径。
+        extract_path (str): 提取文件的目录。
+
+    Returns:
+        model_path (str): 提取后的模型文件路径 (.pth)。
+        tokenizer_path (str): 提取后的分词器目录路径。
+    """
+    with tarfile.open(tar_path, "r:gz") as tar:
+        tar.extractall(path=extract_path)
+
+    model_path = None
+
+
+    for root, dirs, files in os.walk(extract_path):
+        for file in files:
+            if file.endswith(".pth"):
+                model_path = os.path.join(root, file)
+
+
+    if not model_path:
+        raise FileNotFoundError("No .pth model file found in the extracted tar.gz.")
+
+
+    return model_path
+
+
+def evaluate(model, validate_dataloader, loss_fn, X_dimension):
+    positive = 0
+    negative = 0
+    loss_sum = 0
+    iter = 0
+    with torch.no_grad():
+        for X, y in validate_dataloader:
+            X, y = X.to(device), y.to(device)
+            X = X.reshape(X.shape[0], 1, X_dimension)  # Dynamic dimension
+            y_pred = model(X)
+            loss = loss_fn(y_pred, y)
+            loss_sum += loss.item()
+            iter += 1
+
+            for item in zip(y_pred, y):
+                if torch.argmax(item[0]) == item[1]:
+                    positive += 1
+                else:
+                    negative += 1
+
+    accuracy = positive / (positive + negative)
+    avg_loss = loss_sum / iter
+    logging.info(f"Evaluation Accuracy: {accuracy:.4f}")
+    logging.info(f"Average Evaluation Loss: {avg_loss:.4f}")
+
+    evaluation_result = {"metrics": {"accuracy": accuracy, "avg_loss": avg_loss}}
+    return evaluation_result
+
+
+# 读取训练数据的输入路径和输出路径
+def read_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--input_data", type=str,
+                        default=r"D:\djangoProject\talktive\classfications\sample_dataas\processed_data.csv",
+                        required=False, help="Input CSV file path")
+    parser.add_argument("--output_path", type=str, default=r"D:\djangoProject\talktive\classfications\sample_dataas",
+                        required=False, help="Path to save the model and results")
+    parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model file")
+    parser.add_argument("--validate_path", type=str, required=True, help="Path to the validation dataset")
+    args = parser.parse_args()
+    return args
+
+
+if __name__ == '__main__':
+    # 读取参数
+    args = read_args()
+    # 加载数据
+    logging.info("Loading training data...")
+    inputpath=os.path.join(args.input_data, "generated_samples.csv")
+    df = pd.read_csv(inputpath)
+    X = df.drop(columns=["label"])  # Assuming "label" is the column to predict
+    X_dimension = X.shape[1]
+    y = df["label"]
+    # 计算 y_dimension
+    y_dimension = len(y.unique())
+    logging.info(f"y_dimension (number of classes): {y_dimension}")
+    # 载入模型
+    model = CNN(y_dimension)
+    model_path=extract_model_and_tokenizer(args.model_path)
+    model.load_state_dict(torch.load(model_path))
+    model.to(device)
+    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
+    loss_fn = torch.nn.CrossEntropyLoss()
+    epochs = 10  # Number of epochs
+    logging.info(f"Training the model for {epochs} epochs...")
+    # 训练模型
+    # 加载验证数据并评估
+    logging.info("Loading validation data...")
+    validate_path = os.path.join(args.validate_path,"generated_samples.csv")
+    validate_df = pd.read_csv(validate_path)
+    X_validate = validate_df.drop(columns=["label"])  # Assuming "label" is the column to predict
+    y_validate = validate_df["label"]
+    # 创建验证数据集和加载器
+    validate_dataset = LoadData(X_validate, y_validate)
+    validate_loader = DataLoader(validate_dataset, batch_size=64, shuffle=False)
+    # 进行评估
+    evaluation_result = evaluate(model, validate_loader, loss_fn, X_dimension)
+    # 保存评估结果
+    output_file = os.path.join(args.output_path, "evaluation.json")
+    with open(output_file, "w") as f:
+        json.dump(evaluation_result, f)
+    logging.info(f"Evaluation results saved to {output_file}")
diff --git a/classfications/model/__pycache__/icnn.cpython-311.pyc b/classfications/model/__pycache__/icnn.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..acb377ec8c2255973d05e218c45740a405763457
GIT binary patch
literal 2839
zcmZ3^%ge>Uz`(F0vOE1FD+9x05C?`?p^VQy3=9m@8B!Qh7;_kM8KW2(L2RZRhA0Tl
z1QBCmaA!ziZed7aNoCAphU#Tl#>~L58qSMiNnr|R&}4lHlGbFr#gv!lmkeUVFcXyV
zS&M;zp`BqmLn=cQV+unQQwKvDV+vynM-+1kQwu{B3&^G@)((aW#wfO622GY*%+7v(
z$uR98Q@9xz7(f<(e#XSWFqLsSLkS~D7=o8EFfgo!Fc}ya79h#M1sPD)Ffx?D0)&Bq
z0pURuQ5Kl)6vh;$6y`N7sNTn+mMMjmfLgfCDa<L%%a|A#R>OUY!%S9!W+L3mn!-lD
z{ma-G7*?~v{9VHk4`-L~!$cVvvN&Mu6vh<B6n4~bsbPrcgvr-1EC44-7#~L0Fl52i
zEMsI~SPhS}8isgKT!R&&ga#<~F-3X6LJSNHH4IsBeJQMXLX{V$12sjU*abH=g_W9S
z)i9#EYZ)s8!)kbHs9}ibgSjn*A(%mv!|x?1SY9eHFfhCXB@s==Tin^HMR}<?@x_@{
zsV^Ns5)B~2gn@xUlm8ZLacMzn(JhYn_{_Y_lKA){P@vu73QjF7P0cIGOw75(=A55b
zW|(q|(>JjqAU{9HFy$6UQetv;Qhr|QE!L9!qU4NQtW~K+`Ng-mQ%ZAlE8{cs3Q9|E
z3CE|Wmc*y!B$kw<=B1{9?Z3r`#%J@%%u7uyy2TQd>Jxg4DJ@x(`4&rYYEIfM?#lR-
z%-qzx;>`TKTb%Ln$vKI|#qsgQAonRKH2gAiv5HB_O3X{o4=BpdN=+__DM`%9F3BuQ
zjR7l7%S=uz$;{6yj>*kWNzIALOwP;GE2u02CCVxuXFoqZn3q8bwpfFaf#F93!(9&E
zPWB%5>l_l7I3zA|NL}HOy1*gzm5o7CdJgYh9=;yW35hd=RtR3;(Z0x|eT7FGB9=5m
zX@%ki9-WIkI#+mfKw>^K1ZOy2<dM6=BM0JoPQX+G5>sAbc!5XvB9HDB9$k<+ByAsg
z*(JHYFtAH<b+GjCeidVoQJiDFpm2railB?q23MpFI#{l7NZgQ9x-MsSNzQCT$rW?=
zi*g=U<UBwU(qGvaghi)-4fULmaz#vYh0#SGy(>I=AZy$^J$pPs43Ha?F7n7;;gJP#
zQN*<`@~B<mQTxcvCd~DPflV0X4)zZA4))KWG>1~wf`SkfF`uu2s|w_jHVc#>K^oA?
zQ&R>;h8l(iphOCmMJ8$(QFC}LBZ`S2o50Fz7*o*lM=hw_WlCW}%{7b+dECkjh783l
z<xCaKkqqUGj0}+sj35^<E`XPVU?ZVK3YrUo88n$or5P9)UaVXAeBE>fH%PKpcsil)
z`K*?wOO`*~&|k#Pz`)?A$yy}Bz`$^eJug2#y(lrINEpn6ru$ngplqngR3rluWsE2W
z`CmbyNS=X#p-7y8fuTwqoHQ}Bjv@mCL-BM528ITPr@TTvv7PZf@gEqNI90&J4FQqq
zypwomBwiFyxFVp?;doa-bPDf`#07#E1(dG{D0eu1V1=m#YxoG#`UOOQxj2a*KYn19
zV`C5y>@Vvon~>O9*;9FeL!wxafq`L)wj($5L2edDe)fYxjE?;5D5(S#(xABe+yYJ|
zH4F=&p^ZvmiG7eRsG=-Th=JMIV!wuA0g_s<GBAN*CQ}WlA;RPbjx<dca4dl%u1Faa
z6`-irWCq7lkt!&v*mCmIGfRr0#SbLHZn3B37nLU#rBrc)A`QyX07aAqIHGRw2=x1Q
z`OZ+h$Rl%wN2Y`EE{|Z3?+nEmffsq?ukgrsFy0Upoi08}e1_mesVPz&Y&ZBtJ9vvg
zrMD(Ks7%gF%uS7tzr__FpPQdjnge3<#K#wwCgwn7z@-^D-4ubWy~PVHCzDfia^mAP
z8H?mVjsfLB2S^eFRUJhj0)&gh7{G;I0|NwpV3K6z`oMsNFk@wv{J;Pw_*huQJ}}@T
zz<vN*1kDeSdW8dAUFhZI-D2|vdkCy#B?Ck`C=>kRuz`5huE>CafdS<7;tU1`h7Zh)
vjEpxJL@uDB8w@-RV0eQ;^a3jSz@pA5_JIMD=uqm>{0I{N0wT~<fL#IrGU7K)

literal 0
HcmV?d00001

diff --git a/classfications/model/icnn.py b/classfications/model/icnn.py
new file mode 100644
index 0000000..b25bcd4
--- /dev/null
+++ b/classfications/model/icnn.py
@@ -0,0 +1,44 @@
+from torch import nn
+import torch
+import torch.nn as nn
+
+class CNN(nn.Module):
+    def __init__(self, y_dimension):
+        super().__init__()
+
+        # 定义卷积层
+        self.backbone = nn.Sequential(
+            nn.Conv1d(1, 32, kernel_size=2),
+            nn.Conv1d(32, 64, kernel_size=2),
+            nn.MaxPool1d(2, 2),
+            nn.Conv1d(64, 64, kernel_size=2),
+            nn.Conv1d(64, 128, kernel_size=2),
+            nn.MaxPool1d(2, 2),
+        )
+
+        # 动态计算全连接层输入的维度
+        # 假设输入数据的每个样本有 X_dimension 个特征
+        self.dummy_input = torch.zeros(1, 1, 52)  # (batch_size, channels, X_dimension) 这里假设 X_dimension=52
+        self.flattened_size = self._get_flattened_size(self.dummy_input)
+
+        # 定义全连接层
+        self.fc = nn.Sequential(
+            nn.Linear(self.flattened_size, 64),
+            nn.ReLU(),
+            nn.Linear(64, 64),
+            nn.ReLU(),
+            nn.Linear(64, y_dimension)  # 输出维度为 y_dimension
+        )
+
+    def _get_flattened_size(self, X):
+        """计算 Flatten 后的大小"""
+        with torch.no_grad():
+            X = self.backbone(X)  # 前向传播通过卷积层
+            X = torch.flatten(X, 1)  # 扁平化,保留 batch_size 维度
+        return X.size(1)
+
+    def forward(self, X):
+        X = self.backbone(X)
+        X = torch.flatten(X, 1)  # 扁平化操作
+        logits = self.fc(X)
+        return logits
diff --git a/classfications/predict.py b/classfications/predict.py
new file mode 100644
index 0000000..d2cdb7b
--- /dev/null
+++ b/classfications/predict.py
@@ -0,0 +1,149 @@
+import os
+import json
+import torch
+import logging
+from flask import Flask, request, jsonify
+from torch.utils.data import DataLoader
+import pandas as pd
+from classfications import CNN, LoadData
+
+# 配置日志记录
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s [%(levelname)s] %(message)s",
+    handlers=[logging.StreamHandler()]
+)
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+# 模型处理类
+class ICNNModelHandler:
+    def __init__(self):
+        self.device = device
+        self.model = None
+        self.X_dimension = 52  # 输入特征维度
+        self.y_dimension = 15  # 输出类别数
+
+    def load_model(self, model_dir):
+        """
+        加载模型和相关信息。
+        """
+        logging.info("Loading model...")
+        model_path = os.path.join(model_dir, "trained_model.pth")
+
+        if not os.path.exists(model_path):
+            raise FileNotFoundError(f"Model file not found: {model_path}")
+
+        # 初始化模型
+        self.model = CNN(self.y_dimension)
+        self.model.load_state_dict(torch.load(model_path, map_location=self.device))
+        self.model.to(self.device)
+        self.model.eval()
+        logging.info("Model loaded successfully.")
+
+    def predict(self, input_data):
+        """
+        批量预测。
+        Args:
+            input_data (pd.DataFrame): 输入数据,包含特征列。
+        Returns:
+            list: 预测结果。
+        """
+        if self.X_dimension is None or self.y_dimension is None:
+            raise ValueError("Model metadata is not loaded. Cannot perform predictions.")
+
+        # 构建 DataLoader
+        dataset = LoadData(input_data, pd.Series([0] * len(input_data)))  # 假设标签为占位
+        dataloader = DataLoader(dataset, batch_size=64, shuffle=False)
+
+        all_predictions = []
+        with torch.no_grad():
+            for X, _ in dataloader:
+                X = X.to(self.device).float()
+                X = X.reshape(X.shape[0], 1, self.X_dimension)
+                outputs = self.model(X)
+                predictions = torch.argmax(outputs, dim=1)
+                all_predictions.extend(predictions.cpu().tolist())
+
+        return all_predictions
+
+
+# 初始化模型处理器
+model_handler = ICNNModelHandler()
+
+
+# SageMaker 必需函数
+def model_fn(model_dir):
+    """
+    加载模型到内存。
+    """
+    model_handler.load_model(model_dir)
+    return model_handler
+
+
+def input_fn(request_body, request_content_type):
+    """
+    解析输入数据。
+    """
+    if request_content_type == "application/json":
+        input_data = json.loads(request_body)
+        if isinstance(input_data, dict) and "data" in input_data:
+            return pd.DataFrame(input_data["data"])
+        elif isinstance(input_data, list):
+            return pd.DataFrame(input_data)
+        else:
+            raise ValueError("Invalid input format. Must be a JSON with 'data' key or a list of feature dictionaries.")
+    else:
+        raise ValueError(f"Unsupported content type: {request_content_type}")
+
+
+def predict_fn(input_data, model):
+    """
+    使用模型生成预测。
+    """
+    predictions = model.predict(input_data)
+    return predictions
+
+
+def output_fn(prediction, content_type):
+    """
+    格式化预测输出。
+    """
+    if content_type == "application/json":
+        return json.dumps({"predictions": prediction})
+    else:
+        raise ValueError(f"Unsupported content type: {content_type}")
+
+
+# Flask 应用用于本地测试
+if __name__ == "__main__":
+    app = Flask(__name__)
+
+    @app.route("/ping", methods=["GET"])
+    def ping():
+        """
+        健康检查。
+        """
+        status = 200 if model_handler.model else 500
+        return jsonify({"status": "ok" if status == 200 else "error"}), status
+
+    @app.route("/invocations", methods=["POST"])
+    def invocations():
+        """
+        处理推理请求。
+        """
+        try:
+            input_data = input_fn(request.data, request.content_type)
+            predictions = predict_fn(input_data, model_handler)
+            return output_fn(predictions, "application/json")
+        except Exception as e:
+            logging.error(f"Prediction error: {e}")
+            return jsonify({"error": str(e)}), 500
+
+    # 加载模型
+    logging.info("Starting model loading...")
+    model_handler.load_model("/opt/ml/model")
+
+    # 启动 Flask 服务
+    app.run(host="0.0.0.0", port=8080)
\ No newline at end of file
diff --git a/classfications/test/test1.py b/classfications/test/test1.py
new file mode 100644
index 0000000..0ccee23
--- /dev/null
+++ b/classfications/test/test1.py
@@ -0,0 +1,44 @@
+import boto3
+import json
+
+# 创建 SageMaker Runtime 客户端
+sagemaker_client = boto3.client(
+    'sagemaker-runtime',
+    region_name='ap-south-1',
+    aws_access_key_id='AKIA46ZDFEKUZGSUBIOJ',  # 替换为您的 Access Key ID
+    aws_secret_access_key='/hJ2+4k+MkaZb7XtlR0ydmP2uLWBUqBk7hoRTrgi'  # 替换为您的 Secret Access Key
+)
+
+# 定义推理请求的端点名称
+endpoint_name = "icnn-predict-endpoint"
+data = [
+    [
+        0.001220722, 0.003738858, 9.10E-06, 2.06E-05, 2.02E-06, 1.77E-05, 0.000805802, 0,
+        0.001458824, 0.001440329, 0.299027138, 0, 0.333505732, 0.309826049, 0.111932217,
+        0.333336677, 0.000467452, 0.00185729, 0.003715658, 5.08E-07, 1.01E-05, 1.94E-05,
+        1.97E-05, 5.08E-07, 0.000747608, 0.002358479, 0.00371555, 3.92E-07, 0, 0, 0, 0,
+        2.23E-06, 6.69E-06, 0.012820513, 0.331992009, 0.001458824, 0.333505732, 0.34859161,
+        0.451932586, 0, 0, 0, 1, 0, 0, 0, 0, 9.10E-06, 2.02E-06, 2.06E-05, 1.77E-05
+    ]
+]
+
+print(len(data[0]))
+# 准备输入数据
+payload = {
+    "data":data
+}
+
+# 发送推理请求
+try:
+    response = sagemaker_client.invoke_endpoint(
+        EndpointName=endpoint_name,
+        ContentType="application/json",
+        Body=json.dumps(payload)
+    )
+
+    # 解析响应
+    result = response["Body"].read().decode("utf-8")
+    print("Inference result:", result)
+
+except Exception as e:
+    print("Error invoking endpoint:", str(e))
-- 
GitLab