|
| 1 | +# nids/retraining.py |
| 2 | + |
| 3 | +import pandas as pd |
| 4 | +import pickle |
| 5 | +import torch |
| 6 | +import torch.nn as nn |
| 7 | +import torch.optim as optim |
| 8 | +from torch.utils.data import DataLoader, TensorDataset |
| 9 | +from sklearn.model_selection import train_test_split |
| 10 | +from sklearn.preprocessing import StandardScaler |
| 11 | +from nids.model import Net |
| 12 | + |
| 13 | +def load_and_preprocess_data(file_path): |
| 14 | + data = pd.read_csv(file_path) |
| 15 | + data = data.drop(columns=['timestamp']) |
| 16 | + X = data.drop(columns=['label']) |
| 17 | + y = data['label'] |
| 18 | + X = pd.get_dummies(X) |
| 19 | + scaler = StandardScaler() |
| 20 | + X = scaler.fit_transform(X) |
| 21 | + return train_test_split(X, y, test_size=0.2, random_state=42), scaler |
| 22 | + |
| 23 | +def retrain_model(model, train_loader): |
| 24 | + criterion = nn.CrossEntropyLoss() |
| 25 | + optimizer = optim.Adam(model.parameters(), lr=0.001) |
| 26 | + for epoch in range(10): # Fewer epochs for incremental training |
| 27 | + for inputs, labels in train_loader: |
| 28 | + optimizer.zero_grad() |
| 29 | + outputs = model(inputs) |
| 30 | + loss = criterion(outputs, labels) |
| 31 | + loss.backward() |
| 32 | + optimizer.step() |
| 33 | + return model |
| 34 | + |
| 35 | +def retrain(csv_files_path): |
| 36 | + (X_train, X_test, y_train, y_test), scaler = load_and_preprocess_data(csv_files_path) |
| 37 | + |
| 38 | + # Convert data to PyTorch tensors |
| 39 | + X_train_tensor = torch.tensor(X_train, dtype=torch.float32) |
| 40 | + y_train_tensor = torch.tensor(y_train.values, dtype=torch.long) |
| 41 | + train_dataset = TensorDataset(X_train_tensor, y_train_tensor) |
| 42 | + train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) |
| 43 | + |
| 44 | + # Load the existing model |
| 45 | + with open('nids/model_metadata.pkl', 'rb') as f: |
| 46 | + metadata = pickle.load(f) |
| 47 | + num_features = metadata['num_features'] |
| 48 | + num_classes = metadata['num_classes'] |
| 49 | + |
| 50 | + model = Net(num_features, num_classes) |
| 51 | + model.load_state_dict(torch.load('nids/model.pth')) |
| 52 | + |
| 53 | + # Retrain the model |
| 54 | + model = retrain_model(model, train_loader) |
| 55 | + |
| 56 | + # Save the updated model |
| 57 | + torch.save(model.state_dict(), 'nids/updated_model.pth') |
| 58 | + |
| 59 | + # Save the scaler and metadata |
| 60 | + with open('nids/scaler.pkl', 'wb') as f: |
| 61 | + pickle.dump(scaler, f) |
| 62 | + with open('nids/model_metadata.pkl', 'wb') as f: |
| 63 | + pickle.dump(metadata, f) |
| 64 | + |
| 65 | + print("Model, scaler, and metadata (number of features and classes) updated successfully.") |
0 commit comments