10. Online Continual Learning#

In machine learning, continual learning is a problem setting where a model is trained on a sequence of tasks and must perform well on all tasks seen so far. A task is a specific concept or relation the model expects to learn. Online continual learning adds the constraints that the model considers each example exactly once, and the model can perform prediction at any time.

In contrast to continual learning, the typical data stream problem setting adapts to changes in data distribution by discarding knowledge of the past. However, in continual learning, the model must retain knowledge of past tasks while learning new tasks.

Continual learning is synonymous with overcoming catastrophic forgetting, a phenomenon in deep learning where a model trained on a sequence of tasks forgets how to perform well on the initial tasks. Non-deep learning models can be immune to forgetting.


In this notebook, we implement “Experience Replay” (ER), a classic online continual learning strategy that stores a buffer of past examples. By sampling from the buffer during training, ER avoids catastrophic forgetting.

10.1 Reservoir Sampling#

Experience Replay uses reservoir sampling to construct a simple random sample incrementally from a data stream of unknown length. Here, we implement reservoir sampling “Algorithm R” (Vitter, 1985).

[2]:
import torch
from typing import Tuple
from torch import Tensor


class ReservoirSampler:
    def __init__(self, item_count: int, feature_count: int):
        self.item_count = item_count
        self.feature_count = feature_count
        self.reservoir_x = torch.zeros((item_count, feature_count))
        self.reservoir_y = torch.zeros((item_count,), dtype=torch.long)
        self.count = 0

    def update(self, x: Tensor, y: Tensor) -> None:
        batch_size = x.shape[0]
        assert x.shape == (
            batch_size,
            self.feature_count,
        )
        assert y.shape == (batch_size,)

        for i in range(batch_size):
            if self.count < self.item_count:
                # Fill the reservoir
                self.reservoir_x[self.count] = x[i]
                self.reservoir_y[self.count] = y[i]
            else:
                # Reservoir sampling
                index = torch.randint(0, self.count + 1, (1,))
                if index < self.item_count:
                    self.reservoir_x[index] = x[i]
                    self.reservoir_y[index] = y[i]
            self.count += 1

    def sample_n(self, n: int) -> Tuple[Tensor, Tensor]:
        indices = torch.randint(0, min(self.count, self.item_count), (n,))
        return self.reservoir_x[indices], self.reservoir_y[indices]

Let’s check to see if it’s samples look uniform.

[3]:
from matplotlib import pyplot as plt

x = torch.arange(0, 1_000).reshape(-1, 1).float()
y = torch.zeros(1_000, dtype=torch.long)
sampler = ReservoirSampler(500, 1)
sampler.update(x, y)
x = sampler.sample_n(200)
plt.hist(x[0].numpy(), bins=20)
plt.show()
../_images/notebooks_10_ocl_5_0.png

10.2 Experience Replay#

[4]:
from capymoa.base import BatchClassifier
from capymoa.instance import Instance
from capymoa.stream import Schema
from capymoa.type_alias import LabelProbabilities
from torch.nn.functional import cross_entropy
import numpy as np
from torch import nn


class ExperienceReplay(BatchClassifier):
    def __init__(
        self,
        schema: Schema,
        model: nn.Module,
        reservoir_size: int,
        batch_size: int,
        learning_rate: float,
        device: str = "cpu",
    ):
        super().__init__(schema=schema, batch_size=batch_size)
        self.reservoir = ReservoirSampler(reservoir_size, schema.get_num_attributes())
        self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        self.model = model.to(device)
        self.device = device
        self.batch_size = batch_size

    def batch_train(self, x: np.ndarray, y: np.ndarray):
        x: Tensor = torch.from_numpy(x)
        y: Tensor = torch.from_numpy(y).long()

        self.reservoir.update(x, y)

        replay_x, replay_y = self.reservoir.sample_n(self.batch_size)
        train_x = torch.cat((x, replay_x), dim=0).to(self.device)
        train_y = torch.cat((y, replay_y), dim=0).to(self.device)

        self.optimizer.zero_grad()
        y_hat = self.model(train_x)
        loss = cross_entropy(y_hat, train_y)
        loss.backward()
        self.optimizer.step()

    @torch.no_grad
    def predict_proba(self, instance: Instance) -> LabelProbabilities:
        x = torch.from_numpy(instance.x).to(self.device)
        y_hat: Tensor = self.model.forward(x)
        return y_hat.softmax(dim=0).cpu().numpy()

    def __str__(self) -> str:
        return "ExperienceReplay"

10.3 Multi Layer Perceptron#

We create a simple multi-layer perceptron (MLP) with a single hidden layer to demonstrate continual learning.

The output layer of a neural network is often problematic in continual learning because of the extreme and shifting class imbalance between tasks. Lesort et al. (2021) suggest mitigating this by using a variant of weight normalization that parameterizes the weights as a magnitude (set to the unit vector) and a direction.

  • Lesort, T., George, T., & Rish, I. (2021). Continual Learning in Deep Networks: An Analysis of the Last Layer.

[5]:
class SimpleMLP(nn.Module):
    def __init__(self, schema: Schema, hidden_size: int):
        super().__init__()
        num_classes = schema.get_num_classes()

        self.fc1 = nn.Linear(schema.get_num_attributes(), hidden_size)
        self.fc2 = nn.Linear(hidden_size, num_classes, bias=False)
        self.fc2 = nn.utils.parametrizations.weight_norm(self.fc2, name="weight")
        weight_g = self.fc2.parametrizations.weight.original0
        # Set the magnitude to the unit vector
        weight_g.requires_grad_(False).fill_(1.0 / (num_classes**0.5))

    def forward(self, x: Tensor) -> Tensor:
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x
[6]:
from capymoa.evaluation.ocl import ocl_train_eval_loop
from capymoa.datasets.ocl import SplitMNIST

stream = SplitMNIST()
mlp = SimpleMLP(stream.schema, 64)
learner = ExperienceReplay(
    stream.schema,
    mlp,
    reservoir_size=200,
    batch_size=64,
    learning_rate=0.01,
    device="cpu",
)
r = ocl_train_eval_loop(
    learner,
    stream.train_streams,
    stream.test_streams,
    continual_evaluations=10,
    progress_bar=True,
)
print(f"Forward Transfer  {r.forward_transfer:.2f}")
print(f"Backward Transfer {r.backward_transfer:.2f}")
print(f"Accuracy          {r.accuracy_final:.2f}")
print(f"Online Accuracy   {r.prequential_cumulative_accuracy:.2f}")
Forward Transfer  0.00
Backward Transfer -0.18
Accuracy          0.83
Online Accuracy   0.90

10.4 Evaluation#

  • Task 0 start high and decreases as the model forgets the task.

  • Task 1-4 start at zero since the model has not seen the task yet.

[9]:
from matplotlib import pyplot as plt

fig, ax = plt.subplots(figsize=(8, 4))

cmap = plt.get_cmap("tab10")
for t in range(5):
    ax.scatter(r.task_index, r.accuracy_matrix[:, t], color=cmap(t), label=f"Task {t}")
    ax.plot(r.anytime_task_index, r.anytime_accuracy_matrix[:, t], color=cmap(t))

ax.set_xlabel("Task")
ax.set_xticks(range(6))
ax.set_ylabel("Accuracy")
ax.set_title("SplitMNIST Per-Task Accuracy Over Tasks")
ax.legend(frameon=False)
pass
../_images/notebooks_10_ocl_12_0.png
[8]:
def hline(ax, y, label, color_id):
    ax.hlines(y, 0, 5, linestyles="--", label=label, color=cmap(color_id))


fig, ax = plt.subplots(figsize=(8, 4))
# Plot the accuracy on all tasks over the course of tasks
ax.scatter(r.task_index, r.accuracy_all, label="Acc. (all)")
ax.plot(r.anytime_task_index, r.anytime_accuracy_all, label="Anytime Acc. (all)")
hline(ax, r.anytime_accuracy_all_avg, "Avg. Anytime Acc. (all)", 0)

# Plot the accuracy on previously seen tasks over the course of tasks
ax.scatter(r.task_index, r.accuracy_seen, label="Anytime Acc. (seen)")
ax.plot(r.anytime_task_index, r.anytime_accuracy_seen, label="Anytime Acc. (seen)")
hline(ax, r.anytime_accuracy_seen_avg, "Avg. Anytime Acc. (seen)", 1)

ax.legend(ncol=2, frameon=False)
ax.set_xlabel("Task")
ax.set_xticks(range(6))
ax.set_ylabel("Accuracy")
ax.set_title("SplitMNIST Accuracy Over Tasks")
ax.set_ylim(0, 1.05)
pass
../_images/notebooks_10_ocl_13_0.png