A trained model is only useful if it can make predictions on new data. The most common way to make a model accessible to other applications (like a web front-end or a mobile app) is by exposing it through a REST API. The application sends new data to an API endpoint via an HTTP request, and the API returns the model's prediction in the response.
Option 1: Flask (The Classic Microframework)
Flask is a lightweight and straightforward Python web framework, making it a popular choice for simple model serving. The basic idea is to load your model once when the server starts and then create a route that uses the loaded model to make predictions.
Option 2: FastAPI (The Modern, High-Performance Choice)
FastAPI is a modern web framework that has become a favorite for building ML APIs. It offers several key advantages:
- High Performance: As the name suggests, it's one of the fastest Python frameworks available, thanks to its asynchronous nature.
- Automatic Interactive Docs: It automatically generates interactive API documentation (using Swagger UI and ReDoc), which is incredibly helpful for testing and for other developers using your API.
- Data Validation: It uses Python type hints and a library called Pydantic to enforce data validation. If a request sends data in the wrong format, FastAPI automatically rejects it with a clear error message.
Code Snippet: A Sentiment Analysis API with FastAPI
Let's build an API for the sentiment model we created in a previous tutorial.
Python
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import joblib
# 1. Initialize FastAPI app
app = FastAPI(
title="Sentiment Analysis API",
description="An API to predict the sentiment of a given text."
)
# 2. Define the request data model using Pydantic
class TextInput(BaseModel):
text: str
# 3. Load the trained model pipeline
try:
model = joblib.load('sentiment_model.joblib')
except FileNotFoundError:
model = None
# 4. Define the prediction endpoint
@app.post("/predict")
async def predict_sentiment(text_input: TextInput):
if not model:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# The input text is accessed via text_input.text
prediction = model.predict([text_input.text])[0]
return {"sentiment": prediction}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# To run this app: `uvicorn app:app --reload`
# Then navigate to http://127.0.0.1:8000/docs to see the interactive API docs.
Performance Optimization: Request Batching
When you deploy a model on a powerful accelerator like a GPU, making predictions one by one is extremely inefficient. There is a significant overhead for each call, and you don't utilize the GPU's massive parallel processing capability.
Request batching is the solution. The idea is to collect multiple requests that arrive in a short time window and feed them to the model as a single, larger batch. This dramatically increases throughput (predictions per second).
Conceptual Implementation: Implementing a robust batching system is complex. It usually involves a multi-threaded or asynchronous architecture:
- API Thread(s): The API endpoint (/predict) doesn't call the model directly. Instead, it places the request data into a shared queue. It then waits for a signal that its specific request has been processed.
- Worker Thread: A separate, single background thread runs a continuous loop.
- The Loop:
- It sleeps for a very short time (e.g., 10 milliseconds).
- It wakes up and pulls all available requests from the queue (up to a maximum batch size).
- If there are requests, it stacks their data into a single batch tensor.
- It performs one inference call on the entire batch.
- It distributes the results from the output batch back to the corresponding waiting API threads.
This pattern ensures that the GPU is always fed large batches of data, maximizing its utilization and your API's overall throughput. Dedicated model serving platforms (like TensorFlow Serving and TorchServe) have this functionality built-in.