SKClassifier#
- class capymoa.base.SKClassifier[source]#
Bases:
Classifier
A wrapper class for using scikit-learn classifiers in CapyMOA.
Some of scikit-learn’s classifiers that are compatible with online learning have been wrapped and tested already in CapyMOA (See
capymoa.classifier
).However, if you want to use a scikit-learn classifier that has not been wrapped yet, you can use this class to wrap it yourself. This requires that the scikit-learn classifier implements the
partial_fit
andpredict
methods.For example, the following code demonstrates how to use a scikit-learn classifier in CapyMOA:
>>> from sklearn.linear_model import SGDClassifier >>> from capymoa.base import SKClassifier >>> from capymoa.datasets import ElectricityTiny >>> stream = ElectricityTiny() >>> sklearner = SGDClassifier(random_state=1) >>> learner = SKClassifier(sklearner, stream.schema) >>> for _ in range(10): ... instance = stream.next_instance() ... prediction = learner.predict(instance) ... print(f"True: {instance.y_index}, Predicted: {prediction}") ... learner.train(instance) True: 1, Predicted: None True: 1, Predicted: 1 True: 1, Predicted: 1 True: 1, Predicted: 1 True: 0, Predicted: 1 True: 0, Predicted: 1 True: 0, Predicted: 0 True: 0, Predicted: 0 True: 0, Predicted: 0 True: 0, Predicted: 0
A word of caution: even compatible scikit-learn classifiers are not necessarily designed for online learning and might require some tweaking to work well in an online setting.
See also
capymoa.base.SKRegressor
for scikit-learn regressors.- __init__(
- sklearner: ClassifierMixin,
- schema: Schema = None,
- random_seed: int = 1,
Construct a scikit-learn classifier wrapper.
- Parameters:
sklearner – A scikit-learn classifier object to wrap that must implements
partial_fit
andpredict
.schema – Describes the structure of the datastream.
random_seed – Random seed for reproducibility.
- Raises:
ValueError – If the scikit-learn algorithm does not implement
partial_fit
orpredict
.
- sklearner: ClassifierMixin#
The underlying scikit-learn object.
- train(instance: LabeledInstance)[source]#