{ "cells": [ { "cell_type": "markdown", "id": "3b115501e7927be", "metadata": {}, "source": [ "# Save and Load a Model\n", "\n", "In this tutorial, we illustrate the process of saving and loading a model using CapyMOA. \n", "\n", "* We use the SEA synthetic generator as the data source, and the AdaptiveRandomForestClassifier as the learner.\n", "* The trained model is saved to a file, specifically 'capymoa_model.pkl'.\n", "* Subsequently, we reload the model from the file and resume training and evaluating its performance on the SEA data.\n", "* As a final step, we delete the model file." ] }, { "cell_type": "markdown", "id": "b4f5698f-c632-44ef-b337-0549dc5a5168", "metadata": {}, "source": [ "## 1. Training and saving the model\n", "\n", "* We train the model on 5k instances from SEA using the `evaluate_prequential` function\n", "* We proceed to save the model with `save_model(learner, \"capymoa_ARF_model.pkl\")`" ] }, { "cell_type": "code", "execution_count": 1, "id": "b7ca1c5addd95ba3", "metadata": { "ExecuteTime": { "end_time": "2024-05-29T08:18:33.715465Z", "start_time": "2024-05-29T08:18:27.317959Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 87.96\n" ] } ], "source": [ "from capymoa.classifier import AdaptiveRandomForestClassifier\n", "from capymoa.evaluation import prequential_evaluation\n", "from capymoa.stream.generator import SEA\n", "from capymoa.misc import save_model, load_model\n", "\n", "stream = SEA()\n", "learner = AdaptiveRandomForestClassifier(schema=stream.get_schema(), ensemble_size=10)\n", "\n", "results = prequential_evaluation(stream=stream, learner=learner, max_instances=5000)\n", "\n", "print(f\"Accuracy: {results['cumulative'].accuracy():.2f}\")\n", "save_model(learner, \"capymoa_ARF_model.pkl\") # Save model to capymoa_model.pkl" ] }, { "cell_type": "markdown", "id": "934db64f-b4ca-4628-9ad0-9a8d669a2c6b", "metadata": {}, "source": [ "## 2. Loading and resuming training\n", "\n", "* We use `os.path.getsize()` to inspect the size (KB) of the saved file.\n", "* We don't restart the synthetic stream, we just continue processing it through another call to `prequential_evaluation`\n", "* Finally, we observe the accuracy " ] }, { "cell_type": "code", "execution_count": 2, "id": "2f7dd29e2ed65686", "metadata": { "ExecuteTime": { "end_time": "2024-05-29T08:18:37.737826Z", "start_time": "2024-05-29T08:18:33.717028Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The saved model size: 618.28 KB\n", "Updated accuracy: 89.32\n" ] } ], "source": [ "import os\n", "\n", "model_file = 'capymoa_ARF_model.pkl'\n", "\n", "model_size = os.path.getsize(model_file)\n", "print(f\"The saved model size: {model_size / 1024:.2f} KB\")\n", "\n", "restored_learner = load_model(\"capymoa_ARF_model.pkl\") # Load model from capymoa_model.pkl\n", "\n", "# Train for more 50k instances on the restored model\n", "results = prequential_evaluation(stream=stream, learner=restored_learner, max_instances=5000)\n", " \n", "print(f\"Updated accuracy: {results['cumulative'].accuracy():.2f}\")" ] }, { "cell_type": "markdown", "id": "d86a6b63-4a4d-4753-bda5-19a632aad41d", "metadata": {}, "source": [ "## 3. Cleanup \n", "\n", "* As a last step, we delete the model" ] }, { "cell_type": "code", "execution_count": 3, "id": "e00a292713d154ee", "metadata": { "ExecuteTime": { "end_time": "2024-05-29T08:18:37.977906Z", "start_time": "2024-05-29T08:18:37.739209Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "File capymoa_ARF_model.pkl has been deleted.\n" ] } ], "source": [ "if os.path.exists(model_file):\n", " os.remove(model_file)\n", " print(f\"File {model_file} has been deleted.\")\n", "else:\n", " print(f\"File {model_file} not found.\")" ] }, { "cell_type": "code", "execution_count": null, "id": "59c7b60f-908e-4640-a6b8-50650d6c9287", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 5 }