SLDA#
- class capymoa.ocl.strategy.SLDA[source]#
Bases:
BatchClassifier
Streaming Linear Discriminant Analysis.
Streaming Linear Discriminant Analysis (SLDA) [1] is a prototype classifier that incrementally accumulates the mean of each class and the mean and covariance across all classes. Note that this method does not gracefully forget and may not handle concept drift well.
[1] uses SLDA ontop of a pre-trained model to perform continual learning. See [2] and [3] for more details on incremental LDA outside of continual learning.
>>> from capymoa.ocl.strategy import SLDA >>> from capymoa.ocl.datasets import TinySplitMNIST >>> from capymoa.ocl.evaluation import ocl_train_eval_loop >>> scenario = TinySplitMNIST() >>> learner = SLDA(scenario.schema) >>> results = ocl_train_eval_loop( ... learner, ... scenario.train_loaders(32), ... scenario.test_loaders(32), ... ) >>> print(f"{results.accuracy_final*100:.1f}%") 75.5%
- __init__(
- schema: Schema,
- pre_processor: Module = nn.Identity(),
- num_features: int | None = None,
- ridge: float = 1e-6,
- device: device | str = torch.device('cpu'),
Initialize a SLDA classifier.
- Parameters:
schema – Describes the shape and type of the data.
pre_processor – A pre-processing module to apply to the input data, defaults to an identity module.
num_features – Number of features once pre-processed, defaults to the number of attributes in the schema.
ridge – Ridge regularization term to avoid singular covariance matrix, defaults to 1e-6.
device – Device to run the model on, defaults to CPU.
- batch_predict_proba(x: Tensor) Tensor [source]#
Predict the probabilities of the classes for a batch of instances.
- predict(instance: Instance) int | None [source]#
Predict the label of an instance.
The base implementation calls
predict_proba()
and returns the label with the highest probability.- Parameters:
instance – The instance to predict the label for.
- Returns:
The predicted label or
None
if the classifier is unable to make a prediction.
- predict_proba(
- instance: Instance,
Calls
batch_predict_proba()
with a batch of size 1.
- train(instance: LabeledInstance) None [source]#
Calls
batch_train()
with a batch of size 1.
- random_seed: int#
The random seed for reproducibility.
When implementing a classifier ensure random number generators are seeded.