Running Diffusion Models
In this example, we demonstrate how you can write your own Stable Diffusion tool, run it from your terminal, and deploy it to fal as a scalable and production-grade HTTP API without changing a single line of code.
Starting small
We can start by importing fal
and defining a global variable called MODEL_NAME
to denote which model we want to use (it can be any SD/SDXL model or any fine-tuned ones) from the HF model hub:
import fal
MODEL_NAME = "runwayml/stable-diffusion-v1-5"
Then proceed to have a cached function[link] that loads the model into the GPU when it's not already present. This should help us save a lot of time when we are invoking our tool multiple times in a row (or lots of API requests hitting in a short time frame)
@fal.cached
def load_model(model_name: str = MODEL_NAME) -> object:
"""Load the stable diffusion model into the GPU and return the
diffusers pipeline."""
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(model_name)
return pipe.to("cuda")
Implementation notes
The application can load multiple different SD models lazily by taking it as
an input parameter and passing it load_model
. @fal.cached
is input aware
so if you pass a different input (e.g. SD1.5
vs SDXL
) it will actually
reload the code and give you the new model (the old one will still be kept in
cache in case you might need it).
Taking inputs and returning outputs
For enabling automatic web endpoint fal offers through serve=True
, we'll have to define our inputs and outputs in a structured way through Pydantic. Although this looks like a web thing, there is actually nothing that prevents you from using the same I/O for the CLI as well which is what we are going to do.
import fal
from fal.toolkit import Image
from pydantic import BaseModel
# [...]
class DiffusionOptions(BaseModel):
prompt: str
steps: int = 30
class Result(BaseModel):
image: Image
Stable Diffusion App
We can annotate our inference function with the necessary packages (which is just diffusers
and transformers
for this example) and mark it as a served function by setting serve=True
. Although this workflow can run both on T4 and A100, we'll prefer the latter for its performance (but depending on your use case and cost parameters, this might change).
import fal
@fal.function(
"virtualenv",
requirements=[
"diffusers[torch]",
"transformers",
],
serve=True,
machine_type="GPU",
keep_alive=60,
)
def run_stable_diffusion(options: DiffusionOptions) -> Result:
# Load the Diffusers pipeline
pipe = load_model()
# Perform the inference process
result = pipe(options.prompt, num_inference_steps=options.steps)
# Upload the image and return it
image = Image.from_pil(result.images[0])
return Result(image=image)
The inference logic itself should be quite self explanatory but if we need to summarize it three steps, at each invocation this function does:
- Gets the diffusers pipeline which includes the actual model. Although the first invocation will be a bit expensive (~15 seconds) all subsequent cached invocations will be free of any cost.
- Run the pipeline with given options to perform inference and generate the image.
- Upload the image to fal's storage servers, and return a result object.
Using the app in the CLI
To try and play with your new stable diffusion app, you can write a very small interface for it and start running it locally.
from argparse import ArgumentParser
[...]
def main(argv: list[str] | None = None) -> None:
parser = ArgumentParser()
parser.add_argument("prompt", type=str)
parser.add_argument("--steps", type=int, default=40)
args = parser.parse_args(argv)
**local_diffusion = run_stable_diffusion.on(serve=False)**
result = local_diffusion(
DiffusionOptions(
prompt=args.prompt,
steps=args.steps,
)
)
print(
f"Image generation is complete. Access your "
f"image through the URL: {result.image.url}"
)
if __name__ == "__main__":
main()
As you might have noticed, we are creating a new function that is called local_diffusion
by setting the serve
property to False
when performing our invocations through Python. This is done in a way to ensure that our application works both as a web app when ran through run_stable_diffusion()
(or deployed) and also can be interacted through Python.
$ python app.py "a cat on the moon" --steps=50
[...]
Image generation is complete. Access your image through the URL: $URL
Productionizing your API
For sharing this app with others in the form of an HTTP API, all you have to do is call fal
's serve command and let it deploy the function for you to the serverless runtime. Each HTTP request will automatically wake up a server (if there isn't already one), process the request, hang around for a while in case there are other subsequent requests (within the defined keep_alive
) and finally shut itself down to prevent incurring costs when idle.
$ fal fn serve t.py run_stable_diffusion --alias stable-diffusion
Registered a new revision for function 'stable-diffusion' (revision='[...]').
URL: https://$USER-stable-diffusion.gateway.alpha.fal.ai
As long as you have the credentials[link] set, you can invoke this API from anywhere (from your terminal or from your own frontend):
$ curl $APP_URL \
-H 'Authorization: Key $FAL_KEY_ID:$fal_KEY_SECRET' \
-H 'Content-Type: application/json' \
-H 'Accept: application/json, */*;q=0.5' \
-d '{"prompt":"a cat on the moon"}'
{
"image": {
"url": "...",
"content_type": "image/png",
"file_name": "...",
"file_size": ...,
"width": 512,
"height": 512
}
}