SLDA#
- class capymoa.ocl.strategy.SLDA[source]#
Bases:
BatchClassifier
Streaming Linear Discriminant Analysis (SLDA).
SLDA incrementally accumulates the mean of each class and a joint mean and covariance for all classes. Note that this method does not gracefully forget and may not handle concept drift well.
[Hayes20] uses SLDA ontop of a pre-trained model to perform continual learning. See [WikipediaILDA] and [Ghassabeh2015] for more details on incremental LDA outside of continual learning.
[Hayes20]Hayes, T. L., & Kanan, C. (2020). Lifelong Machine Learning with Deep Streaming Linear Discriminant Analysis. CLVision Workshop at CVPR 2020, 1–15.
[Ghassabeh2015]Ghassabeh, Y. A., Rudzicz, F., & Moghaddam, H. A. (2015). Fast incremental LDA feature extraction. Pattern Recognition, 48(6), 1999-2012.
- __init__(
- schema: Schema,
- pre_processor: Module = nn.Identity(),
- num_features: int | None = None,
- ridge: float = 1e-6,
- device: device | str = torch.device('cpu'),
Initialize a SLDA classifier.
- Parameters:
schema – Describes the shape and type of the data.
pre_processor – A pre-processing module to apply to the input data, defaults to an identity module.
num_features – Number of features once pre-processed, defaults to the number of attributes in the schema.
ridge – Ridge regularization term to avoid singular covariance matrix, defaults to 1e-6.
device – Device to run the model on, defaults to CPU.
- batch_predict_proba(x: Tensor) Tensor [source]#
Predict the probabilities of the classes for a batch of instances.
- predict(instance: Instance) int | None [source]#
Predict the label of an instance.
The base implementation calls
predict_proba()
and returns the label with the highest probability.- Parameters:
instance – The instance to predict the label for.
- Returns:
The predicted label or
None
if the classifier is unable to make a prediction.
- predict_proba(
- instance: Instance,
Calls
batch_predict_proba()
with a batch of size 1.
- train(instance: LabeledInstance) None [source]#
Calls
batch_train()
with a batch of size 1.
- random_seed: int#
The random seed for reproducibility.
When implementing a classifier ensure random number generators are seeded.