ExperienceReplay#
- class capymoa.ocl.strategy.ExperienceReplay[source]#
Bases:
BatchClassifier
,TrainTaskAware
,TestTaskAware
Experience Replay (ER) strategy for continual learning.
Uses a replay buffer to store past experiences and samples from it during training to mitigate catastrophic forgetting.
The replay buffer is implemented using reservoir sampling, which allows for uniform sampling over the entire stream [vitter1985].
Not
capymoa.ocl.base.TrainTaskAware
orcapymoa.ocl.base.TestTaskAware
, but will proxy it to the wrapped learner.
[vitter1985]Jeffrey S. Vitter. 1985. Random sampling with a reservoir. ACM Trans. Math. Softw. 11, 1 (March 1985), 37–57. https://doi.org/10.1145/3147.3165
- __init__(
- learner: BatchClassifier,
- buffer_size: int = 200,
- repeat: int = 1,
Initialize the Experience Replay strategy.
- Parameters:
learner – The learner to be wrapped for experience replay.
buffer_size – The size of the replay buffer, defaults to 200.
repeat – The number of times to repeat the training data in each batch, defaults to 1.
- batch_predict_proba(x: Tensor) Tensor [source]#
Predict the probabilities of the classes for a batch of instances.
- predict(instance: Instance) int | None [source]#
Predict the label of an instance.
The base implementation calls
predict_proba()
and returns the label with the highest probability.- Parameters:
instance – The instance to predict the label for.
- Returns:
The predicted label or
None
if the classifier is unable to make a prediction.
- predict_proba(
- instance: Instance,
Calls
batch_predict_proba()
with a batch of size 1.
- train(instance: LabeledInstance) None [source]#
Calls
batch_train()
with a batch of size 1.
- device: torch.device = device(type='cpu')#
Device on which the batch will be processed.
- learner#
The wrapped learner to be trained with experience replay.
- random_seed: int#
The random seed for reproducibility.
When implementing a classifier ensure random number generators are seeded.
- x_dtype: torch.dtype = torch.float32[source]#
Data type for the input features.
- y_dtype: torch.dtype = torch.int64[source]#
Data type for the target value/labels.