{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 10. Online Continual Learning\n", "\n", "In machine learning, continual learning is a problem setting where a model is\n", "trained on a sequence of tasks and must perform well on all tasks seen so far. A\n", "task is a specific concept or relation the model expects to learn. Online\n", "continual learning adds the constraints that the model considers each example\n", "exactly once, and the model can perform prediction at any time.\n", "\n", "In contrast to continual learning, the typical data stream problem setting\n", "adapts to changes in data distribution by discarding knowledge of the past.\n", "However, in continual learning, the model must retain knowledge of past tasks\n", "while learning new tasks.\n", "\n", "Continual learning is synonymous with overcoming catastrophic forgetting, a\n", "phenomenon in deep learning where a model trained on a sequence of tasks forgets\n", "how to perform well on the initial tasks. Non-deep learning models can be immune\n", "to forgetting.\n", "\n", "---\n", "\n", "In this notebook, we implement “Experience Replay” (ER), a classic online\n", "continual learning strategy that stores a buffer of past examples. By sampling\n", "from the buffer during training, ER avoids catastrophic forgetting." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "nbsphinx": "hidden" }, "outputs": [], "source": [ "# This cell is hidden on capymoa.org. See docs/contributing/docs.rst\n", "from util.nbmock import mock_datasets, is_nb_fast\n", "\n", "if is_nb_fast():\n", " mock_datasets()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 10.1 Reservoir Sampling\n", "\n", "Experience Replay uses [reservoir sampling](https://en.wikipedia.org/wiki/Reservoir_sampling)\n", "to construct a simple random sample incrementally from a data stream of unknown\n", "length. Here, we implement reservoir sampling \"Algorithm R\" (Vitter, 1985).\n", "\n", "- Jeffrey S. Vitter. 1985. Random sampling with a reservoir. ACM Trans. Math.\n", " Softw. 11, 1 (March 1985), 37–57. https://doi.org/10.1145/3147.3165" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from typing import Tuple\n", "from torch import Tensor\n", "\n", "\n", "class ReservoirSampler:\n", " def __init__(self, item_count: int, feature_count: int):\n", " self.item_count = item_count\n", " self.feature_count = feature_count\n", " self.reservoir_x = torch.zeros((item_count, feature_count))\n", " self.reservoir_y = torch.zeros((item_count,), dtype=torch.long)\n", " self.count = 0\n", "\n", " def update(self, x: Tensor, y: Tensor) -> None:\n", " batch_size = x.shape[0]\n", " assert x.shape == (\n", " batch_size,\n", " self.feature_count,\n", " )\n", " assert y.shape == (batch_size,)\n", "\n", " for i in range(batch_size):\n", " if self.count < self.item_count:\n", " # Fill the reservoir\n", " self.reservoir_x[self.count] = x[i]\n", " self.reservoir_y[self.count] = y[i]\n", " else:\n", " # Reservoir sampling\n", " index = torch.randint(0, self.count + 1, (1,))\n", " if index < self.item_count:\n", " self.reservoir_x[index] = x[i]\n", " self.reservoir_y[index] = y[i]\n", " self.count += 1\n", "\n", " def sample_n(self, n: int) -> Tuple[Tensor, Tensor]:\n", " indices = torch.randint(0, min(self.count, self.item_count), (n,))\n", " return self.reservoir_x[indices], self.reservoir_y[indices]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's check to see if it's samples look uniform." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from matplotlib import pyplot as plt\n", "\n", "x = torch.arange(0, 1_000).reshape(-1, 1).float()\n", "y = torch.zeros(1_000, dtype=torch.long)\n", "sampler = ReservoirSampler(500, 1)\n", "sampler.update(x, y)\n", "x = sampler.sample_n(200)\n", "plt.hist(x[0].numpy(), bins=20)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 10.2 Experience Replay\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from capymoa.base import BatchClassifier\n", "from capymoa.instance import Instance\n", "from capymoa.stream import Schema\n", "from capymoa.type_alias import LabelProbabilities\n", "from torch.nn.functional import cross_entropy\n", "import numpy as np\n", "from torch import nn\n", "\n", "\n", "class ExperienceReplay(BatchClassifier):\n", " def __init__(\n", " self,\n", " schema: Schema,\n", " model: nn.Module,\n", " reservoir_size: int,\n", " batch_size: int,\n", " learning_rate: float,\n", " device: str = \"cpu\",\n", " ):\n", " super().__init__(schema=schema, batch_size=batch_size)\n", " self.reservoir = ReservoirSampler(reservoir_size, schema.get_num_attributes())\n", " self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n", " self.model = model.to(device)\n", " self.device = device\n", " self.batch_size = batch_size\n", "\n", " def batch_train(self, x: np.ndarray, y: np.ndarray):\n", " x: Tensor = torch.from_numpy(x)\n", " y: Tensor = torch.from_numpy(y).long()\n", "\n", " self.reservoir.update(x, y)\n", "\n", " replay_x, replay_y = self.reservoir.sample_n(self.batch_size)\n", " train_x = torch.cat((x, replay_x), dim=0).to(self.device)\n", " train_y = torch.cat((y, replay_y), dim=0).to(self.device)\n", "\n", " self.optimizer.zero_grad()\n", " y_hat = self.model(train_x)\n", " loss = cross_entropy(y_hat, train_y)\n", " loss.backward()\n", " self.optimizer.step()\n", "\n", " @torch.no_grad\n", " def predict_proba(self, instance: Instance) -> LabelProbabilities:\n", " x = torch.from_numpy(instance.x).to(self.device)\n", " y_hat: Tensor = self.model.forward(x)\n", " return y_hat.softmax(dim=0).cpu().numpy()\n", "\n", " def __str__(self) -> str:\n", " return \"ExperienceReplay\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 10.3 Multi Layer Perceptron\n", "\n", "We create a simple multi-layer perceptron (MLP) with a single hidden layer to demonstrate continual learning.\n", "\n", "The output layer of a neural network is often problematic in continual learning because of the extreme and shifting\n", "class imbalance between tasks. Lesort et al. (2021) suggest mitigating this by using a variant of weight normalization\n", "that parameterizes the weights as a magnitude (set to the unit vector) and a direction. \n", "\n", "* Lesort, T., George, T., & Rish, I. (2021). Continual Learning in Deep Networks:\n", " An Analysis of the Last Layer." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class SimpleMLP(nn.Module):\n", " def __init__(self, schema: Schema, hidden_size: int):\n", " super().__init__()\n", " num_classes = schema.get_num_classes()\n", "\n", " self.fc1 = nn.Linear(schema.get_num_attributes(), hidden_size)\n", " self.fc2 = nn.Linear(hidden_size, num_classes, bias=False)\n", " self.fc2 = nn.utils.parametrizations.weight_norm(self.fc2, name=\"weight\")\n", " weight_g = self.fc2.parametrizations.weight.original0\n", " # Set the magnitude to the unit vector\n", " weight_g.requires_grad_(False).fill_(1.0 / (num_classes**0.5))\n", "\n", " def forward(self, x: Tensor) -> Tensor:\n", " x = torch.relu(self.fc1(x))\n", " x = self.fc2(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e8bfff69ea104021b6e1b51ec4a45896", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Train & Eval: 0%| | 0/560000 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from matplotlib import pyplot as plt\n", "\n", "fig, ax = plt.subplots(figsize=(8, 4))\n", "\n", "cmap = plt.get_cmap(\"tab10\")\n", "for t in range(5):\n", " ax.scatter(r.task_index, r.accuracy_matrix[:, t], color=cmap(t), label=f\"Task {t}\")\n", " ax.plot(r.anytime_task_index, r.anytime_accuracy_matrix[:, t], color=cmap(t))\n", "\n", "ax.set_xlabel(\"Task\")\n", "ax.set_xticks(range(6))\n", "ax.set_ylabel(\"Accuracy\")\n", "ax.set_title(\"SplitMNIST Per-Task Accuracy Over Tasks\")\n", "ax.legend(frameon=False)\n", "pass" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def hline(ax, y, label, color_id):\n", " ax.hlines(y, 0, 5, linestyles=\"--\", label=label, color=cmap(color_id))\n", "\n", "\n", "fig, ax = plt.subplots(figsize=(8, 4))\n", "# Plot the accuracy on all tasks over the course of tasks\n", "ax.scatter(r.task_index, r.accuracy_all, label=\"Acc. (all)\")\n", "ax.plot(r.anytime_task_index, r.anytime_accuracy_all, label=\"Anytime Acc. (all)\")\n", "hline(ax, r.anytime_accuracy_all_avg, \"Avg. Anytime Acc. (all)\", 0)\n", "\n", "# Plot the accuracy on previously seen tasks over the course of tasks\n", "ax.scatter(r.task_index, r.accuracy_seen, label=\"Anytime Acc. (seen)\")\n", "ax.plot(r.anytime_task_index, r.anytime_accuracy_seen, label=\"Anytime Acc. (seen)\")\n", "hline(ax, r.anytime_accuracy_seen_avg, \"Avg. Anytime Acc. (seen)\", 1)\n", "\n", "ax.legend(ncol=2, frameon=False)\n", "ax.set_xlabel(\"Task\")\n", "ax.set_xticks(range(6))\n", "ax.set_ylabel(\"Accuracy\")\n", "ax.set_title(\"SplitMNIST Accuracy Over Tasks\")\n", "ax.set_ylim(0, 1.05)\n", "pass" ] } ], "metadata": { "kernelspec": { "display_name": "capymoa", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.8" } }, "nbformat": 4, "nbformat_minor": 2 }