All Articles

How to serve a Machine Learning model through a Flask API ?

In this post, which is kind of the 101 of ML model deployment, we will use the python microframework Flask to serve a machine learning model through an API.

Part I: The Training

Before deploying a machine learning model, you need… a machine learning model. Well, let’s say this has already been done, and you have built (using scikit-learn) a state-of-the-art-model to address the very challenging and original task of classifying Iris flower species using features such as sepal and petal’s length and width.

Below is the Notebook we used to achieve this:

Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

The model we have developed has been persisted using the joblib library. (But we could of course have used python’s standard serialization module pickle)

I also usually persist an ordered list of all the features used by the model, to make sure we use the same features in the same order during inference.

With this persisted model, what we could technically do analyzing a new flower is to load the model from a notebook, and get the predicted class using that model:

Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

But this might not be the best user experience. Instead, what we would probably want to do is to develop a user interface, in which an end-user can fill in a form with all the new flower’s characteristics, and then get the predicted class.

In this post we won’t talk much about the development of that user interface. But we understand that we need to make our ML model available somehow, so that the user interface can send the new flowers’ characteristics and get predictions in return. This is done by encapsulating the model in an API.

Part II: Building the API

More specifically, we will use Flask to build a very simple API, which will contain a /predict-species endpoint. UIs will then be able to make POST requests to that endpoint, passing in the flowers’ characteristics, and getting the predicted species as results.

Flask API

Below is the code of our flask app.

from flask import Flask
from flask import request
import joblib
import pandas as pd
import json
with open("iris_classifier.joblib", "rb") as f:
iris_classifier = joblib.load(f)
with open("iris_classifier_features.joblib", "rb") as f:
iris_classifier_features = joblib.load(f)
app = Flask(__name__)
@app.route('/predict-species', methods=['POST'])
def predict_species():
flower = {}
for feature in iris_classifier_features:
flower[feature] = [request.form[feature]]
flower = pd.DataFrame(flower)
species = iris_classifier.predict(flower[iris_classifier_features])
return species[0]

To launch the flask server you just need to run the command python -m flask run.

And, here we are ! Our model is now served on localhost and we can test the endpoint using Postman:

Postman screenshot

We might also need the probability associated with each class, so we can define a new endpoint at the end of the file:

@app.route('/predict-species-proba', methods=['POST'])
def predict_species_proba():
flower = {}
for feature in iris_classifier_features:
flower[feature] = [request.form[feature]]
flower = pd.DataFrame(flower)
probas = iris_classifier.predict_proba(flower[iris_classifier_features])[0, :].tolist()
species_proba = {}
for idx, species in enumerate(['setosa', 'versicolor', 'virginica']):
species_proba[species] = probas[idx]
return json.dumps(species_proba)

And results will look like this:

Postman screenshot


This post shows how you can use Flask to serve a machine learning model.

While I think this is a good example to understand how ML deployment can be achieved, have in mind that this is not production-ready yet.

Indeed, the built-in flask server is a development server, and should not be used in production (see more here).

Moreover, for now our API runs locally. In a future post, I will explain how you can actually deploy this API on a server (probably using AWS EC2).