Skip to content

Commit 06166d5

Browse files
Redid full classifier to ensure it was trained
correctly
1 parent 62cee9b commit 06166d5

File tree

2 files changed

+87
-12
lines changed

2 files changed

+87
-12
lines changed

CDC_Classifier_Period.ipynb

+87-12
Original file line numberDiff line numberDiff line change
@@ -10383,7 +10383,34 @@
1038310383
"name": "stdout",
1038410384
"output_type": "stream",
1038510385
"text": [
10386-
"0\n"
10386+
"0\n",
10387+
"1\n",
10388+
"2\n",
10389+
"3\n",
10390+
"4\n",
10391+
"5\n",
10392+
"6\n",
10393+
"7\n",
10394+
"8\n",
10395+
"9\n",
10396+
"MCC: 0.7038956549471436\n",
10397+
"Accuracy: 0.9134016668263244\n",
10398+
"auROC: 0.8872189145023346\n"
10399+
]
10400+
},
10401+
{
10402+
"ename": "ValueError",
10403+
"evalue": "Found input variables with inconsistent numbers of samples: [2409, 10439]",
10404+
"output_type": "error",
10405+
"traceback": [
10406+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
10407+
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
10408+
"Cell \u001b[0;32mIn[126], line 28\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39mAccuracy:\u001b[39m\u001b[39m\"\u001b[39m, accuracy)\n\u001b[1;32m 26\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39mauROC:\u001b[39m\u001b[39m\"\u001b[39m, ROC)\n\u001b[0;32m---> 28\u001b[0m \u001b[39mprint\u001b[39m(confusion_matrix(y_test, y_pred))\n",
10409+
"File \u001b[0;32m~/miniconda3/envs/COVID_forecasting/lib/python3.11/site-packages/sklearn/utils/_param_validation.py:211\u001b[0m, in \u001b[0;36mvalidate_params.<locals>.decorator.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 206\u001b[0m \u001b[39mwith\u001b[39;00m config_context(\n\u001b[1;32m 207\u001b[0m skip_parameter_validation\u001b[39m=\u001b[39m(\n\u001b[1;32m 208\u001b[0m prefer_skip_nested_validation \u001b[39mor\u001b[39;00m global_skip_validation\n\u001b[1;32m 209\u001b[0m )\n\u001b[1;32m 210\u001b[0m ):\n\u001b[0;32m--> 211\u001b[0m \u001b[39mreturn\u001b[39;00m func(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[1;32m 212\u001b[0m \u001b[39mexcept\u001b[39;00m InvalidParameterError \u001b[39mas\u001b[39;00m e:\n\u001b[1;32m 213\u001b[0m \u001b[39m# When the function is just a wrapper around an estimator, we allow\u001b[39;00m\n\u001b[1;32m 214\u001b[0m \u001b[39m# the function to delegate validation to the estimator, but we replace\u001b[39;00m\n\u001b[1;32m 215\u001b[0m \u001b[39m# the name of the estimator by the name of the function in the error\u001b[39;00m\n\u001b[1;32m 216\u001b[0m \u001b[39m# message to avoid confusion.\u001b[39;00m\n\u001b[1;32m 217\u001b[0m msg \u001b[39m=\u001b[39m re\u001b[39m.\u001b[39msub(\n\u001b[1;32m 218\u001b[0m \u001b[39mr\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mparameter of \u001b[39m\u001b[39m\\\u001b[39m\u001b[39mw+ must be\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 219\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mparameter of \u001b[39m\u001b[39m{\u001b[39;00mfunc\u001b[39m.\u001b[39m\u001b[39m__qualname__\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m must be\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 220\u001b[0m \u001b[39mstr\u001b[39m(e),\n\u001b[1;32m 221\u001b[0m )\n",
10410+
"File \u001b[0;32m~/miniconda3/envs/COVID_forecasting/lib/python3.11/site-packages/sklearn/metrics/_classification.py:326\u001b[0m, in \u001b[0;36mconfusion_matrix\u001b[0;34m(y_true, y_pred, labels, sample_weight, normalize)\u001b[0m\n\u001b[1;32m 231\u001b[0m \u001b[39m@validate_params\u001b[39m(\n\u001b[1;32m 232\u001b[0m {\n\u001b[1;32m 233\u001b[0m \u001b[39m\"\u001b[39m\u001b[39my_true\u001b[39m\u001b[39m\"\u001b[39m: [\u001b[39m\"\u001b[39m\u001b[39marray-like\u001b[39m\u001b[39m\"\u001b[39m],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 242\u001b[0m y_true, y_pred, \u001b[39m*\u001b[39m, labels\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m, sample_weight\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m, normalize\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m\n\u001b[1;32m 243\u001b[0m ):\n\u001b[1;32m 244\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"Compute confusion matrix to evaluate the accuracy of a classification.\u001b[39;00m\n\u001b[1;32m 245\u001b[0m \n\u001b[1;32m 246\u001b[0m \u001b[39m By definition a confusion matrix :math:`C` is such that :math:`C_{i, j}`\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 324\u001b[0m \u001b[39m (0, 2, 1, 1)\u001b[39;00m\n\u001b[1;32m 325\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 326\u001b[0m y_type, y_true, y_pred \u001b[39m=\u001b[39m _check_targets(y_true, y_pred)\n\u001b[1;32m 327\u001b[0m \u001b[39mif\u001b[39;00m y_type \u001b[39mnot\u001b[39;00m \u001b[39min\u001b[39;00m (\u001b[39m\"\u001b[39m\u001b[39mbinary\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mmulticlass\u001b[39m\u001b[39m\"\u001b[39m):\n\u001b[1;32m 328\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39m%s\u001b[39;00m\u001b[39m is not supported\u001b[39m\u001b[39m\"\u001b[39m \u001b[39m%\u001b[39m y_type)\n",
10411+
"File \u001b[0;32m~/miniconda3/envs/COVID_forecasting/lib/python3.11/site-packages/sklearn/metrics/_classification.py:84\u001b[0m, in \u001b[0;36m_check_targets\u001b[0;34m(y_true, y_pred)\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_check_targets\u001b[39m(y_true, y_pred):\n\u001b[1;32m 58\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"Check that y_true and y_pred belong to the same classification task.\u001b[39;00m\n\u001b[1;32m 59\u001b[0m \n\u001b[1;32m 60\u001b[0m \u001b[39m This converts multiclass or binary types to a common shape, and raises a\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[39m y_pred : array or indicator matrix\u001b[39;00m\n\u001b[1;32m 83\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 84\u001b[0m check_consistent_length(y_true, y_pred)\n\u001b[1;32m 85\u001b[0m type_true \u001b[39m=\u001b[39m type_of_target(y_true, input_name\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39my_true\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 86\u001b[0m type_pred \u001b[39m=\u001b[39m type_of_target(y_pred, input_name\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39my_pred\u001b[39m\u001b[39m\"\u001b[39m)\n",
10412+
"File \u001b[0;32m~/miniconda3/envs/COVID_forecasting/lib/python3.11/site-packages/sklearn/utils/validation.py:409\u001b[0m, in \u001b[0;36mcheck_consistent_length\u001b[0;34m(*arrays)\u001b[0m\n\u001b[1;32m 407\u001b[0m uniques \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39munique(lengths)\n\u001b[1;32m 408\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(uniques) \u001b[39m>\u001b[39m \u001b[39m1\u001b[39m:\n\u001b[0;32m--> 409\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 410\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mFound input variables with inconsistent numbers of samples: \u001b[39m\u001b[39m%r\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[1;32m 411\u001b[0m \u001b[39m%\u001b[39m [\u001b[39mint\u001b[39m(l) \u001b[39mfor\u001b[39;00m l \u001b[39min\u001b[39;00m lengths]\n\u001b[1;32m 412\u001b[0m )\n",
10413+
"\u001b[0;31mValueError\u001b[0m: Found input variables with inconsistent numbers of samples: [2409, 10439]"
1038710414
]
1038810415
}
1038910416
],
@@ -10414,13 +10441,12 @@
1041410441
"print(\"MCC:\", MCC)\n",
1041510442
"print(\"Accuracy:\", accuracy)\n",
1041610443
"print(\"auROC:\", ROC)\n",
10417-
"\n",
10418-
"print(confusion_matrix(y_test, y_pred))"
10444+
"\n"
1041910445
]
1042010446
},
1042110447
{
1042210448
"cell_type": "code",
10423-
"execution_count": 122,
10449+
"execution_count": 127,
1042410450
"metadata": {},
1042510451
"outputs": [
1042610452
{
@@ -10453,12 +10479,61 @@
1045310479
},
1045410480
{
1045510481
"cell_type": "code",
10456-
"execution_count": 165,
10482+
"execution_count": 131,
10483+
"metadata": {},
10484+
"outputs": [
10485+
{
10486+
"name": "stdout",
10487+
"output_type": "stream",
10488+
"text": [
10489+
"MCC: 0.7038956549471436\n",
10490+
"Accuracy: 0.9134016668263244\n",
10491+
"auROC: 0.8872189145023346\n"
10492+
]
10493+
}
10494+
],
10495+
"source": [
10496+
"y_pred = clf.predict(X_test_full)\n",
10497+
"y_pred_proba = clf.predict_proba(X_test_full)\n",
10498+
"\n",
10499+
"# Evaluate the accuracy of the model\n",
10500+
"accuracy = accuracy_score(y_test_full, y_pred)\n",
10501+
"ROC = roc_auc_score(y_test_full, y_pred_proba[:,1])\n",
10502+
"MCC = (matthews_corrcoef(y_test_full, y_pred) + 1)/2\n",
10503+
"\n",
10504+
"print(\"MCC:\", MCC)\n",
10505+
"print(\"Accuracy:\", accuracy)\n",
10506+
"print(\"auROC:\", ROC)\n"
10507+
]
10508+
},
10509+
{
10510+
"cell_type": "code",
10511+
"execution_count": 132,
1045710512
"metadata": {},
1045810513
"outputs": [],
1045910514
"source": [
10460-
"#model_name = f\"CDC_classifier_auroc_{ROC:.4f}_CDC_period_full.sav\"\n",
10461-
"#pickle.dump(clf, open(model_name, 'wb'))"
10515+
"model_name = f\"CDC_classifier_auroc_{ROC:.4f}_CDC_period_full.sav\"\n",
10516+
"pickle.dump(clf, open(model_name, 'wb'))"
10517+
]
10518+
},
10519+
{
10520+
"cell_type": "code",
10521+
"execution_count": 133,
10522+
"metadata": {},
10523+
"outputs": [
10524+
{
10525+
"data": {
10526+
"text/plain": [
10527+
"'CDC_classifier_auroc_0.8872_CDC_period_full.sav'"
10528+
]
10529+
},
10530+
"execution_count": 133,
10531+
"metadata": {},
10532+
"output_type": "execute_result"
10533+
}
10534+
],
10535+
"source": [
10536+
"model_name"
1046210537
]
1046310538
},
1046410539
{
@@ -10604,23 +10679,23 @@
1060410679
},
1060510680
{
1060610681
"cell_type": "code",
10607-
"execution_count": 92,
10682+
"execution_count": 134,
1060810683
"metadata": {},
1060910684
"outputs": [
1061010685
{
1061110686
"name": "stdout",
1061210687
"output_type": "stream",
1061310688
"text": [
10614-
"MCC: 0.47932871282255496\n",
10615-
"Accuracy: 0.14404317144043172\n",
10616-
"auROC: 0.531625516761377\n"
10689+
"MCC: 0.7089756779440055\n",
10690+
"Accuracy: 0.7667081776670818\n",
10691+
"auROC: 0.8185285531837256\n"
1061710692
]
1061810693
}
1061910694
],
1062010695
"source": [
1062110696
"X_test, y_test, weights_test, missing_data_test_HSA = prep_training_test_data(all_HSA_ID_weekly_data, no_weeks = range(2, 5), weeks_in_future = 3, geography = 'HSA_ID', weight_col = 'weight', keep_output = True) # account for the fact that week 1 is the week included to allow for calculation of delta\n",
1062210697
"\n",
10623-
"full_model = pickle.load(open('/Users/rem76/Documents/COVID_projections/COVID_forecasting/CDC_classifier_auroc_0.9091_CDC_period_full.sav', 'rb'))\n",
10698+
"full_model = pickle.load(open('/Users/rem76/Documents/COVID_projections/COVID_forecasting/CDC_classifier_auroc_0.8872_CDC_period_full.sav', 'rb'))\n",
1062410699
"# Train the decision tree classifier\n",
1062510700
"\n",
1062610701
"# Make predictions on the test set\n",
2.34 KB
Binary file not shown.

0 commit comments

Comments
 (0)