Skip to content

Commit ea7791c

Browse files
committed
raspberry pi implementation
1 parent c39922a commit ea7791c

File tree

2 files changed

+149
-0
lines changed

2 files changed

+149
-0
lines changed

nids/pi_prediction.py

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import torch
2+
import pandas as pd
3+
from nids.logging import log_prediction
4+
from nids.flow import Flow
5+
from sklearn.preprocessing import StandardScaler
6+
from scapy.all import IP, TCP, UDP, IPv6
7+
import pickle
8+
import RPi.GPIO as GPIO
9+
import time
10+
import threading
11+
from queue import Queue
12+
13+
# Load the metadata including class mapping and feature names
14+
with open('nids/model_metadata.pkl', 'rb') as f:
15+
metadata = pickle.load(f)
16+
LABEL_TO_TRAFFIC_TYPE = metadata['class_mapping']
17+
FEATURE_NAMES = metadata['feature_names']
18+
19+
flows = {}
20+
21+
def preprocess_packet(packet, scaler):
22+
if IP in packet:
23+
flow_key = (packet[IP].src, packet[IP].dst, packet[IP].sport, packet[IP].dport)
24+
elif IPv6 in packet:
25+
flow_key = (packet[IPv6].src, packet[IPv6].dst, packet[IPv6].sport, packet[IPv6].dport)
26+
else:
27+
return None, None, None, None
28+
29+
if flow_key not in flows:
30+
flows[flow_key] = Flow(packet)
31+
else:
32+
flows[flow_key].update(packet)
33+
34+
flow = flows[flow_key]
35+
features = flow.get_features()
36+
37+
# Create DataFrame with consistent feature names
38+
df = pd.DataFrame([features])
39+
df = pd.get_dummies(df)
40+
df = df.reindex(columns=scaler.feature_names_in_, fill_value=0) # Ensure consistent feature order
41+
42+
scaled_data = scaler.transform(df)
43+
src_ip = flow.src_ip
44+
dst_ip = flow.dst_ip
45+
return torch.tensor(scaled_data, dtype=torch.float32), features, src_ip, dst_ip
46+
47+
# Define the GPIO pin numbers for the LEDs
48+
GREEN_LED_PIN = 17
49+
RED_LED_PIN = 27
50+
ORANGE_LED_PIN = 22
51+
52+
# Setup GPIO mode and pins
53+
GPIO.setmode(GPIO.BCM)
54+
GPIO.setup(GREEN_LED_PIN, GPIO.OUT)
55+
GPIO.setup(RED_LED_PIN, GPIO.OUT)
56+
GPIO.setup(ORANGE_LED_PIN, GPIO.OUT)
57+
58+
# Create a queue for red and orange LED events
59+
led_queue = Queue()
60+
61+
# Lock for mutual exclusivity
62+
led_lock = threading.Lock()
63+
64+
def control_led(pin, duration):
65+
with led_lock:
66+
GPIO.output(pin, GPIO.HIGH)
67+
time.sleep(duration)
68+
GPIO.output(pin, GPIO.LOW)
69+
70+
def led_worker():
71+
while True:
72+
pin, duration = led_queue.get()
73+
control_led(pin, duration)
74+
led_queue.task_done()
75+
76+
# Start the LED worker thread
77+
led_thread = threading.Thread(target=led_worker, daemon=True)
78+
led_thread.start()
79+
80+
def run_prediction(packet, model, scaler):
81+
model.eval()
82+
try:
83+
data_tensor, original_data, src_ip, dst_ip = preprocess_packet(packet, scaler)
84+
if data_tensor is None:
85+
print(f"Packet ignored: {packet.summary()}")
86+
return
87+
88+
# Make prediction
89+
with torch.no_grad():
90+
output = model(data_tensor)
91+
_, prediction = torch.max(output, 1)
92+
traffic_type = LABEL_TO_TRAFFIC_TYPE.get(prediction.item(), "Unknown")
93+
94+
# Log detailed information
95+
log_prediction(packet, prediction, original_data, traffic_type, src_ip, dst_ip)
96+
print(f'Prediction: {prediction.item()} ({traffic_type}), Source IP: {src_ip}, Destination IP: {dst_ip}')
97+
98+
# Control LEDs based on prediction
99+
if traffic_type == "Unknown":
100+
led_queue.put((ORANGE_LED_PIN, 2))
101+
elif prediction.item() != 0:
102+
led_queue.put((RED_LED_PIN, 2))
103+
else:
104+
threading.Thread(target=control_led, args=(GREEN_LED_PIN, 2)).start()
105+
106+
except Exception as e:
107+
print(f"Error during prediction: {e}")
108+
109+
finally:
110+
# Ensure all LEDs are turned off after processing
111+
GPIO.output(GREEN_LED_PIN, GPIO.LOW)
112+
GPIO.output(RED_LED_PIN, GPIO.LOW)
113+
GPIO.output(ORANGE_LED_PIN, GPIO.LOW)
114+
115+
# Clean up GPIO on exit
116+
def cleanup_gpio():
117+
GPIO.cleanup()

pi_run.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# run.py
2+
3+
import torch
4+
import pickle
5+
from scapy.all import sniff, IP, IPv6
6+
from nids import setup_logging
7+
from nids.model import Net
8+
from nids.pi_prediction import run_prediction
9+
10+
# Setup logging
11+
setup_logging()
12+
13+
# Load the model and scaler
14+
with open('nids/model_metadata.pkl', 'rb') as f:
15+
metadata = pickle.load(f)
16+
17+
model = Net(input_size=metadata['num_features'], num_classes=metadata['num_classes'])
18+
model.load_state_dict(torch.load('nids/model.pth'))
19+
model.eval()
20+
21+
with open('nids/scaler.pkl', 'rb') as f:
22+
scaler = pickle.load(f)
23+
24+
def packet_handler(packet):
25+
if IP or IPv6 in packet:
26+
run_prediction(packet, model, scaler)
27+
else:
28+
print("Non-IP/IPv6 packet ignored")
29+
30+
if __name__ == '__main__':
31+
print("Starting packet capture...")
32+
sniff(prn=packet_handler, store=0) # prn specifies the function to apply to each packet

0 commit comments

Comments
 (0)