Save and Load a Model#

In this tutorial, we illustrate the process of saving and loading a model using CapyMOA.

  • We use the SEA synthetic generator as the data source, and the AdaptiveRandomForestClassifier as the learner.

  • The trained model is saved to a file, specifically ‘capymoa_model.pkl’.

  • Subsequently, we reload the model from the file and resume training and evaluating its performance on the SEA data.

  • As a final step, we delete the model file.

1. Training and saving the model#

  • We train the model on 5k instances from SEA using the evaluate_prequential function

  • We proceed to save the model with save_model(learner, "capymoa_ARF_model.pkl")

from capymoa.classifier import AdaptiveRandomForestClassifier
from capymoa.evaluation import prequential_evaluation
from import SEA
from capymoa.misc import save_model, load_model

stream = SEA()
learner = AdaptiveRandomForestClassifier(schema=stream.get_schema(), ensemble_size=10)

results = prequential_evaluation(stream=stream, learner=learner, max_instances=5000)

print(f"Accuracy: {results['cumulative'].accuracy():.2f}")
save_model(learner, "capymoa_ARF_model.pkl") # Save model to capymoa_model.pkl
Accuracy: 87.96

2. Loading and resuming training#

  • We use os.path.getsize() to inspect the size (KB) of the saved file.

  • We don’t restart the synthetic stream, we just continue processing it through another call to prequential_evaluation

  • Finally, we observe the accuracy

import os

model_file = 'capymoa_ARF_model.pkl'

model_size = os.path.getsize(model_file)
print(f"The saved model size: {model_size / 1024:.2f} KB")

restored_learner = load_model("capymoa_ARF_model.pkl")  # Load model from capymoa_model.pkl

# Train for more 50k instances on the restored model
results = prequential_evaluation(stream=stream, learner=restored_learner, max_instances=5000)

print(f"Updated accuracy: {results['cumulative'].accuracy():.2f}")
The saved model size: 616.66 KB
Updated accuracy: 89.32

3. Cleanup#

  • As a last step, we delete the model

if os.path.exists(model_file):
    print(f"File {model_file} has been deleted.")
    print(f"File {model_file} not found.")
File capymoa_ARF_model.pkl has been deleted.
[ ]: