-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
125 lines (104 loc) · 4.71 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import yaml
import logging
import argparse
from threading import Thread
from data_preprocessing import preprocess
from feature_selection import hhosssa_feature_selection
from data_balancing import hhosssa_smote
from models import train_models
from evaluation import evaluation_metrics, cross_validation
from real_time_detection import data_ingestion, prediction_engine
from dashboard import app
logging.basicConfig(level=logging.INFO)
def load_config():
"""Load configuration from config.yaml file."""
with open('config.yaml', 'r') as file:
return yaml.safe_load(file)
def run_data_ingestion(config):
logging.info("Starting real-time data ingestion...")
data_ingestion.run()
logging.info("Real-time data ingestion and prediction setup completed.")
def run_dashboard(config):
logging.info("Starting dashboard...")
app.run(
host=config['dashboard']['host'],
port=config['dashboard']['port'],
debug=config['dashboard']['debug']
)
def parse_arguments():
parser = argparse.ArgumentParser(description='APT Detection System')
parser.add_argument('--train', action='store_true', help='Train models')
parser.add_argument('--predict', action='store_true', help='Run prediction engine')
parser.add_argument('--dashboard', action='store_true', help='Run dashboard')
parser.add_argument('--all', action='store_true', help='Run all components')
return parser.parse_args()
if __name__ == "__main__":
try:
# Load configuration
config = load_config()
# Parse command line arguments
args = parse_arguments()
# If no arguments provided, run all components
if not (args.train or args.predict or args.dashboard):
args.all = True
# Train models if requested
if args.train or args.all:
# Load and preprocess data
logging.info("Starting data preprocessing...")
dataset_path = os.path.join(os.getcwd(), config['data_paths']['dataset'])
df = preprocess.run(dataset_path)
logging.info("Data preprocessing completed.")
# Feature selection
logging.info("Starting feature selection...")
selected_features = hhosssa_feature_selection.run(df)
logging.info("Feature selection completed.")
# Data balancing
logging.info("Starting data balancing...")
balanced_data = hhosssa_smote.run(selected_features)
logging.info("Data balancing completed.")
# Train models
logging.info("Starting model training...")
lgbm_model, bilstm_model, hybrid_model = train_models.run(balanced_data, save=True)
logging.info("Model training completed.")
# Evaluate models
logging.info("Starting model evaluation...")
accuracy, roc_auc = evaluation_metrics.evaluate(hybrid_model, balanced_data)
logging.info(f"Model evaluation completed with Accuracy: {accuracy}, ROC-AUC: {roc_auc}")
# Initialize models for prediction
models = None
if args.train or args.all:
# Use freshly trained models
models = {'lgbm_model': lgbm_model, 'bilstm_model': bilstm_model}
# Initialize threads
ingestion_thread = None
dashboard_thread = None
# Run prediction engine if requested
if args.predict or args.all:
# Real-time detection setup
ingestion_thread = Thread(target=run_data_ingestion, args=(config,))
ingestion_thread.start()
# Start prediction engine
logging.info("Starting prediction engine...")
try:
if models:
# Use freshly trained models
predict_fn = prediction_engine.run(models, use_saved_models=False)
else:
# Load models from disk
predict_fn = prediction_engine.run(use_saved_models=True)
logging.info("Prediction engine started successfully.")
except Exception as e:
logging.error(f"Failed to start prediction engine: {e}")
# Continue with other components even if prediction engine fails
# Run dashboard if requested
if args.dashboard or args.all:
dashboard_thread = Thread(target=run_dashboard, args=(config,))
dashboard_thread.start()
# Wait for threads to complete
if dashboard_thread:
dashboard_thread.join()
if ingestion_thread:
ingestion_thread.join()
except Exception as e:
logging.error(f"An error occurred: {e}")