Skip to content

Commit c39922a

Browse files
committed
run method
1 parent 09630b4 commit c39922a

File tree

6 files changed

+157
-45
lines changed

6 files changed

+157
-45
lines changed

.gitignore

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,8 @@
11
nids/__pycache__
2-
dataset/MachineLearningCVE
2+
dataset/MachineLearningCVE
3+
nids_logs.log
4+
packet_sniffer.log
5+
nids/model_metadata.pkl
6+
nids/model.pth
7+
nids/scaler.pkl
8+
nids/training_test_accuracy.png

nids/flow.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# nids/flow.py
2+
3+
import time
4+
import pandas as pd
5+
from scapy.all import IP, IPv6, TCP, UDP
6+
7+
class Flow:
8+
def __init__(self, packet):
9+
self.packets = [packet]
10+
self.start_time = time.time()
11+
self.end_time = time.time()
12+
if IP in packet:
13+
self.src_ip = packet[IP].src
14+
self.dst_ip = packet[IP].dst
15+
self.src_port = packet[IP].sport if TCP in packet or UDP in packet else 0
16+
self.dst_port = packet[IP].dport if TCP in packet or UDP in packet else 0
17+
elif IPv6 in packet:
18+
self.src_ip = packet[IPv6].src
19+
self.dst_ip = packet[IPv6].dst
20+
self.src_port = packet[IPv6].sport if TCP in packet or UDP in packet else 0
21+
self.dst_port = packet[IPv6].dport if TCP in packet or UDP in packet else 0
22+
self.total_fwd_packets = 1 if self.src_ip == (packet[IP].src if IP in packet else packet[IPv6].src) else 0
23+
self.total_bwd_packets = 1 if self.src_ip == (packet[IP].dst if IP in packet else packet[IPv6].dst) else 0
24+
self.total_fwd_bytes = len(packet) if self.src_ip == (packet[IP].src if IP in packet else packet[IPv6].src) else 0
25+
self.total_bwd_bytes = len(packet) if self.src_ip == (packet[IP].dst if IP in packet else packet[IPv6].dst) else 0
26+
27+
def update(self, packet):
28+
self.packets.append(packet)
29+
self.end_time = time.time()
30+
if self.src_ip == (packet[IP].src if IP in packet else packet[IPv6].src):
31+
self.total_fwd_packets += 1
32+
self.total_fwd_bytes += len(packet)
33+
else:
34+
self.total_bwd_packets += 1
35+
self.total_bwd_bytes += len(packet)
36+
37+
def get_duration(self):
38+
return (self.end_time - self.start_time) * 1e6 # duration in microseconds
39+
40+
def get_features(self):
41+
features = {
42+
'Destination Port': self.dst_port,
43+
'Flow Duration': self.get_duration(),
44+
'Total Fwd Packets': self.total_fwd_packets,
45+
'Total Backward Packets': self.total_bwd_packets,
46+
'Total Length of Fwd Packets': self.total_fwd_bytes,
47+
'Total Length of Bwd Packets': self.total_bwd_bytes,
48+
'Fwd Packet Length Max': max([len(p) for p in self.packets if self.src_ip == (p[IP].src if IP in p else p[IPv6].src)], default=0),
49+
'Fwd Packet Length Min': min([len(p) for p in self.packets if self.src_ip == (p[IP].src if IP in p else p[IPv6].src)], default=0),
50+
'Fwd Packet Length Mean': sum([len(p) for p in self.packets if self.src_ip == (p[IP].src if IP in p else p[IPv6].src)]) / self.total_fwd_packets if self.total_fwd_packets > 0 else 0,
51+
'Fwd Packet Length Std': pd.Series([len(p) for p in self.packets if self.src_ip == (p[IP].src if IP in p else p[IPv6].src)]).std(),
52+
'Bwd Packet Length Max': max([len(p) for p in self.packets if self.src_ip == (p[IP].dst if IP in p else p[IPv6].dst)], default=0),
53+
'Bwd Packet Length Min': min([len(p) for p in self.packets if self.src_ip == (p[IP].dst if IP in p else p[IPv6].dst)], default=0),
54+
'Bwd Packet Length Mean': sum([len(p) for p in self.packets if self.src_ip == (p[IP].dst if IP in p else p[IPv6].dst)]) / self.total_bwd_packets if self.total_bwd_packets > 0 else 0,
55+
'Bwd Packet Length Std': pd.Series([len(p) for p in self.packets if self.src_ip == (p[IP].dst if IP in p else p[IPv6].dst)]).std(),
56+
'Flow Bytes/s': (self.total_fwd_bytes + self.total_bwd_bytes) / self.get_duration() * 1e6 if self.get_duration() > 0 else 0,
57+
'Flow Packets/s': (self.total_fwd_packets + self.total_bwd_packets) / self.get_duration() * 1e6 if self.get_duration() > 0 else 0,
58+
}
59+
return features

nids/logging.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,22 @@
11
# nids/logging.py
22

33
import logging
4+
from scapy.all import IP, IPv6
5+
import os
46

57
def setup_logging():
6-
# Configure logging
7-
logging.basicConfig(filename='nids_logs.log', level=logging.INFO,
8+
# Clear the log file if it exists
9+
log_file = 'nids_logs.log'
10+
if os.path.exists(log_file):
11+
with open(log_file, 'w'):
12+
pass
13+
14+
# Setup logging configuration
15+
logging.basicConfig(filename=log_file, level=logging.INFO,
816
format='%(asctime)s:%(levelname)s:%(message)s')
917

10-
def log_prediction(data, prediction):
11-
logging.info(f'Data: {data}, Prediction: {prediction.item()}')
18+
def log_prediction(packet, prediction, original_data, traffic_type, src_ip, dst_ip):
19+
summary = packet.summary() if IP in packet or IPv6 in packet else "Non-IP packet"
20+
features_str = ', '.join([f'{k}: {v}' for k, v in original_data.items()])
21+
log_message = f'Packet: {summary}, Prediction: {prediction.item()} ({traffic_type}), Source IP: {src_ip}, Destination IP: {dst_ip}, Features: [{features_str}]'
22+
logging.info(log_message)

nids/prediction.py

+49-20
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,63 @@
11
# nids/prediction.py
22

3-
from kafka import KafkaConsumer
43
import torch
54
import pandas as pd
6-
import json
7-
from nids.model import Net
85
from nids.logging import log_prediction
6+
from nids.flow import Flow
97
from sklearn.preprocessing import StandardScaler
8+
from scapy.all import IP, TCP, UDP, IPv6
9+
import pickle
1010

11-
def preprocess_data(data, scaler):
12-
data = pd.DataFrame([data])
13-
data = pd.get_dummies(data)
14-
data = scaler.transform(data)
15-
return torch.tensor(data, dtype=torch.float32)
16-
17-
def run_prediction(model, scaler):
18-
# Initialize Kafka consumer
19-
consumer = KafkaConsumer('network_traffic',
20-
bootstrap_servers='localhost:9092',
21-
value_deserializer=lambda v: json.loads(v.decode('utf-8')))
11+
# Load the metadata including class mapping and feature names
12+
with open('nids/model_metadata.pkl', 'rb') as f:
13+
metadata = pickle.load(f)
14+
LABEL_TO_TRAFFIC_TYPE = metadata['class_mapping']
15+
FEATURE_NAMES = metadata['feature_names']
16+
17+
flows = {}
18+
19+
def preprocess_packet(packet, scaler):
20+
if IP in packet:
21+
flow_key = (packet[IP].src, packet[IP].dst, packet[IP].sport, packet[IP].dport)
22+
elif IPv6 in packet:
23+
flow_key = (packet[IPv6].src, packet[IPv6].dst, packet[IPv6].sport, packet[IPv6].dport)
24+
else:
25+
return None, None, None, None
26+
27+
if flow_key not in flows:
28+
flows[flow_key] = Flow(packet)
29+
else:
30+
flows[flow_key].update(packet)
2231

32+
flow = flows[flow_key]
33+
features = flow.get_features()
34+
35+
# Create DataFrame with consistent feature names
36+
df = pd.DataFrame([features])
37+
df = pd.get_dummies(df)
38+
df = df.reindex(columns=scaler.feature_names_in_, fill_value=0) # Ensure consistent feature order
39+
40+
scaled_data = scaler.transform(df)
41+
src_ip = flow.src_ip
42+
dst_ip = flow.dst_ip
43+
return torch.tensor(scaled_data, dtype=torch.float32), features, src_ip, dst_ip
44+
45+
def run_prediction(packet, model, scaler):
2346
model.eval()
24-
# Real-time prediction loop
25-
for message in consumer:
26-
data = message.value
27-
data_tensor = preprocess_data(data, scaler)
47+
try:
48+
data_tensor, original_data, src_ip, dst_ip = preprocess_packet(packet, scaler)
49+
if data_tensor is None:
50+
print(f"Packet ignored: {packet.summary()}")
51+
return
2852

2953
# Make prediction
3054
with torch.no_grad():
3155
output = model(data_tensor)
3256
_, prediction = torch.max(output, 1)
33-
log_prediction(data, prediction)
34-
print(f'Prediction: {prediction.item()}')
57+
traffic_type = LABEL_TO_TRAFFIC_TYPE.get(prediction.item(), "Unknown")
58+
59+
# Log detailed information
60+
log_prediction(packet, prediction, original_data, traffic_type, src_ip, dst_ip)
61+
print(f'Prediction: {prediction.item()} ({traffic_type}), Source IP: {src_ip}, Destination IP: {dst_ip}')
62+
except Exception as e:
63+
print(f'Error processing packet: {e}')

nids/retraining.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch.utils.data import DataLoader, TensorDataset
99
from sklearn.model_selection import train_test_split
1010
from sklearn.preprocessing import StandardScaler
11-
from nids.model import Net
11+
from nids import Net
1212

1313
def load_and_preprocess_data(file_path):
1414
data = pd.read_csv(file_path)

run.py

+26-19
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,31 @@
22

33
import torch
44
import pickle
5-
from nids import Net, setup_logging, run_prediction
5+
from scapy.all import sniff, IP, IPv6
6+
from nids import setup_logging
7+
from nids.model import Net
8+
from nids.prediction import run_prediction
69

7-
if __name__ == '__main__':
8-
# Setup logging
9-
setup_logging()
10+
# Setup logging
11+
setup_logging()
12+
13+
# Load the model and scaler
14+
with open('nids/model_metadata.pkl', 'rb') as f:
15+
metadata = pickle.load(f)
16+
17+
model = Net(input_size=metadata['num_features'], num_classes=metadata['num_classes'])
18+
model.load_state_dict(torch.load('nids/model.pth'))
19+
model.eval()
1020

11-
# Load the number of features and classes
12-
with open('nids/model_metadata.pkl', 'rb') as f:
13-
metadata = pickle.load(f)
14-
num_features = metadata['num_features']
15-
num_classes = metadata['num_classes']
16-
17-
# Load the model and scaler
18-
model = Net(input_size=num_features, num_classes=num_classes)
19-
model.load_state_dict(torch.load('nids/model.pth'))
20-
21-
with open('nids/scaler.pkl', 'rb') as f:
22-
scaler = pickle.load(f)
23-
24-
# Run real-time prediction
25-
run_prediction(model, scaler)
21+
with open('nids/scaler.pkl', 'rb') as f:
22+
scaler = pickle.load(f)
23+
24+
def packet_handler(packet):
25+
if IP or IPv6 in packet:
26+
run_prediction(packet, model, scaler)
27+
else:
28+
print("Non-IP/IPv6 packet ignored")
29+
30+
if __name__ == '__main__':
31+
print("Starting packet capture...")
32+
sniff(prn=packet_handler, store=0) # prn specifies the function to apply to each packet

0 commit comments

Comments
 (0)