Deploy Flux Models with Pruna AI for 8x Faster Inference on Koyeb
Large AI models can be over-parameterized, memory-intensive, and slow at inference time. When deployed in production, these inefficiencies translate to higher latency, increased GPU costs, and wasted compute. Optimizing these models before deployment isn't just a nice-to-haveβit's essential for scalable, cost-effective AI.
Pruna AI is a a model compression and optimization framework that reduces model size and compute requirements without sacrificing performance. Pruna supports structured pruning, quantization, and advanced compression strategies that preserve accuracy while dramatically improving efficiency.
In this tutorial, we will showcase how to use pruna_pro
to optimize your models and deploy them on Koyeb. The same workflow applies to using pruna
- simply adjust the installation command and the compression configuration.
Requirements
To successfully follow and complete this guide, you need:
- Python (version 3.9 or higher) installed on your local development environment
- pruna_pro installed (or pruna if you are using the open-source version)
- A Koyeb account to deploy the optimized Flux model
- The Koyeb CLI installed to interact with Koyeb from the command line
Steps
To successfully build and deploy the MCP Server using Streamable HTTP transport to Koyeb, you need to follow these steps:
- Install Pruna
- Load the Baseline Model
- Configure Pruna for Optimization
- Optimize the Model
- Deploy the Optimized Model on Koyeb
Install Pruna
To run pruna
on Koyeb, youβll need Python 3.9 or later and access to any Nvidia GPU from Koyeb. In this demo, we will use FastAPI to build the application and serve the model and perform predictions.
Use uv
to install and manage project dependencies.
Get started by initializing a new project using uv
:
uv init pruna-on-koyeb
Then, install the dependencies that will be required by our application:
uv add fastapi diffusers torch pruna_pro
Next, create a new file server.py
containing the following complete implementation of our application. We will in the next section breakdown the different steps used to optimize the model using Pruna.
import base64
import io
import logging
import os
import time
from contextlib import asynccontextmanager
from typing import List, Literal, Optional
import torch
from diffusers import FluxPipeline
from fastapi import FastAPI, HTTPException
from pruna_pro import SmashConfig, smash
from pydantic import BaseModel, Field, field_validator
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler()],
)
logger = logging.getLogger(__name__)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_ID = "black-forest-labs/FLUX.1-dev"
class ModelManager:
"""Manages the loading, unloading and access to the Flux model pipeline."""
def __init__(self):
self.device = DEVICE
async def load_model(self):
logger.info(f"Loading model {MODEL_ID} on device {self.device}...")
try:
base_pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
).to(self.device)
smash_config = SmashConfig()
smash_config["cacher"] = "taylor_auto"
smash_config["compiler"] = "torch_compile"
smash_config._prepare_saving = False
self.pipe = smash(
model=base_pipe,
token=os.getenv("PRUNA_API_KEY"),
smash_config=smash_config,
)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise RuntimeError(f"Failed to load model: {str(e)}")
async def unload_model(self):
"""Cleanup method to properly unload models"""
try:
del self.pipe
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("Model components unloaded successfully")
except Exception as e:
logger.error(f"Error unloading model components: {str(e)}")
model_manager = ModelManager()
class GenerationRequest(BaseModel):
"""Request model for image generation containing all parameters."""
prompt: str = Field(..., min_length=1, max_length=1000)
prompt_2: Optional[str] = Field(None, min_length=1, max_length=1000)
height: Optional[int] = Field(1024, ge=64, le=2048)
width: Optional[int] = Field(1024, ge=64, le=2048)
num_inference_steps: int = Field(50, ge=1, le=100)
guidance_scale: float = Field(3.5, ge=0.0, le=10.0)
max_sequence_length: int = Field(256, ge=1, le=256)
num_images_per_prompt: int = Field(1, ge=1, le=5)
seed: Optional[int] = Field(None)
speed_mode: Literal[
"Lightly Juiced π (more consistent)",
"Juiced π₯ (default)",
"Extra Juiced π₯ (more speed)",
] = Field(default="Juiced π₯ (default)")
@field_validator("height", "width")
@classmethod
def validate_dimensions(cls, v: Optional[int]) -> Optional[int]:
if v is not None and v % 8 != 0:
raise ValueError("Height and width must be divisible by 8")
return v
class Config:
json_schema_extra = {
"example": {
"prompt": "A beautiful landscape",
"height": 1024,
"width": 1024,
"num_inference_steps": 30,
}
}
class GenerationResponse(BaseModel):
images: List[str]
seed: int
def encode_images(images, quality: int = 85):
"""
Encode PIL images to base64 JPEG strings.
Args:
images: List of PIL Image objects
quality: JPEG quality (1-100)
Returns:
List of base64 encoded image strings
"""
encoded_images = []
for img in images:
with io.BytesIO() as buffered:
img = img.convert("RGB")
img.save(buffered, format="JPEG", quality=quality, optimize=True)
img_str = base64.b64encode(buffered.getvalue()).decode()
encoded_images.append(f"data:image/jpeg;base64,{img_str}")
return encoded_images
@asynccontextmanager
async def lifespan(_: FastAPI):
await model_manager.load_model()
yield
await model_manager.unload_model()
logger.info("Application shut down successfully")
app = FastAPI(
title="flux.1-juiced API",
description="API for generating images using FLUX.1 [dev] with Pruna AI",
version="1.0.0",
lifespan=lifespan,
)
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"device": model_manager.device,
"model_loaded": hasattr(model_manager, "pipe"),
}
@app.post("/predict", response_model=GenerationResponse)
async def predict(request: GenerationRequest):
try:
logger.info(f"Starting image generation with prompt: {request.prompt[:50]}...")
logger.debug(f"Generation parameters: {request.model_dump()}")
start_time = time.time()
if request.seed is not None:
torch.manual_seed(request.seed)
else:
request.seed = torch.randint(0, 2**32 - 1, (1,)).item()
pipeline = model_manager.pipe
if hasattr(pipeline, "cache_helper"):
pipeline.cache_helper.disable()
pipeline.cache_helper.enable()
if request.speed_mode == "Lightly Juiced π (more consistent)":
print("Setting cache speed factor: 0.4")
pipeline.cache_helper.set_params(
speed_factor=0.5 if request.num_inference_steps > 20 else 0.6,
)
elif request.speed_mode == "Extra Juiced π₯ (more speed)":
print("Setting cache speed factor: 0.2")
pipeline.cache_helper.set_params(
speed_factor=0.3 if request.num_inference_steps > 20 else 0.4,
)
elif request.speed_mode == "Juiced π₯ (default)":
print("Setting cache speed factor: 0.5")
pipeline.cache_helper.set_params(
speed_factor=0.4 if request.num_inference_steps > 20 else 0.5,
)
else:
print("Warning: Selected pipeline does not have cache_helper.")
images = pipeline(
prompt=request.prompt,
prompt_2=request.prompt_2,
height=request.height,
width=request.width,
num_inference_steps=request.num_inference_steps,
guidance_scale=request.guidance_scale,
max_sequence_length=request.max_sequence_length,
num_images_per_prompt=request.num_images_per_prompt,
).images
generation_time = time.time() - start_time
logger.info(f"Image generation completed in {generation_time:.2f} seconds")
logger.info(f"Generated {len(images)} images")
encoded_images = encode_images(images)
return GenerationResponse(images=encoded_images, seed=request.seed)
except Exception as e:
logger.error(f"Error during image generation: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
Letβs through the different steps used in the example: loading the baseline model, configuring Pruna for optimization, optimizing the model, and using it to check if its ready for production.
Load the Baseline Model
In this demo, we are using FluxPipeline
as the baseline model before optimization.
from diffusers import FluxPipeline
import torch
# Load the model
self.pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
).to(self.device)
Configure Pruna for Optimization
Next, we create a SmashConfig object that specifies how the model should be optimized. pruna
allows you to customize parameters like caching and compilation.
# Configure Pruna Smash
smash_config = SmashConfig()
smash_config["cacher"] = "taylor_auto"
smash_config["compiler"] = "torch_compile"
smash_config._prepare_saving = False
Optimize the Model
Pass your model and configuration to the smash()
function, which applies the optimizations.
# Optimize the model
self.pipe = smash(
model=base_pipe,
token=os.getenv("PRUNA_API_KEY"), # Provide your actual token if you have purchased one using the `PRUNA_API_KEY` environment variable
smash_config=smash_config,
)
Use the Optimized Model
After optimization, the model is ready for prediction.
# Generate output
images = pipeline(
prompt=request.prompt,
prompt_2=request.prompt_2,
height=request.height,
width=request.width,
num_inference_steps=request.num_inference_steps,
guidance_scale=request.guidance_scale,
max_sequence_length=request.max_sequence_length,
num_images_per_prompt=request.num_images_per_prompt,
).images
Deploy the Optimized Model on Koyeb
Create a Dockerfile
within your project current directory with the following content:
FROM nvidia/cuda:12.8.1-cudnn-runtime-ubuntu22.04
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
ENV DEBIAN_FRONTEND=noninteractive \
UV_COMPILE_BYTECODE=1 \
UV_LINK_MODE=copy \
HF_HOME=/workspace/model-cache \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PATH="/workspace/.venv/bin:$PATH" \
PORT=8000
RUN apt-get update && \
apt-get install -y build-essential && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
WORKDIR /workspace
COPY . ./
RUN uv python pin 3.12.0 && \
uv sync
ENTRYPOINT uvicorn server:app --host 0.0.0.0 --port ${PORT:-8000}
You can deploy to Koyeb using their control panel or via the Koyeb CLI. In this guide, we will deploy using the CLI.
koyeb deploy . pruna-on-koyeb \
--instance-type gpu-nvidia-l40s \
--region na \
--type web \
--port 8000:http \
--archive-builder \
--env PRUNA_API_KEY=`your_pruna_api_key`
After a few seconds, your service will be deployed and running on Koyeb and you will be able to perform your first predictions using your Koyeb domain ending with .koyeb.app
. Check out the /docs
endpoint to access the documentation on how to run your first prediction.
Conclusion
In this guide, we showcased how to optimize your models with Pruna AI and deploy them on Koyeb using the Koyeb CLI. Optimizing with Pruna and deploying on Koyeb gives you production-ready performance without compromising on cost or scalability.
To learn more, read the Pruna AI x Koyeb documentation or check out our one-click deploy for Pruna AI on Koyeb