ExperienceReplay#
- class capymoa.ocl.strategy.ExperienceReplay[source]#
Bases:
BatchClassifier
,TrainTaskAware
,TestTaskAware
Experience Replay.
Experience Replay (ER) [1] is a replay continual learning strategy.
Uses a replay buffer to store past experiences and samples from it during training to mitigate catastrophic forgetting.
The replay buffer is implemented using reservoir sampling, which allows for uniform sampling over the entire stream [2].
Not
capymoa.ocl.base.TrainTaskAware
orcapymoa.ocl.base.TestTaskAware
, but will proxy it to the wrapped learner.
>>> from capymoa.ann import Perceptron >>> from capymoa.classifier import Finetune >>> from capymoa.ocl.strategy import ExperienceReplay >>> from capymoa.ocl.datasets import TinySplitMNIST >>> from capymoa.ocl.evaluation import ocl_train_eval_loop >>> import torch >>> _ = torch.manual_seed(0) >>> scenario = TinySplitMNIST() >>> model = Perceptron(scenario.schema) >>> learner = ExperienceReplay(Finetune(scenario.schema, model)) >>> results = ocl_train_eval_loop( ... learner, ... scenario.train_loaders(32), ... scenario.test_loaders(32), ... ) >>> print(f"{results.accuracy_final*100:.1f}%") 33.0%
- __init__(
- learner: BatchClassifier,
- buffer_size: int = 200,
- repeat: int = 1,
Initialize the Experience Replay strategy.
- Parameters:
learner – The learner to be wrapped for experience replay.
buffer_size – The size of the replay buffer, defaults to 200.
repeat – The number of times to repeat the training data in each batch, defaults to 1.
- batch_predict_proba(x: Tensor) Tensor [source]#
Predict the probabilities of the classes for a batch of instances.
- predict(instance: Instance) int | None [source]#
Predict the label of an instance.
The base implementation calls
predict_proba()
and returns the label with the highest probability.- Parameters:
instance – The instance to predict the label for.
- Returns:
The predicted label or
None
if the classifier is unable to make a prediction.
- predict_proba(
- instance: Instance,
Calls
batch_predict_proba()
with a batch of size 1.
- train(instance: LabeledInstance) None [source]#
Calls
batch_train()
with a batch of size 1.
- device: torch.device = device(type='cpu')#
Device on which the batch will be processed.
- learner#
The wrapped learner to be trained with experience replay.
- random_seed: int#
The random seed for reproducibility.
When implementing a classifier ensure random number generators are seeded.
- x_dtype: torch.dtype = torch.float32[source]#
Data type for the input features.
- y_dtype: torch.dtype = torch.int64[source]#
Data type for the target value/labels.