Skip to content

Commit b98ab85

Browse files
committed
more detailed training process to help development of run.py
1 parent cb83810 commit b98ab85

8 files changed

+22
-43
lines changed

nids/__init__.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# nids/__init__.py
22

3+
# Import necessary modules and functions
34
from .data_preprocessing import load_and_preprocess_data
4-
from .model import Net, train_model
5+
from .model import train_model
56
from .logging import setup_logging, log_prediction
6-
from .prediction import run_prediction
7-
from .retraining import retrain

nids/data_preprocessing.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from sklearn.preprocessing import StandardScaler
66
from sklearn.model_selection import train_test_split
77
import numpy as np
8+
import pickle
89

910
def load_and_preprocess_data(csv_files_path):
1011
# Load all CSV files
@@ -35,10 +36,13 @@ def load_and_preprocess_data(csv_files_path):
3536
raise ValueError("The target label column is not found in the dataset.")
3637

3738
# Encode categorical variables
38-
data[label_column] = data[label_column].astype('category').cat.codes
39+
data[label_column] = data[label_column].astype('category')
40+
class_mapping = dict(enumerate(data[label_column].cat.categories))
41+
data[label_column] = data[label_column].cat.codes
3942

4043
# Print unique values of the target labels
4144
print(f"Unique target labels: {data[label_column].unique()}")
45+
print(f"Class mapping: {class_mapping}")
4246

4347
# Replace infinite values with NaN
4448
data.replace([np.inf, -np.inf], np.nan, inplace=True)
@@ -61,4 +65,16 @@ def load_and_preprocess_data(csv_files_path):
6165
# Split data into training and testing sets
6266
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
6367

64-
return X_train, X_test, y_train, y_test, scaler
68+
# Save the class mapping, number of features, and feature names
69+
metadata = {
70+
'num_features': X_train.shape[1],
71+
'num_classes': len(class_mapping),
72+
'class_mapping': class_mapping,
73+
'feature_names': list(X.columns)
74+
}
75+
with open('nids/model_metadata.pkl', 'wb') as f:
76+
pickle.dump(metadata, f)
77+
78+
print("Metadata (number of features, classes, class mapping and faeture names) saved. ")
79+
80+
return X_train, X_test, y_train, y_test, scaler

nids/model.pth

0 Bytes
Binary file not shown.

nids/model_metadata.pkl

1.78 KB
Binary file not shown.

nids/training_test_accuracy.png

8 KB
Loading

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
pandas
22
scikit-learn
33
torch
4-
kafka-python
4+
scapy

retrain_and_run.py

-28
This file was deleted.

train.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,5 @@
1919
torch.save(model.state_dict(), 'nids/model.pth')
2020
with open('nids/scaler.pkl', 'wb') as f:
2121
pickle.dump(scaler, f)
22-
23-
# Save number of features and classes
24-
with open('nids/model_metadata.pkl', 'wb') as f:
25-
metadata = {
26-
'num_features': X_train.shape[1],
27-
'num_classes': num_classes
28-
}
29-
pickle.dump(metadata, f)
3022

31-
print("Model, scaler, and metadata (number of features and classes) saved.")
23+
print("Model and scaler saved.")

0 commit comments

Comments
 (0)