|
10383 | 10383 | "name": "stdout",
|
10384 | 10384 | "output_type": "stream",
|
10385 | 10385 | "text": [
|
10386 |
| - "0\n" |
| 10386 | + "0\n", |
| 10387 | + "1\n", |
| 10388 | + "2\n", |
| 10389 | + "3\n", |
| 10390 | + "4\n", |
| 10391 | + "5\n", |
| 10392 | + "6\n", |
| 10393 | + "7\n", |
| 10394 | + "8\n", |
| 10395 | + "9\n", |
| 10396 | + "MCC: 0.7038956549471436\n", |
| 10397 | + "Accuracy: 0.9134016668263244\n", |
| 10398 | + "auROC: 0.8872189145023346\n" |
| 10399 | + ] |
| 10400 | + }, |
| 10401 | + { |
| 10402 | + "ename": "ValueError", |
| 10403 | + "evalue": "Found input variables with inconsistent numbers of samples: [2409, 10439]", |
| 10404 | + "output_type": "error", |
| 10405 | + "traceback": [ |
| 10406 | + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", |
| 10407 | + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", |
| 10408 | + "Cell \u001b[0;32mIn[126], line 28\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39mAccuracy:\u001b[39m\u001b[39m\"\u001b[39m, accuracy)\n\u001b[1;32m 26\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39mauROC:\u001b[39m\u001b[39m\"\u001b[39m, ROC)\n\u001b[0;32m---> 28\u001b[0m \u001b[39mprint\u001b[39m(confusion_matrix(y_test, y_pred))\n", |
| 10409 | + "File \u001b[0;32m~/miniconda3/envs/COVID_forecasting/lib/python3.11/site-packages/sklearn/utils/_param_validation.py:211\u001b[0m, in \u001b[0;36mvalidate_params.<locals>.decorator.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 206\u001b[0m \u001b[39mwith\u001b[39;00m config_context(\n\u001b[1;32m 207\u001b[0m skip_parameter_validation\u001b[39m=\u001b[39m(\n\u001b[1;32m 208\u001b[0m prefer_skip_nested_validation \u001b[39mor\u001b[39;00m global_skip_validation\n\u001b[1;32m 209\u001b[0m )\n\u001b[1;32m 210\u001b[0m ):\n\u001b[0;32m--> 211\u001b[0m \u001b[39mreturn\u001b[39;00m func(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[1;32m 212\u001b[0m \u001b[39mexcept\u001b[39;00m InvalidParameterError \u001b[39mas\u001b[39;00m e:\n\u001b[1;32m 213\u001b[0m \u001b[39m# When the function is just a wrapper around an estimator, we allow\u001b[39;00m\n\u001b[1;32m 214\u001b[0m \u001b[39m# the function to delegate validation to the estimator, but we replace\u001b[39;00m\n\u001b[1;32m 215\u001b[0m \u001b[39m# the name of the estimator by the name of the function in the error\u001b[39;00m\n\u001b[1;32m 216\u001b[0m \u001b[39m# message to avoid confusion.\u001b[39;00m\n\u001b[1;32m 217\u001b[0m msg \u001b[39m=\u001b[39m re\u001b[39m.\u001b[39msub(\n\u001b[1;32m 218\u001b[0m \u001b[39mr\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mparameter of \u001b[39m\u001b[39m\\\u001b[39m\u001b[39mw+ must be\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 219\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mparameter of \u001b[39m\u001b[39m{\u001b[39;00mfunc\u001b[39m.\u001b[39m\u001b[39m__qualname__\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m must be\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 220\u001b[0m \u001b[39mstr\u001b[39m(e),\n\u001b[1;32m 221\u001b[0m )\n", |
| 10410 | + "File \u001b[0;32m~/miniconda3/envs/COVID_forecasting/lib/python3.11/site-packages/sklearn/metrics/_classification.py:326\u001b[0m, in \u001b[0;36mconfusion_matrix\u001b[0;34m(y_true, y_pred, labels, sample_weight, normalize)\u001b[0m\n\u001b[1;32m 231\u001b[0m \u001b[39m@validate_params\u001b[39m(\n\u001b[1;32m 232\u001b[0m {\n\u001b[1;32m 233\u001b[0m \u001b[39m\"\u001b[39m\u001b[39my_true\u001b[39m\u001b[39m\"\u001b[39m: [\u001b[39m\"\u001b[39m\u001b[39marray-like\u001b[39m\u001b[39m\"\u001b[39m],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 242\u001b[0m y_true, y_pred, \u001b[39m*\u001b[39m, labels\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m, sample_weight\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m, normalize\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m\n\u001b[1;32m 243\u001b[0m ):\n\u001b[1;32m 244\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"Compute confusion matrix to evaluate the accuracy of a classification.\u001b[39;00m\n\u001b[1;32m 245\u001b[0m \n\u001b[1;32m 246\u001b[0m \u001b[39m By definition a confusion matrix :math:`C` is such that :math:`C_{i, j}`\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 324\u001b[0m \u001b[39m (0, 2, 1, 1)\u001b[39;00m\n\u001b[1;32m 325\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 326\u001b[0m y_type, y_true, y_pred \u001b[39m=\u001b[39m _check_targets(y_true, y_pred)\n\u001b[1;32m 327\u001b[0m \u001b[39mif\u001b[39;00m y_type \u001b[39mnot\u001b[39;00m \u001b[39min\u001b[39;00m (\u001b[39m\"\u001b[39m\u001b[39mbinary\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mmulticlass\u001b[39m\u001b[39m\"\u001b[39m):\n\u001b[1;32m 328\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39m%s\u001b[39;00m\u001b[39m is not supported\u001b[39m\u001b[39m\"\u001b[39m \u001b[39m%\u001b[39m y_type)\n", |
| 10411 | + "File \u001b[0;32m~/miniconda3/envs/COVID_forecasting/lib/python3.11/site-packages/sklearn/metrics/_classification.py:84\u001b[0m, in \u001b[0;36m_check_targets\u001b[0;34m(y_true, y_pred)\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_check_targets\u001b[39m(y_true, y_pred):\n\u001b[1;32m 58\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"Check that y_true and y_pred belong to the same classification task.\u001b[39;00m\n\u001b[1;32m 59\u001b[0m \n\u001b[1;32m 60\u001b[0m \u001b[39m This converts multiclass or binary types to a common shape, and raises a\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[39m y_pred : array or indicator matrix\u001b[39;00m\n\u001b[1;32m 83\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 84\u001b[0m check_consistent_length(y_true, y_pred)\n\u001b[1;32m 85\u001b[0m type_true \u001b[39m=\u001b[39m type_of_target(y_true, input_name\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39my_true\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 86\u001b[0m type_pred \u001b[39m=\u001b[39m type_of_target(y_pred, input_name\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39my_pred\u001b[39m\u001b[39m\"\u001b[39m)\n", |
| 10412 | + "File \u001b[0;32m~/miniconda3/envs/COVID_forecasting/lib/python3.11/site-packages/sklearn/utils/validation.py:409\u001b[0m, in \u001b[0;36mcheck_consistent_length\u001b[0;34m(*arrays)\u001b[0m\n\u001b[1;32m 407\u001b[0m uniques \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39munique(lengths)\n\u001b[1;32m 408\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(uniques) \u001b[39m>\u001b[39m \u001b[39m1\u001b[39m:\n\u001b[0;32m--> 409\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 410\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mFound input variables with inconsistent numbers of samples: \u001b[39m\u001b[39m%r\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[1;32m 411\u001b[0m \u001b[39m%\u001b[39m [\u001b[39mint\u001b[39m(l) \u001b[39mfor\u001b[39;00m l \u001b[39min\u001b[39;00m lengths]\n\u001b[1;32m 412\u001b[0m )\n", |
| 10413 | + "\u001b[0;31mValueError\u001b[0m: Found input variables with inconsistent numbers of samples: [2409, 10439]" |
10387 | 10414 | ]
|
10388 | 10415 | }
|
10389 | 10416 | ],
|
@@ -10414,13 +10441,12 @@
|
10414 | 10441 | "print(\"MCC:\", MCC)\n",
|
10415 | 10442 | "print(\"Accuracy:\", accuracy)\n",
|
10416 | 10443 | "print(\"auROC:\", ROC)\n",
|
10417 |
| - "\n", |
10418 |
| - "print(confusion_matrix(y_test, y_pred))" |
| 10444 | + "\n" |
10419 | 10445 | ]
|
10420 | 10446 | },
|
10421 | 10447 | {
|
10422 | 10448 | "cell_type": "code",
|
10423 |
| - "execution_count": 122, |
| 10449 | + "execution_count": 127, |
10424 | 10450 | "metadata": {},
|
10425 | 10451 | "outputs": [
|
10426 | 10452 | {
|
@@ -10453,12 +10479,61 @@
|
10453 | 10479 | },
|
10454 | 10480 | {
|
10455 | 10481 | "cell_type": "code",
|
10456 |
| - "execution_count": 165, |
| 10482 | + "execution_count": 131, |
| 10483 | + "metadata": {}, |
| 10484 | + "outputs": [ |
| 10485 | + { |
| 10486 | + "name": "stdout", |
| 10487 | + "output_type": "stream", |
| 10488 | + "text": [ |
| 10489 | + "MCC: 0.7038956549471436\n", |
| 10490 | + "Accuracy: 0.9134016668263244\n", |
| 10491 | + "auROC: 0.8872189145023346\n" |
| 10492 | + ] |
| 10493 | + } |
| 10494 | + ], |
| 10495 | + "source": [ |
| 10496 | + "y_pred = clf.predict(X_test_full)\n", |
| 10497 | + "y_pred_proba = clf.predict_proba(X_test_full)\n", |
| 10498 | + "\n", |
| 10499 | + "# Evaluate the accuracy of the model\n", |
| 10500 | + "accuracy = accuracy_score(y_test_full, y_pred)\n", |
| 10501 | + "ROC = roc_auc_score(y_test_full, y_pred_proba[:,1])\n", |
| 10502 | + "MCC = (matthews_corrcoef(y_test_full, y_pred) + 1)/2\n", |
| 10503 | + "\n", |
| 10504 | + "print(\"MCC:\", MCC)\n", |
| 10505 | + "print(\"Accuracy:\", accuracy)\n", |
| 10506 | + "print(\"auROC:\", ROC)\n" |
| 10507 | + ] |
| 10508 | + }, |
| 10509 | + { |
| 10510 | + "cell_type": "code", |
| 10511 | + "execution_count": 132, |
10457 | 10512 | "metadata": {},
|
10458 | 10513 | "outputs": [],
|
10459 | 10514 | "source": [
|
10460 |
| - "#model_name = f\"CDC_classifier_auroc_{ROC:.4f}_CDC_period_full.sav\"\n", |
10461 |
| - "#pickle.dump(clf, open(model_name, 'wb'))" |
| 10515 | + "model_name = f\"CDC_classifier_auroc_{ROC:.4f}_CDC_period_full.sav\"\n", |
| 10516 | + "pickle.dump(clf, open(model_name, 'wb'))" |
| 10517 | + ] |
| 10518 | + }, |
| 10519 | + { |
| 10520 | + "cell_type": "code", |
| 10521 | + "execution_count": 133, |
| 10522 | + "metadata": {}, |
| 10523 | + "outputs": [ |
| 10524 | + { |
| 10525 | + "data": { |
| 10526 | + "text/plain": [ |
| 10527 | + "'CDC_classifier_auroc_0.8872_CDC_period_full.sav'" |
| 10528 | + ] |
| 10529 | + }, |
| 10530 | + "execution_count": 133, |
| 10531 | + "metadata": {}, |
| 10532 | + "output_type": "execute_result" |
| 10533 | + } |
| 10534 | + ], |
| 10535 | + "source": [ |
| 10536 | + "model_name" |
10462 | 10537 | ]
|
10463 | 10538 | },
|
10464 | 10539 | {
|
@@ -10604,23 +10679,23 @@
|
10604 | 10679 | },
|
10605 | 10680 | {
|
10606 | 10681 | "cell_type": "code",
|
10607 |
| - "execution_count": 92, |
| 10682 | + "execution_count": 134, |
10608 | 10683 | "metadata": {},
|
10609 | 10684 | "outputs": [
|
10610 | 10685 | {
|
10611 | 10686 | "name": "stdout",
|
10612 | 10687 | "output_type": "stream",
|
10613 | 10688 | "text": [
|
10614 |
| - "MCC: 0.47932871282255496\n", |
10615 |
| - "Accuracy: 0.14404317144043172\n", |
10616 |
| - "auROC: 0.531625516761377\n" |
| 10689 | + "MCC: 0.7089756779440055\n", |
| 10690 | + "Accuracy: 0.7667081776670818\n", |
| 10691 | + "auROC: 0.8185285531837256\n" |
10617 | 10692 | ]
|
10618 | 10693 | }
|
10619 | 10694 | ],
|
10620 | 10695 | "source": [
|
10621 | 10696 | "X_test, y_test, weights_test, missing_data_test_HSA = prep_training_test_data(all_HSA_ID_weekly_data, no_weeks = range(2, 5), weeks_in_future = 3, geography = 'HSA_ID', weight_col = 'weight', keep_output = True) # account for the fact that week 1 is the week included to allow for calculation of delta\n",
|
10622 | 10697 | "\n",
|
10623 |
| - "full_model = pickle.load(open('/Users/rem76/Documents/COVID_projections/COVID_forecasting/CDC_classifier_auroc_0.9091_CDC_period_full.sav', 'rb'))\n", |
| 10698 | + "full_model = pickle.load(open('/Users/rem76/Documents/COVID_projections/COVID_forecasting/CDC_classifier_auroc_0.8872_CDC_period_full.sav', 'rb'))\n", |
10624 | 10699 | "# Train the decision tree classifier\n",
|
10625 | 10700 | "\n",
|
10626 | 10701 | "# Make predictions on the test set\n",
|
|
0 commit comments