Jun 04, 2025
7 min read

Deploy Flux Models with Pruna AI for 8x Faster Inference on Koyeb

Large AI models can be over-parameterized, memory-intensive, and slow at inference time. When deployed in production, these inefficiencies translate to higher latency, increased GPU costs, and wasted compute. Optimizing these models before deployment isn't just a nice-to-haveβ€”it's essential for scalable, cost-effective AI.

Pruna AI is a a model compression and optimization framework that reduces model size and compute requirements without sacrificing performance. Pruna supports structured pruning, quantization, and advanced compression strategies that preserve accuracy while dramatically improving efficiency.

In this tutorial, we will showcase how to use pruna_pro to optimize your models and deploy them on Koyeb. The same workflow applies to using pruna - simply adjust the installation command and the compression configuration.

Deploy to Koyeb

Requirements

To successfully follow and complete this guide, you need:

  • Python (version 3.9 or higher) installed on your local development environment
  • pruna_pro installed (or pruna if you are using the open-source version)
  • A Koyeb account to deploy the optimized Flux model
  • The Koyeb CLI installed to interact with Koyeb from the command line

Steps

To successfully build and deploy the MCP Server using Streamable HTTP transport to Koyeb, you need to follow these steps:

  1. Install Pruna
  2. Load the Baseline Model
  3. Configure Pruna for Optimization
  4. Optimize the Model
  5. Deploy the Optimized Model on Koyeb

Install Pruna

To run pruna on Koyeb, you’ll need Python 3.9 or later and access to any Nvidia GPU from Koyeb. In this demo, we will use FastAPI to build the application and serve the model and perform predictions.

Use uv to install and manage project dependencies. Get started by initializing a new project using uv:

uv init pruna-on-koyeb

Then, install the dependencies that will be required by our application:

uv add fastapi diffusers torch pruna_pro

Next, create a new file server.py containing the following complete implementation of our application. We will in the next section breakdown the different steps used to optimize the model using Pruna.

import base64
import io
import logging
import os
import time
from contextlib import asynccontextmanager
from typing import List, Literal, Optional

import torch
from diffusers import FluxPipeline
from fastapi import FastAPI, HTTPException
from pruna_pro import SmashConfig, smash
from pydantic import BaseModel, Field, field_validator

os.environ["TOKENIZERS_PARALLELISM"] = "false"

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    handlers=[logging.StreamHandler()],
)
logger = logging.getLogger(__name__)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_ID = "black-forest-labs/FLUX.1-dev"

class ModelManager:
    """Manages the loading, unloading and access to the Flux model pipeline."""

    def __init__(self):
        self.device = DEVICE

    async def load_model(self):
        logger.info(f"Loading model {MODEL_ID} on device {self.device}...")
        try:
            base_pipe = FluxPipeline.from_pretrained(
                "black-forest-labs/FLUX.1-dev",
                torch_dtype=torch.bfloat16,
            ).to(self.device)

            smash_config = SmashConfig()
            smash_config["cacher"] = "taylor_auto"
            smash_config["compiler"] = "torch_compile"
            smash_config._prepare_saving = False

            self.pipe = smash(
                model=base_pipe,
                token=os.getenv("PRUNA_API_KEY"),
                smash_config=smash_config,
            )

            logger.info("Model loaded successfully")

        except Exception as e:
            logger.error(f"Error loading model: {str(e)}")
            raise RuntimeError(f"Failed to load model: {str(e)}")

    async def unload_model(self):
        """Cleanup method to properly unload models"""
        try:
            del self.pipe

            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            logger.info("Model components unloaded successfully")
        except Exception as e:
            logger.error(f"Error unloading model components: {str(e)}")

model_manager = ModelManager()

class GenerationRequest(BaseModel):
    """Request model for image generation containing all parameters."""

    prompt: str = Field(..., min_length=1, max_length=1000)
    prompt_2: Optional[str] = Field(None, min_length=1, max_length=1000)
    height: Optional[int] = Field(1024, ge=64, le=2048)
    width: Optional[int] = Field(1024, ge=64, le=2048)
    num_inference_steps: int = Field(50, ge=1, le=100)
    guidance_scale: float = Field(3.5, ge=0.0, le=10.0)
    max_sequence_length: int = Field(256, ge=1, le=256)
    num_images_per_prompt: int = Field(1, ge=1, le=5)
    seed: Optional[int] = Field(None)
    speed_mode: Literal[
        "Lightly Juiced 🍊 (more consistent)",
        "Juiced πŸ”₯ (default)",
        "Extra Juiced πŸ”₯ (more speed)",
    ] = Field(default="Juiced πŸ”₯ (default)")

    @field_validator("height", "width")
    @classmethod
    def validate_dimensions(cls, v: Optional[int]) -> Optional[int]:
        if v is not None and v % 8 != 0:
            raise ValueError("Height and width must be divisible by 8")
        return v

    class Config:
        json_schema_extra = {
            "example": {
                "prompt": "A beautiful landscape",
                "height": 1024,
                "width": 1024,
                "num_inference_steps": 30,
            }
        }

class GenerationResponse(BaseModel):
    images: List[str]
    seed: int

def encode_images(images, quality: int = 85):
    """
    Encode PIL images to base64 JPEG strings.

    Args:
        images: List of PIL Image objects
        quality: JPEG quality (1-100)
    Returns:
        List of base64 encoded image strings
    """

    encoded_images = []
    for img in images:
        with io.BytesIO() as buffered:
            img = img.convert("RGB")
            img.save(buffered, format="JPEG", quality=quality, optimize=True)
            img_str = base64.b64encode(buffered.getvalue()).decode()
            encoded_images.append(f"data:image/jpeg;base64,{img_str}")
    return encoded_images

@asynccontextmanager
async def lifespan(_: FastAPI):
    await model_manager.load_model()
    yield
    await model_manager.unload_model()
    logger.info("Application shut down successfully")

app = FastAPI(
    title="flux.1-juiced API",
    description="API for generating images using FLUX.1 [dev] with Pruna AI",
    version="1.0.0",
    lifespan=lifespan,
)

@app.get("/health")
async def health_check():
    return {
        "status": "healthy",
        "device": model_manager.device,
        "model_loaded": hasattr(model_manager, "pipe"),
    }

@app.post("/predict", response_model=GenerationResponse)
async def predict(request: GenerationRequest):
    try:
        logger.info(f"Starting image generation with prompt: {request.prompt[:50]}...")
        logger.debug(f"Generation parameters: {request.model_dump()}")

        start_time = time.time()

        if request.seed is not None:
            torch.manual_seed(request.seed)
        else:
            request.seed = torch.randint(0, 2**32 - 1, (1,)).item()

        pipeline = model_manager.pipe

        if hasattr(pipeline, "cache_helper"):
            pipeline.cache_helper.disable()
            pipeline.cache_helper.enable()
            if request.speed_mode == "Lightly Juiced 🍊 (more consistent)":
                print("Setting cache speed factor: 0.4")
                pipeline.cache_helper.set_params(
                    speed_factor=0.5 if request.num_inference_steps > 20 else 0.6,
                )
            elif request.speed_mode == "Extra Juiced πŸ”₯ (more speed)":
                print("Setting cache speed factor: 0.2")
                pipeline.cache_helper.set_params(
                    speed_factor=0.3 if request.num_inference_steps > 20 else 0.4,
                )
            elif request.speed_mode == "Juiced πŸ”₯ (default)":
                print("Setting cache speed factor: 0.5")
                pipeline.cache_helper.set_params(
                    speed_factor=0.4 if request.num_inference_steps > 20 else 0.5,
                )
        else:
            print("Warning: Selected pipeline does not have cache_helper.")

        images = pipeline(
            prompt=request.prompt,
            prompt_2=request.prompt_2,
            height=request.height,
            width=request.width,
            num_inference_steps=request.num_inference_steps,
            guidance_scale=request.guidance_scale,
            max_sequence_length=request.max_sequence_length,
            num_images_per_prompt=request.num_images_per_prompt,
        ).images

        generation_time = time.time() - start_time
        logger.info(f"Image generation completed in {generation_time:.2f} seconds")
        logger.info(f"Generated {len(images)} images")

        encoded_images = encode_images(images)

        return GenerationResponse(images=encoded_images, seed=request.seed)

    except Exception as e:
        logger.error(f"Error during image generation: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

Let’s through the different steps used in the example: loading the baseline model, configuring Pruna for optimization, optimizing the model, and using it to check if its ready for production.

Load the Baseline Model

In this demo, we are using FluxPipeline as the baseline model before optimization.

from diffusers import FluxPipeline
import torch

# Load the model
self.pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
).to(self.device)

Configure Pruna for Optimization

Next, we create a SmashConfig object that specifies how the model should be optimized. pruna allows you to customize parameters like caching and compilation.

# Configure Pruna Smash
smash_config = SmashConfig()
smash_config["cacher"] = "taylor_auto"
smash_config["compiler"] = "torch_compile"
smash_config._prepare_saving = False

Optimize the Model

Pass your model and configuration to the smash() function, which applies the optimizations.

# Optimize the model

self.pipe = smash(
    model=base_pipe,
    token=os.getenv("PRUNA_API_KEY"), # Provide your actual token if you have purchased one using the `PRUNA_API_KEY` environment variable
    smash_config=smash_config,
)

Use the Optimized Model

After optimization, the model is ready for prediction.

# Generate output
images = pipeline(
    prompt=request.prompt,
    prompt_2=request.prompt_2,
    height=request.height,
    width=request.width,
    num_inference_steps=request.num_inference_steps,
    guidance_scale=request.guidance_scale,
    max_sequence_length=request.max_sequence_length,
    num_images_per_prompt=request.num_images_per_prompt,
).images

Deploy the Optimized Model on Koyeb

Create a Dockerfile within your project current directory with the following content:

FROM nvidia/cuda:12.8.1-cudnn-runtime-ubuntu22.04

COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/

ENV DEBIAN_FRONTEND=noninteractive \
    UV_COMPILE_BYTECODE=1 \
    UV_LINK_MODE=copy \
    HF_HOME=/workspace/model-cache \
    HF_HUB_ENABLE_HF_TRANSFER=1 \
    PATH="/workspace/.venv/bin:$PATH" \
    PORT=8000

RUN apt-get update && \
    apt-get install -y build-essential && \
    apt-get clean && \
    rm -rf /var/lib/apt/lists/*

WORKDIR /workspace

COPY . ./

RUN uv python pin 3.12.0 && \
    uv sync

ENTRYPOINT uvicorn server:app --host 0.0.0.0 --port ${PORT:-8000}

You can deploy to Koyeb using their control panel or via the Koyeb CLI. In this guide, we will deploy using the CLI.

koyeb deploy . pruna-on-koyeb \
   --instance-type gpu-nvidia-l40s \
   --region na \
   --type web \
   --port 8000:http \
   --archive-builder \
   --env PRUNA_API_KEY=`your_pruna_api_key`

After a few seconds, your service will be deployed and running on Koyeb and you will be able to perform your first predictions using your Koyeb domain ending with .koyeb.app. Check out the /docs endpoint to access the documentation on how to run your first prediction.

Conclusion

In this guide, we showcased how to optimize your models with Pruna AI and deploy them on Koyeb using the Koyeb CLI. Optimizing with Pruna and deploying on Koyeb gives you production-ready performance without compromising on cost or scalability.

To learn more, read the Pruna AI x Koyeb documentation or check out our one-click deploy for Pruna AI on Koyeb


Deploy AI apps to production in minutes

Get started
Koyeb is a developer-friendly serverless platform to deploy apps globally. No-ops, servers, or infrastructure management.
Service is degraded
Β© Koyeb