From 48d5f3f180687d9b6ada4bfbbec93bd6294f6bd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=89=8D?= <1104239712@qq.com> Date: Sun, 9 Feb 2025 15:26:06 +0800 Subject: [PATCH] icnnprocess --- classfications/__init__.py | 11 ++ classfications/__main__.py | 179 ++++++++++++++++++ .../__pycache__/__init__.cpython-311.pyc | Bin 0 -> 493 bytes .../__pycache__/train_data.cpython-311.pyc | Bin 0 -> 1338 bytes classfications/datas/train_data.py | 16 ++ classfications/evalute.py | 144 ++++++++++++++ .../model/__pycache__/icnn.cpython-311.pyc | Bin 0 -> 2839 bytes classfications/model/icnn.py | 44 +++++ classfications/predict.py | 149 +++++++++++++++ classfications/test/test1.py | 44 +++++ 10 files changed, 587 insertions(+) create mode 100644 classfications/__init__.py create mode 100644 classfications/__main__.py create mode 100644 classfications/__pycache__/__init__.cpython-311.pyc create mode 100644 classfications/datas/__pycache__/train_data.cpython-311.pyc create mode 100644 classfications/datas/train_data.py create mode 100644 classfications/evalute.py create mode 100644 classfications/model/__pycache__/icnn.cpython-311.pyc create mode 100644 classfications/model/icnn.py create mode 100644 classfications/predict.py create mode 100644 classfications/test/test1.py diff --git a/classfications/__init__.py b/classfications/__init__.py new file mode 100644 index 0000000..ee2bb24 --- /dev/null +++ b/classfications/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +"""Top-level package for ICNN.""" + +__author__ = 'Qian Wang, Inc.' +__email__ = 'info@sdv.dev' +__version__ = '0.10.2.dev0' + +from classfications.model . icnn import CNN +from classfications.datas.train_data import LoadData +__all__ = ('CNN','LoadData') \ No newline at end of file diff --git a/classfications/__main__.py b/classfications/__main__.py new file mode 100644 index 0000000..c309926 --- /dev/null +++ b/classfications/__main__.py @@ -0,0 +1,179 @@ +import argparse +import logging +import os + +import torch +from matplotlib import pyplot as plt +from torch.utils.data import DataLoader + +from classfications import LoadData +from classfications import CNN +import pandas as pd +import json +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler()] +) + +# 设定设备 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def loss_value_plot(losses, iter): + plt.figure() + plt.plot([i for i in range(1, iter + 1)], losses) + plt.xlabel('Iterations (×100)') + plt.ylabel('Loss Value') + plt.show() + +def train(model, optimizer, loss_fn, train_dataloader, epochs, X_dimension): + losses = [] + iter = 0 + + for epoch in range(epochs): + logging.info(f"Epoch {epoch + 1}/{epochs} started") + for i, (X, y) in enumerate(train_dataloader): + X, y = X.to(device), y.to(device) + X = X.reshape(X.shape[0], 1, X_dimension) # Dynamic dimension + y_pred = model(X) + loss = loss_fn(y_pred, y) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if i % 100 == 0: + logging.info(f"Loss: {loss.item():.4f}\t[{(i + 1) * len(X)}/{len(train_dataloader.dataset)}]") + + iter += 1 + losses.append(loss.item()) + + return losses, iter + + +# 定义测试函数 +def test(model, test_dataloader, loss_fn, X_dimension): + positive = 0 + negative = 0 + loss_sum = 0 + iter = 0 + + with torch.no_grad(): + for X, y in test_dataloader: + X, y = X.to(device), y.to(device) + X = X.reshape(X.shape[0], 1, X_dimension) # Dynamic dimension + y_pred = model(X) + loss = loss_fn(y_pred, y) + loss_sum += loss.item() + iter += 1 + + for item in zip(y_pred, y): + if torch.argmax(item[0]) == item[1]: + positive += 1 + else: + negative += 1 + + acc = positive / (positive + negative) + avg_loss = loss_sum / iter + logging.info(f"Accuracy: {acc:.4f}") + logging.info(f"Average Loss: {avg_loss:.4f}") + return acc, avg_loss + +class HyperparametersParser: + def __init__(self, hyperparameters_file): + self.hyperparameters_file = hyperparameters_file + self.hyperparameters = {} + + def parse_hyperparameters(self): + """ + 解析超参数文件并进行类型转换。 + """ + try: + with open(self.hyperparameters_file, "r") as f: + injected_hyperparameters = json.load(f) + + logging.info("Parsing hyperparameters...") + for key, value in injected_hyperparameters.items(): + logging.info(f"Raw hyperparameter - Key: {key}, Value: {value}") + self.hyperparameters[key] = self._convert_type(value) + except FileNotFoundError: + logging.warning(f"Hyperparameters file {self.hyperparameters_file} not found. Using default parameters.") + except Exception as e: + logging.error(f"Error parsing hyperparameters: {e}") + return self.hyperparameters + + @staticmethod + def _convert_type(value): + """ + 将超参数值转换为合适的 Python 类型(字符串、整数、浮点数或布尔值)。 + """ + if isinstance(value, str): + # 判断是否为布尔值 + if value.lower() == "true": + return True + elif value.lower() == "false": + return False + + # 判断是否为整数或浮点数 + try: + if "." in value: + return float(value) + else: + return int(value) + except ValueError: + pass # 如果无法转换,保留字符串 + return value +# 读取训练数据的输入路径和输出路径 + +if __name__ == '__main__': + # 超参数文件路径 + hyperparameters_file = "/opt/ml/input/config/hyperparameters.json" + + # 解析超参数 + parser = HyperparametersParser(hyperparameters_file) + hyperparameters = parser.parse_hyperparameters() + + # 提取超参数值 + input_data_path = hyperparameters.get("input_data", r"D:/djangoProject/talktive/classfications/sample_dataas/processed_data.csv") + output_path = hyperparameters.get("output_path", r"D:/djangoProject/talktive/classfications/sample_dataas") + epochs = hyperparameters.get("epochs", 10) + + # 加载数据 + logging.info("Loading data...") + df = pd.read_csv(input_data_path) + X = df.drop(columns=["label"]) # Assuming "label" is the column to predict + y = df["label"] + + # 自动获取输入特征的维度 + X_dimension = X.shape[1] + logging.info(f"X_dimension (number of features): {X_dimension}") + + # 创建数据集和数据加载器 + dataset = LoadData(X, y) + train_loader = DataLoader(dataset, batch_size=64, shuffle=True) + + # 计算 y_dimension + y_dimension = len(y.unique()) + logging.info(f"y_dimension (number of classes): {y_dimension}") + + # 载入模型和设置训练 + model = CNN(y_dimension) # 传入 y_dimension + model.to(device) + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + loss_fn = torch.nn.CrossEntropyLoss() + + logging.info(f"Training the model for {epochs} epochs...") + + # 训练模型 + losses, iter = train(model, optimizer, loss_fn, train_loader, epochs, X_dimension) + + # 绘制损失曲线 + loss_value_plot(losses, iter) + + # 保存训练后的模型 + os.makedirs(output_path, exist_ok=True) + model_save_path = os.path.join(output_path, "trained_model.pth") + torch.save(model.state_dict(), model_save_path) + logging.info(f"Model saved to {model_save_path}") diff --git a/classfications/__pycache__/__init__.cpython-311.pyc b/classfications/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..891a91dc1630421dfd9b5b6b3271f5a538d30605 GIT binary patch literal 493 zcmZ3^%ge>Uz`(#`+np}O$iVOz#DQTZDC4sc0|Uc!h7^V<h7`sq#uTO~rWEEV<`k9` z)*O~x)+km+h7^_*wj8!x_9%8nh7|TFjuehy22IW?>5%*a-JI02)EtF^#N_P6^i+kk z{2~QUXFor^D*nLC#5{%Y#JqGJ1<$->y(*r}ytI6W;*>JIl+>~+ZUa3-13e=U+u$Y0 zC7O)4n4SInG#PJk_~a+1xFnV&YBCiuGcYg|u`n<&_-S(9Vvmnc$xn`tzr__FpIBOw zkzW)ae~U9dJ~cNnGbcX&7I%DnS!z*nW`16L{4L4koW$bdw9MqhlFa<PV!hn_l++x( z%;da0u%Wl)(d1JSOA?FqN{SLQ^Ws5lu&p2~6LWIn<5x0#208MViHlWCN>*ZCdVWAr zepYI7NlZy%PIgIVS!xWbc`@<vnR%Hd@$q^EmA^P_a`RJ4b5iY!_!$@&7#SECiai+^ z7(OsFGBSQ(V_;Oez@YemSFVBU29Hz&=M6sT2JQz8+80pK4F;17sOSclOanUz7V$DL GFaQAjqL4%Y literal 0 HcmV?d00001 diff --git a/classfications/datas/__pycache__/train_data.cpython-311.pyc b/classfications/datas/__pycache__/train_data.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2711635f9b33c5b84dc358d65e6b104833db6b4 GIT binary patch literal 1338 zcmZ3^%ge>Uz`zg~(Vf1Rk%8echy%l{P{wC71_p-d3@HpLj5!QZAet$MF_$TdiIIVc z!JQ$6xrHHxC6#FzGXukFW~d&9DCQK#U<OUrmmmQ@O~za7E{P?H#i=F9AORR=gfc!W zfUQVnh+<4(h+;}%Yyla-)WQ(O0x~{|HH8Ia*ewp9{KOQHLCJ877#J9s7#J8p{Le~Y zQ)(D%7~-J@GceRJ)iA`v>|sb@2xeHx=%>kai!tIBW2GkZEtcZcoU|ei1_p*AP6h^s zVvs=!3JSk$T&!YJvJ&&s^8<?Vvr>~wVoDNovP&|{Qe%>H5{rw|GLsWaGV}9_V^TmK zizz8e%*=}ivGocnZ*j!OXXa&=#K%_&!knxJ6Ji5dBf`MI(7<p*P^g2YhYd<~uywF~ z201(#?p~<RK#uyX1@>JDl3Wc#7M!<?k%3_~oEOYc#LmFLpviQLIVUv_>^)7!B2Zix zgVZP#aWgP5++vT9&q>XTkFOFyawJ#;tWp5%#=HDN6HKO<byoLOU*M1`;$>i9NCt&C z$e|$20(Il(QyA_;xUPmF3$7;%<QtGqh8jjmFbyI~Km-(Lfy@N6*Dx+)VqjPe*IUbk z?0*#VVab(&2=hTvfZcrN8m1cNG^St%O(syPWlbrmEJ)R4zr|XTUzD72i>)LzuQ<O5 z9H6&YGIR2iZ?UK4<R_LG8-bJ0EtZ`8ymU<#a8Tc3&CE+ltpF<q%NB$DrJ$ey4Sw$U z`1I70%#zgH`1mRjq$q%kf&#nP6A}$Pf;~PnMCORikzNpbMN)gK=@#1yJdPK69Ix;= zb}-(>l(xFSV|S6q?h21x2jdNXfeDHoB^})0`18|b1EtWs#N5>Q_*-1@@wxdar8yur zPkek~X<`mUricfWD?q8B2oz5%8NiysNv8-D+F%02D0T!pv4H^sKQJ+|3V&dL6C8}J z(jOS$1RFD}$Oi^Y0%AJIJWZA&VNm)M00+8WX-Q^Iu^uRE6@iit*xDjlklR4v19s0Z z4jYIc?26<W7#Kilyf}n`f#Cx)BO~Jt2A&2myuqMy0Tq2<<7O23z<^0~lz#+?egP3^ I@?duY00m_b8vp<R literal 0 HcmV?d00001 diff --git a/classfications/datas/train_data.py b/classfications/datas/train_data.py new file mode 100644 index 0000000..caa6bd8 --- /dev/null +++ b/classfications/datas/train_data.py @@ -0,0 +1,16 @@ +import torch +from torch.utils.data import Dataset + + +class LoadData(Dataset): + def __init__(self, X, y): + self.X = X + self.y = y + + def __len__(self): + return len(self.X) + + def __getitem__(self, index): + X = torch.tensor(self.X.iloc[index], dtype=torch.float32) # Ensure float32 + y = torch.tensor(self.y.iloc[index], dtype=torch.long) # Ensure long type for labels + return X, y diff --git a/classfications/evalute.py b/classfications/evalute.py new file mode 100644 index 0000000..f6bd2a7 --- /dev/null +++ b/classfications/evalute.py @@ -0,0 +1,144 @@ +import argparse +import logging +import os +import tarfile + +import torch +import pandas as pd +from torch.utils.data import DataLoader +from classfications import LoadData +from classfications import CNN +import json +from matplotlib import pyplot as plt +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler()] +) + +# 设定设备 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def loss_value_plot(losses, iter): + plt.figure() + plt.plot([i for i in range(1, iter + 1)], losses) + plt.xlabel('Iterations (×100)') + plt.ylabel('Loss Value') + plt.show() + + +def extract_model_and_tokenizer(tar_path, extract_path="/tmp"): + """ + 从 tar.gz 文件中提取模型文件和分词器目录。 + + Args: + tar_path (str): 模型 tar.gz 文件路径。 + extract_path (str): 提取文件的目录。 + + Returns: + model_path (str): 提取后的模型文件路径 (.pth)。 + tokenizer_path (str): 提取后的分词器目录路径。 + """ + with tarfile.open(tar_path, "r:gz") as tar: + tar.extractall(path=extract_path) + + model_path = None + + + for root, dirs, files in os.walk(extract_path): + for file in files: + if file.endswith(".pth"): + model_path = os.path.join(root, file) + + + if not model_path: + raise FileNotFoundError("No .pth model file found in the extracted tar.gz.") + + + return model_path + + +def evaluate(model, validate_dataloader, loss_fn, X_dimension): + positive = 0 + negative = 0 + loss_sum = 0 + iter = 0 + with torch.no_grad(): + for X, y in validate_dataloader: + X, y = X.to(device), y.to(device) + X = X.reshape(X.shape[0], 1, X_dimension) # Dynamic dimension + y_pred = model(X) + loss = loss_fn(y_pred, y) + loss_sum += loss.item() + iter += 1 + + for item in zip(y_pred, y): + if torch.argmax(item[0]) == item[1]: + positive += 1 + else: + negative += 1 + + accuracy = positive / (positive + negative) + avg_loss = loss_sum / iter + logging.info(f"Evaluation Accuracy: {accuracy:.4f}") + logging.info(f"Average Evaluation Loss: {avg_loss:.4f}") + + evaluation_result = {"metrics": {"accuracy": accuracy, "avg_loss": avg_loss}} + return evaluation_result + + +# 读取训练数据的输入路径和输出路径 +def read_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_data", type=str, + default=r"D:\djangoProject\talktive\classfications\sample_dataas\processed_data.csv", + required=False, help="Input CSV file path") + parser.add_argument("--output_path", type=str, default=r"D:\djangoProject\talktive\classfications\sample_dataas", + required=False, help="Path to save the model and results") + parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model file") + parser.add_argument("--validate_path", type=str, required=True, help="Path to the validation dataset") + args = parser.parse_args() + return args + + +if __name__ == '__main__': + # 读取参数 + args = read_args() + # 加载数据 + logging.info("Loading training data...") + inputpath=os.path.join(args.input_data, "generated_samples.csv") + df = pd.read_csv(inputpath) + X = df.drop(columns=["label"]) # Assuming "label" is the column to predict + X_dimension = X.shape[1] + y = df["label"] + # 计算 y_dimension + y_dimension = len(y.unique()) + logging.info(f"y_dimension (number of classes): {y_dimension}") + # 载入模型 + model = CNN(y_dimension) + model_path=extract_model_and_tokenizer(args.model_path) + model.load_state_dict(torch.load(model_path)) + model.to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + loss_fn = torch.nn.CrossEntropyLoss() + epochs = 10 # Number of epochs + logging.info(f"Training the model for {epochs} epochs...") + # 训练模型 + # 加载验证数据并评估 + logging.info("Loading validation data...") + validate_path = os.path.join(args.validate_path,"generated_samples.csv") + validate_df = pd.read_csv(validate_path) + X_validate = validate_df.drop(columns=["label"]) # Assuming "label" is the column to predict + y_validate = validate_df["label"] + # 创建验证数据集和加载器 + validate_dataset = LoadData(X_validate, y_validate) + validate_loader = DataLoader(validate_dataset, batch_size=64, shuffle=False) + # 进行评估 + evaluation_result = evaluate(model, validate_loader, loss_fn, X_dimension) + # 保存评估结果 + output_file = os.path.join(args.output_path, "evaluation.json") + with open(output_file, "w") as f: + json.dump(evaluation_result, f) + logging.info(f"Evaluation results saved to {output_file}") diff --git a/classfications/model/__pycache__/icnn.cpython-311.pyc b/classfications/model/__pycache__/icnn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acb377ec8c2255973d05e218c45740a405763457 GIT binary patch literal 2839 zcmZ3^%ge>Uz`(F0vOE1FD+9x05C?`?p^VQy3=9m@8B!Qh7;_kM8KW2(L2RZRhA0Tl z1QBCmaA!ziZed7aNoCAphU#Tl#>~L58qSMiNnr|R&}4lHlGbFr#gv!lmkeUVFcXyV zS&M;zp`BqmLn=cQV+unQQwKvDV+vynM-+1kQwu{B3&^G@)((aW#wfO622GY*%+7v( z$uR98Q@9xz7(f<(e#XSWFqLsSLkS~D7=o8EFfgo!Fc}ya79h#M1sPD)Ffx?D0)&Bq z0pURuQ5Kl)6vh;$6y`N7sNTn+mMMjmfLgfCDa<L%%a|A#R>OUY!%S9!W+L3mn!-lD z{ma-G7*?~v{9VHk4`-L~!$cVvvN&Mu6vh<B6n4~bsbPrcgvr-1EC44-7#~L0Fl52i zEMsI~SPhS}8isgKT!R&&ga#<~F-3X6LJSNHH4IsBeJQMXLX{V$12sjU*abH=g_W9S z)i9#EYZ)s8!)kbHs9}ibgSjn*A(%mv!|x?1SY9eHFfhCXB@s==Tin^HMR}<?@x_@{ zsV^Ns5)B~2gn@xUlm8ZLacMzn(JhYn_{_Y_lKA){P@vu73QjF7P0cIGOw75(=A55b zW|(q|(>JjqAU{9HFy$6UQetv;Qhr|QE!L9!qU4NQtW~K+`Ng-mQ%ZAlE8{cs3Q9|E z3CE|Wmc*y!B$kw<=B1{9?Z3r`#%J@%%u7uyy2TQd>Jxg4DJ@x(`4&rYYEIfM?#lR- z%-qzx;>`TKTb%Ln$vKI|#qsgQAonRKH2gAiv5HB_O3X{o4=BpdN=+__DM`%9F3BuQ zjR7l7%S=uz$;{6yj>*kWNzIALOwP;GE2u02CCVxuXFoqZn3q8bwpfFaf#F93!(9&E zPWB%5>l_l7I3zA|NL}HOy1*gzm5o7CdJgYh9=;yW35hd=RtR3;(Z0x|eT7FGB9=5m zX@%ki9-WIkI#+mfKw>^K1ZOy2<dM6=BM0JoPQX+G5>sAbc!5XvB9HDB9$k<+ByAsg z*(JHYFtAH<b+GjCeidVoQJiDFpm2railB?q23MpFI#{l7NZgQ9x-MsSNzQCT$rW?= zi*g=U<UBwU(qGvaghi)-4fULmaz#vYh0#SGy(>I=AZy$^J$pPs43Ha?F7n7;;gJP# zQN*<`@~B<mQTxcvCd~DPflV0X4)zZA4))KWG>1~wf`SkfF`uu2s|w_jHVc#>K^oA? zQ&R>;h8l(iphOCmMJ8$(QFC}LBZ`S2o50Fz7*o*lM=hw_WlCW}%{7b+dECkjh783l z<xCaKkqqUGj0}+sj35^<E`XPVU?ZVK3YrUo88n$or5P9)UaVXAeBE>fH%PKpcsil) z`K*?wOO`*~&|k#Pz`)?A$yy}Bz`$^eJug2#y(lrINEpn6ru$ngplqngR3rluWsE2W z`CmbyNS=X#p-7y8fuTwqoHQ}Bjv@mCL-BM528ITPr@TTvv7PZf@gEqNI90&J4FQqq zypwomBwiFyxFVp?;doa-bPDf`#07#E1(dG{D0eu1V1=m#YxoG#`UOOQxj2a*KYn19 zV`C5y>@Vvon~>O9*;9FeL!wxafq`L)wj($5L2edDe)fYxjE?;5D5(S#(xABe+yYJ| zH4F=&p^ZvmiG7eRsG=-Th=JMIV!wuA0g_s<GBAN*CQ}WlA;RPbjx<dca4dl%u1Faa z6`-irWCq7lkt!&v*mCmIGfRr0#SbLHZn3B37nLU#rBrc)A`QyX07aAqIHGRw2=x1Q z`OZ+h$Rl%wN2Y`EE{|Z3?+nEmffsq?ukgrsFy0Upoi08}e1_mesVPz&Y&ZBtJ9vvg zrMD(Ks7%gF%uS7tzr__FpPQdjnge3<#K#wwCgwn7z@-^D-4ubWy~PVHCzDfia^mAP z8H?mVjsfLB2S^eFRUJhj0)&gh7{G;I0|NwpV3K6z`oMsNFk@wv{J;Pw_*huQJ}}@T zz<vN*1kDeSdW8dAUFhZI-D2|vdkCy#B?Ck`C=>kRuz`5huE>CafdS<7;tU1`h7Zh) vjEpxJL@uDB8w@-RV0eQ;^a3jSz@pA5_JIMD=uqm>{0I{N0wT~<fL#IrGU7K) literal 0 HcmV?d00001 diff --git a/classfications/model/icnn.py b/classfications/model/icnn.py new file mode 100644 index 0000000..b25bcd4 --- /dev/null +++ b/classfications/model/icnn.py @@ -0,0 +1,44 @@ +from torch import nn +import torch +import torch.nn as nn + +class CNN(nn.Module): + def __init__(self, y_dimension): + super().__init__() + + # 定义卷积层 + self.backbone = nn.Sequential( + nn.Conv1d(1, 32, kernel_size=2), + nn.Conv1d(32, 64, kernel_size=2), + nn.MaxPool1d(2, 2), + nn.Conv1d(64, 64, kernel_size=2), + nn.Conv1d(64, 128, kernel_size=2), + nn.MaxPool1d(2, 2), + ) + + # 动态计算全连接层输入的维度 + # 假设输入数据的每个样本有 X_dimension 个特征 + self.dummy_input = torch.zeros(1, 1, 52) # (batch_size, channels, X_dimension) 这里假设 X_dimension=52 + self.flattened_size = self._get_flattened_size(self.dummy_input) + + # 定义全连接层 + self.fc = nn.Sequential( + nn.Linear(self.flattened_size, 64), + nn.ReLU(), + nn.Linear(64, 64), + nn.ReLU(), + nn.Linear(64, y_dimension) # 输出维度为 y_dimension + ) + + def _get_flattened_size(self, X): + """计算 Flatten 后的大小""" + with torch.no_grad(): + X = self.backbone(X) # 前向传播通过卷积层 + X = torch.flatten(X, 1) # 扁平化,保留 batch_size 维度 + return X.size(1) + + def forward(self, X): + X = self.backbone(X) + X = torch.flatten(X, 1) # 扁平化操作 + logits = self.fc(X) + return logits diff --git a/classfications/predict.py b/classfications/predict.py new file mode 100644 index 0000000..d2cdb7b --- /dev/null +++ b/classfications/predict.py @@ -0,0 +1,149 @@ +import os +import json +import torch +import logging +from flask import Flask, request, jsonify +from torch.utils.data import DataLoader +import pandas as pd +from classfications import CNN, LoadData + +# 配置日志记录 +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler()] +) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +# 模型处理类 +class ICNNModelHandler: + def __init__(self): + self.device = device + self.model = None + self.X_dimension = 52 # 输入特征维度 + self.y_dimension = 15 # 输出类别数 + + def load_model(self, model_dir): + """ + 加载模型和相关信息。 + """ + logging.info("Loading model...") + model_path = os.path.join(model_dir, "trained_model.pth") + + if not os.path.exists(model_path): + raise FileNotFoundError(f"Model file not found: {model_path}") + + # 初始化模型 + self.model = CNN(self.y_dimension) + self.model.load_state_dict(torch.load(model_path, map_location=self.device)) + self.model.to(self.device) + self.model.eval() + logging.info("Model loaded successfully.") + + def predict(self, input_data): + """ + 批量预测。 + Args: + input_data (pd.DataFrame): 输入数据,包含特征列。 + Returns: + list: 预测结果。 + """ + if self.X_dimension is None or self.y_dimension is None: + raise ValueError("Model metadata is not loaded. Cannot perform predictions.") + + # 构建 DataLoader + dataset = LoadData(input_data, pd.Series([0] * len(input_data))) # 假设标签为占位 + dataloader = DataLoader(dataset, batch_size=64, shuffle=False) + + all_predictions = [] + with torch.no_grad(): + for X, _ in dataloader: + X = X.to(self.device).float() + X = X.reshape(X.shape[0], 1, self.X_dimension) + outputs = self.model(X) + predictions = torch.argmax(outputs, dim=1) + all_predictions.extend(predictions.cpu().tolist()) + + return all_predictions + + +# 初始化模型处理器 +model_handler = ICNNModelHandler() + + +# SageMaker 必需函数 +def model_fn(model_dir): + """ + 加载模型到内存。 + """ + model_handler.load_model(model_dir) + return model_handler + + +def input_fn(request_body, request_content_type): + """ + 解析输入数据。 + """ + if request_content_type == "application/json": + input_data = json.loads(request_body) + if isinstance(input_data, dict) and "data" in input_data: + return pd.DataFrame(input_data["data"]) + elif isinstance(input_data, list): + return pd.DataFrame(input_data) + else: + raise ValueError("Invalid input format. Must be a JSON with 'data' key or a list of feature dictionaries.") + else: + raise ValueError(f"Unsupported content type: {request_content_type}") + + +def predict_fn(input_data, model): + """ + 使用模型生成预测。 + """ + predictions = model.predict(input_data) + return predictions + + +def output_fn(prediction, content_type): + """ + 格式化预测输出。 + """ + if content_type == "application/json": + return json.dumps({"predictions": prediction}) + else: + raise ValueError(f"Unsupported content type: {content_type}") + + +# Flask 应用用于本地测试 +if __name__ == "__main__": + app = Flask(__name__) + + @app.route("/ping", methods=["GET"]) + def ping(): + """ + 健康检查。 + """ + status = 200 if model_handler.model else 500 + return jsonify({"status": "ok" if status == 200 else "error"}), status + + @app.route("/invocations", methods=["POST"]) + def invocations(): + """ + 处理推理请求。 + """ + try: + input_data = input_fn(request.data, request.content_type) + predictions = predict_fn(input_data, model_handler) + return output_fn(predictions, "application/json") + except Exception as e: + logging.error(f"Prediction error: {e}") + return jsonify({"error": str(e)}), 500 + + # 加载模型 + logging.info("Starting model loading...") + model_handler.load_model("/opt/ml/model") + + # 启动 Flask 服务 + app.run(host="0.0.0.0", port=8080) \ No newline at end of file diff --git a/classfications/test/test1.py b/classfications/test/test1.py new file mode 100644 index 0000000..0ccee23 --- /dev/null +++ b/classfications/test/test1.py @@ -0,0 +1,44 @@ +import boto3 +import json + +# 创建 SageMaker Runtime 客户端 +sagemaker_client = boto3.client( + 'sagemaker-runtime', + region_name='ap-south-1', + aws_access_key_id='AKIA46ZDFEKUZGSUBIOJ', # 替换为您的 Access Key ID + aws_secret_access_key='/hJ2+4k+MkaZb7XtlR0ydmP2uLWBUqBk7hoRTrgi' # 替换为您的 Secret Access Key +) + +# 定义推理请求的端点名称 +endpoint_name = "icnn-predict-endpoint" +data = [ + [ + 0.001220722, 0.003738858, 9.10E-06, 2.06E-05, 2.02E-06, 1.77E-05, 0.000805802, 0, + 0.001458824, 0.001440329, 0.299027138, 0, 0.333505732, 0.309826049, 0.111932217, + 0.333336677, 0.000467452, 0.00185729, 0.003715658, 5.08E-07, 1.01E-05, 1.94E-05, + 1.97E-05, 5.08E-07, 0.000747608, 0.002358479, 0.00371555, 3.92E-07, 0, 0, 0, 0, + 2.23E-06, 6.69E-06, 0.012820513, 0.331992009, 0.001458824, 0.333505732, 0.34859161, + 0.451932586, 0, 0, 0, 1, 0, 0, 0, 0, 9.10E-06, 2.02E-06, 2.06E-05, 1.77E-05 + ] +] + +print(len(data[0])) +# 准备输入数据 +payload = { + "data":data +} + +# 发送推理请求 +try: + response = sagemaker_client.invoke_endpoint( + EndpointName=endpoint_name, + ContentType="application/json", + Body=json.dumps(payload) + ) + + # 解析响应 + result = response["Body"].read().decode("utf-8") + print("Inference result:", result) + +except Exception as e: + print("Error invoking endpoint:", str(e)) -- GitLab