{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "223810fd-84a1-40f9-a303-2a64403b49fe", "metadata": {}, "source": [ "# 1. Evaluating supervised learners in CapyMOA\n", "\n", "This notebook further explores **high-level evaluation functions**, **Data Abstraction** and **Classifiers**\n", "\n", "* **High-level evaluation functions**\n", " * We demonstrate how to use ```prequential_evaluation()``` and how to further encapsulate prequential evaluation using ```prequential_evaluation_multiple_learners```\n", " * We also discuss particularities about how these evaluation functions relate to how research has developed in the field, and how evaluation is commonly performed and presented.\n", "\n", "* **Supervised Learning**\n", " * We clarify important information concerning the usage of **Classifiers** and their predictions\n", " * We added some examples using **Regressors**, which highlight the fact that the evaluation is identical to **Classifiers** (i.e. same high-level evaluation functions)\n", " \n", "---\n", "\n", "*More information about CapyMOA can be found in* https://www.capymoa.org\n", "\n", "**last update on 25/07/2024**" ] }, { "attachments": {}, "cell_type": "markdown", "id": "df79e7d7-bde3-4b59-8479-a44247b3c26f", "metadata": {}, "source": [ "## 1. The difference between Evaluators\n", "\n", "* The following example implements an **while loop** that updates a ```ClassificationWindowedEvaluator``` and a ```ClassificationEvaluator``` for the same learner. \n", "* The ```ClassificationWindowedEvaluator``` update the metrics according to tumbling windows which 'forgets' old correct and incorrect predictions. This allows us to observe how well the learner performs on shorter windows. \n", "* The ```ClassificationEvaluator``` updates the metrics taking into account all the correct and incorrect predictions made. It is useful to observe the overall performance after processing hundreds of thousands of instances.\n", "\n", "* **Two important points**:\n", " 1. Regarding **window_size** in ```ClassificationEvaluator```: A ```ClassificationEvaluator``` also allow us to specify a window size, but it only controls the frequency at which cumulative metrics are calculated.\n", " 2. If we access metrics directly (not through ```metrics_per_window()```) in ```ClassificationWindowedEvaluator``` we will be looking at the metrics corresponding to the last window.\n", " \n", "For further insight into the specifics of the Evaluators, please refer to the documentation: https://www.capymoa.org" ] }, { "cell_type": "code", "execution_count": 1, "id": "df2d9ca9", "metadata": { "execution": { "iopub.execute_input": "2024-09-23T00:26:45.015368Z", "iopub.status.busy": "2024-09-23T00:26:45.014637Z", "iopub.status.idle": "2024-09-23T00:26:45.041402Z", "shell.execute_reply": "2024-09-23T00:26:45.039349Z" }, "nbsphinx": "hidden" }, "outputs": [], "source": [ "# This cell is hidden on capymoa.org. See docs/contributing/docs.rst\n", "from util.nbmock import mock_datasets, is_nb_fast\n", "\n", "if is_nb_fast():\n", " mock_datasets()" ] }, { "cell_type": "code", "execution_count": 2, "id": "9b29ce7d-7e76-4741-a728-2bd2e795eb79", "metadata": { "execution": { "iopub.execute_input": "2024-09-23T00:26:45.049266Z", "iopub.status.busy": "2024-09-23T00:26:45.048500Z", "iopub.status.idle": "2024-09-23T00:26:57.757668Z", "shell.execute_reply": "2024-09-23T00:26:57.757178Z" }, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ClassificationWindowedEvaluator] Windowed accuracy reported for every window_size windows\n", "[89.57777777777778, 89.46666666666667, 90.2, 89.71111111111111, 88.68888888888888, 88.48888888888888, 87.6888888888889, 88.88888888888889, 89.28888888888889, 91.06666666666666]\n", "[ClassificationEvaluator] Cumulative accuracy: 89.32953742937853\n" ] } ], "source": [ "from capymoa.datasets import Electricity\n", "from capymoa.evaluation import ClassificationWindowedEvaluator, ClassificationEvaluator\n", "from capymoa.classifier import AdaptiveRandomForestClassifier\n", "\n", "stream = Electricity()\n", "\n", "ARF = AdaptiveRandomForestClassifier(schema=stream.get_schema(), ensemble_size=10)\n", "\n", "# The window_size in ClassificationWindowedEvaluator specifies the amount of instances used per evaluation\n", "windowedEvaluatorARF = ClassificationWindowedEvaluator(\n", " schema=stream.get_schema(), window_size=4500\n", ")\n", "# The window_size ClassificationEvaluator just specifies the frequency at which the cumulative metrics are stored\n", "classificationEvaluatorARF = ClassificationEvaluator(\n", " schema=stream.get_schema(), window_size=4500\n", ")\n", "\n", "while stream.has_more_instances():\n", " instance = stream.next_instance()\n", " prediction = ARF.predict(instance)\n", " windowedEvaluatorARF.update(instance.y_index, prediction)\n", " classificationEvaluatorARF.update(instance.y_index, prediction)\n", " ARF.train(instance)\n", "\n", "# Showing only the 'classifications correct (percent)' (i.e. accuracy)\n", "print(\n", " \"[ClassificationWindowedEvaluator] Windowed accuracy reported for every window_size windows\"\n", ")\n", "print(windowedEvaluatorARF.accuracy())\n", "\n", "print(\n", " f\"[ClassificationEvaluator] Cumulative accuracy: {classificationEvaluatorARF.accuracy()}\"\n", ")\n", "# We could report the cumulative accuracy every window_size instances with the following code, but that is normally not very insightful.\n", "# display(classificationEvaluatorARF.metrics_per_window())" ] }, { "cell_type": "markdown", "id": "c61735c4-c41a-443c-be3c-e6cc3b526395", "metadata": {}, "source": [ "## 2. High-level evaluation functions\n", "\n", "In CapyMOA, for supervised learning, there is one primary evaluation function designed to handle the manipulation of Evaluators, i.e. the `prequential_evaluation()`. This function streamline the process, ensuring users need not directly update them. Essentially, this function execute the evaluation loop and update the relevant Evaluators:\n", "\n", "`prequential_evaluation()` utilises `ClassificationEvaluator` and `ClassificationWindowedEvaluator`\n", "\n", "Previously, CapyMOA included two other functions: `cumulative_evaluation()` and `windowed_evaluation()`. However, since `prequential_evaluation()` incorporates the functionality of both we decided to remove those functions and focus on `prequential_evaluation()`.\n", "It's important to note that `prequential_evaluation()` is applicable to both `Regression` and `Prediction Intervals` besides `Classification`. The functionality and interpretation remain the same across these cases, but the metrics differ.\n", "\n", "**Result of a high-level function**\n", "\n", "* The return from `prequential_evaluation()` is a `PrequentialResults` object which provides access to the `cumulative` and `windowed` metrics as well as some other metrics (like wall-clock and cpu time).\n", "\n", "**Common characteristics for all high-level evaluation functions**\n", "\n", "* `prequential_evaluation()` specify a `max_instances` parameter, which by default is `None`. Depending on the source of the data (e.g. a real stream or a synthetic stream) the function will never stop! The intuition behind this is that Streams are infinite, we process them as such. Therefore, it is a good idea to specify `max_instances` unless you are using a snapshot of a stream (i.e. a `Dataset` like `Electricity`)\n", "\n", "**Evaluation practices in the literature (and practice)**\n", "\n", "Interested readers might want to peruse section **6.1.1 Error Estimation** from [Machine Learning for Data Streams](https://moa.cms.waikato.ac.nz/book-html/) book. We further expand the relationships between the literature and our evaluation functions in the documentation: https://www.capymoa.org" ] }, { "attachments": {}, "cell_type": "markdown", "id": "99802e5d-c85c-448a-86ca-93910a245666", "metadata": {}, "source": [ "### 2.1 prequential_evaluation()\n", "\n", "A `prequential_evaluation()` performs a windowed evaluation and a cumulative evaluation at once. Internally, it maintains a `ClassificationWindowedEvaluator` (for the windowed metrics) and ```ClassificationEvaluator``` (for the cumulative metrics). This allows us to have access to the **cumulative** and **windowed** results without running two separate evaluation functions. \n", "\n", "* The results returned from ```prequential_evaluation()``` allows accessing the Evaluator objects ```ClassificationWindowedEvaluator``` (attribute `windowed`) and ```ClassificationEvaluator``` (attribute `cumulative`) directly. \n", " \n", "* Notice that the computational overhead of training and assessing the same model twice outweighs the minimum overhead of updating the two Evaluators within the function. Thus, it is advisable to use the `prequential_evaluation()` function instead of creating separate `while` loops for evaluation\n", "\n", "* Advanced users might intuitively request metrics directly from the `results` object, which will return the `cumulative` metrics. For example, assuming `results = prequential_evaluation(...)`, `results.accuracy()` will return the `cumulative` accuracy. \n", "**IMPORTANT**: There are no IDE hints for these metrics as they are accessed dynamic via `__getattr__`. It is advisable that users access metrics explicitly through `results.cumulative` (or `results['cumulative']`) or `results.windowed` (or `results['windowed']`)\n", "\n", "* Invoking `results.metrics_per_window()` from a `results` object will return the dataframe with the `windowed` results.\n", "\n", "* `results.write_to_file()` will output the `cumulative` and `windowed` results to a directory.\n", "\n", "* `results.cumulative.metrics_dict()` will return all the cumulative metrics identifiers and their corresponding values in a dictionary structure\n", "\n", "* Invoking `plot_windowed_results()` with a `PrequentialResults` object will plot its `windowed` results\n", "\n", "* For plotting and analysis purposes, one might want to set `store_predictions=True` and `store_y=True` on the `prequential_evaluation()` function, which will include all the predictions and ground truth y in the PrequentialResults object. It is important to note that this can be costly in terms of memory depending on the size of the stream." ] }, { "cell_type": "code", "execution_count": 3, "id": "56ad96a5-f8cf-4a49-8817-07c2b284fbe6", "metadata": { "execution": { "iopub.execute_input": "2024-09-23T00:26:57.759462Z", "iopub.status.busy": "2024-09-23T00:26:57.759272Z", "iopub.status.idle": "2024-09-23T00:26:58.449867Z", "shell.execute_reply": "2024-09-23T00:26:58.449300Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\tDifferent ways of accessing metrics:\n", "results_ht['wallclock']: 0.018116235733032227 results_ht.wallclock(): 0.018116235733032227\n", "results_ht['cpu_time']: 0.06709921499999894 results_ht.cpu_time(): 0.06709921499999894\n", "results_ht.cumulative.accuracy() = 83.85000000000001\n", "results_ht.cumulative['accuracy'] = 83.85000000000001\n", "results_ht['cumulative'].accuracy() = 83.85000000000001\n", "results_ht.accuracy() = 83.85000000000001\n", "\n", "\tAll the cumulative results:\n", "{'instances': 2000.0, 'accuracy': 83.85000000000001, 'kappa': 66.04003700899992, 'kappa_t': -14.946619217081869, 'kappa_m': 59.010152284263974, 'f1_score': 83.03346476507683, 'f1_score_0': 86.77855096193205, 'f1_score_1': 79.25497752087348, 'precision': 83.24177714270593, 'precision_0': 85.82995951417004, 'precision_1': 80.65359477124183, 'recall': 82.82619238745067, 'recall_0': 87.74834437086093, 'recall_1': 77.90404040404042}\n", "\n", "\tAll the windowed results:\n" ] }, { "data": { "text/html": [ "
\n", " | instances | \n", "accuracy | \n", "kappa | \n", "kappa_t | \n", "kappa_m | \n", "f1_score | \n", "f1_score_0 | \n", "f1_score_1 | \n", "precision | \n", "precision_0 | \n", "precision_1 | \n", "recall | \n", "recall_0 | \n", "recall_1 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "100.0 | \n", "89.0 | \n", "75.663717 | \n", "31.250000 | \n", "64.516129 | \n", "87.841244 | \n", "91.603053 | \n", "84.057971 | \n", "87.582418 | \n", "92.307692 | \n", "82.857143 | \n", "88.101604 | \n", "90.909091 | \n", "85.294118 | \n", "
1 | \n", "200.0 | \n", "80.0 | \n", "49.367089 | \n", "-42.857143 | \n", "67.213115 | \n", "78.947368 | \n", "60.000000 | \n", "86.666667 | \n", "88.235294 | \n", "100.000000 | \n", "76.470588 | \n", "71.428571 | \n", "42.857143 | \n", "100.000000 | \n", "
2 | \n", "300.0 | \n", "71.0 | \n", "16.953036 | \n", "-141.666667 | \n", "29.268293 | \n", "58.514754 | \n", "81.290323 | \n", "35.555556 | \n", "58.114035 | \n", "82.894737 | \n", "33.333333 | \n", "58.921037 | \n", "79.746835 | \n", "38.095238 | \n", "
3 | \n", "400.0 | \n", "85.0 | \n", "66.637011 | \n", "-36.363636 | \n", "77.941176 | \n", "84.021504 | \n", "77.611940 | \n", "88.721805 | \n", "86.376882 | \n", "89.655172 | \n", "83.098592 | \n", "81.791171 | \n", "68.421053 | \n", "95.161290 | \n", "
4 | \n", "500.0 | \n", "87.0 | \n", "73.684211 | \n", "-8.333333 | \n", "80.000000 | \n", "87.218591 | \n", "88.495575 | \n", "85.057471 | \n", "87.916667 | \n", "83.333333 | \n", "92.500000 | \n", "86.531513 | \n", "94.339623 | \n", "78.723404 | \n", "
5 | \n", "600.0 | \n", "84.0 | \n", "64.221825 | \n", "-14.285714 | \n", "54.285714 | \n", "82.965706 | \n", "88.059701 | \n", "75.757576 | \n", "85.615079 | \n", "81.944444 | \n", "89.285714 | \n", "80.475382 | \n", "95.161290 | \n", "65.789474 | \n", "
6 | \n", "700.0 | \n", "85.0 | \n", "70.000000 | \n", "16.666667 | \n", "70.588235 | \n", "85.880856 | \n", "83.146067 | \n", "86.486486 | \n", "85.000000 | \n", "74.000000 | \n", "96.000000 | \n", "86.780160 | \n", "94.871795 | \n", "78.688525 | \n", "
7 | \n", "800.0 | \n", "99.0 | \n", "97.954173 | \n", "94.117647 | \n", "97.674419 | \n", "98.987342 | \n", "98.823529 | \n", "99.130435 | \n", "99.137931 | \n", "100.000000 | \n", "98.275862 | \n", "98.837209 | \n", "97.674419 | \n", "100.000000 | \n", "
8 | \n", "900.0 | \n", "78.0 | \n", "57.446809 | \n", "-15.789474 | \n", "56.862745 | \n", "81.751825 | \n", "80.357143 | \n", "75.000000 | \n", "83.582090 | \n", "67.164179 | \n", "100.000000 | \n", "80.000000 | \n", "100.000000 | \n", "60.000000 | \n", "
9 | \n", "1000.0 | \n", "96.0 | \n", "91.922456 | \n", "50.000000 | \n", "92.727273 | \n", "95.998775 | \n", "95.555556 | \n", "96.363636 | \n", "96.185065 | \n", "97.727273 | \n", "94.642857 | \n", "95.813205 | \n", "93.478261 | \n", "98.148148 | \n", "
10 | \n", "1100.0 | \n", "83.0 | \n", "1.162791 | \n", "-142.857143 | \n", "0.000000 | \n", "50.583013 | \n", "90.607735 | \n", "10.526316 | \n", "50.555556 | \n", "91.111111 | \n", "10.000000 | \n", "50.610501 | \n", "90.109890 | \n", "11.111111 | \n", "
11 | \n", "1200.0 | \n", "76.0 | \n", "10.979228 | \n", "-100.000000 | \n", "7.692308 | \n", "66.740576 | \n", "86.046512 | \n", "14.285714 | \n", "87.755102 | \n", "75.510204 | \n", "100.000000 | \n", "53.846154 | \n", "100.000000 | \n", "7.692308 | \n", "
12 | \n", "1300.0 | \n", "87.0 | \n", "66.529351 | \n", "-62.500000 | \n", "59.375000 | \n", "85.391617 | \n", "91.275168 | \n", "74.509804 | \n", "91.975309 | \n", "83.950617 | \n", "100.000000 | \n", "79.687500 | \n", "100.000000 | \n", "59.375000 | \n", "
13 | \n", "1400.0 | \n", "91.0 | \n", "64.285714 | \n", "57.142857 | \n", "52.631579 | \n", "84.639017 | \n", "94.736842 | \n", "68.965517 | \n", "95.000000 | \n", "90.000000 | \n", "100.000000 | \n", "76.315789 | \n", "100.000000 | \n", "52.631579 | \n", "
14 | \n", "1500.0 | \n", "92.0 | \n", "62.686567 | \n", "42.857143 | \n", "50.000000 | \n", "84.076433 | \n", "95.454545 | \n", "66.666667 | \n", "95.652174 | \n", "91.304348 | \n", "100.000000 | \n", "75.000000 | \n", "100.000000 | \n", "50.000000 | \n", "
15 | \n", "1600.0 | \n", "89.0 | \n", "73.170732 | \n", "21.428571 | \n", "65.625000 | \n", "87.145717 | \n", "92.307692 | \n", "80.701754 | \n", "90.000000 | \n", "88.000000 | \n", "92.000000 | \n", "84.466912 | \n", "97.058824 | \n", "71.875000 | \n", "
16 | \n", "1700.0 | \n", "89.0 | \n", "78.000000 | \n", "8.333333 | \n", "76.086957 | \n", "89.196535 | \n", "89.523810 | \n", "88.421053 | \n", "89.393939 | \n", "85.454545 | \n", "93.333333 | \n", "89.000000 | \n", "94.000000 | \n", "84.000000 | \n", "
17 | \n", "1800.0 | \n", "72.0 | \n", "45.141066 | \n", "-47.368421 | \n", "9.677419 | \n", "77.094241 | \n", "63.157895 | \n", "77.419355 | \n", "81.578947 | \n", "100.000000 | \n", "63.157895 | \n", "73.076923 | \n", "46.153846 | \n", "100.000000 | \n", "
18 | \n", "1900.0 | \n", "58.0 | \n", "24.677188 | \n", "-200.000000 | \n", "-31.250000 | \n", "65.568421 | \n", "58.823529 | \n", "57.142857 | \n", "65.329768 | \n", "88.235294 | \n", "42.424242 | \n", "65.808824 | \n", "44.117647 | \n", "87.500000 | \n", "
19 | \n", "2000.0 | \n", "86.0 | \n", "66.410749 | \n", "26.315789 | \n", "58.823529 | \n", "84.238820 | \n", "90.140845 | \n", "75.862069 | \n", "87.938596 | \n", "84.210526 | \n", "91.666667 | \n", "80.837790 | \n", "96.969697 | \n", "64.705882 | \n", "