paint-brush
AI-aangedreven Image Generation API-service met FLUX, Python en Diffusers: een snelle handleidingdoor@herahavenai
259 lezingen

AI-aangedreven Image Generation API-service met FLUX, Python en Diffusers: een snelle handleiding

door HeraHaven AI11m2024/11/29
Read on Terminal Reader

Te lang; Lezen

In dit artikel leiden we je door het maken van je eigen FLUX-server met Python. Met deze server kun je afbeeldingen genereren op basis van tekstprompts via een eenvoudige API. Of je deze server nu voor persoonlijk gebruik gebruikt of als onderdeel van een productietoepassing, deze gids helpt je op weg.
featured image - AI-aangedreven Image Generation API-service met FLUX, Python en Diffusers: een snelle handleiding
HeraHaven AI HackerNoon profile picture


In dit artikel leiden we je door het maken van je eigen FLUX-server met Python. Met deze server kun je afbeeldingen genereren op basis van tekstprompts via een eenvoudige API. Of je deze server nu voor persoonlijk gebruik gebruikt of als onderdeel van een productietoepassing, deze gids helpt je op weg.


FLUX (van Black Forest Labs ) heeft de wereld van AI-beeldgeneratie de afgelopen maanden stormenderhand veroverd. Het heeft niet alleen Stable Diffusion (de vorige open-source koning) verslagen op veel benchmarks, maar het heeft ook propriëtaire modellen zoals Dall-E of Midjourney in sommige statistieken overtroffen.


Maar hoe zou je FLUX gebruiken op een van je apps? Je zou kunnen denken aan serverloze hosts zoals Replicate en anderen, maar deze kunnen heel snel heel duur worden en bieden mogelijk niet de flexibiliteit die je nodig hebt. Dat is waar het maken van je eigen aangepaste FLUX-server van pas komt.

Vereisten

Voordat we in de code duiken, moeten we controleren of u de benodigde tools en bibliotheken hebt ingesteld:

  • Python: Python 3 moet op uw computer geïnstalleerd zijn, bij voorkeur versie 3.10.
  • torch : Het deep learning-framework dat we gebruiken om FLUX uit te voeren.
  • diffusers : Biedt toegang tot het FLUX-model.
  • transformers : Vereiste afhankelijkheid van diffusers.
  • sentencepiece : Vereist om de FLUX-tokenizer uit te voeren
  • protobuf : Vereist om FLUX uit te voeren
  • accelerate : helpt in sommige gevallen het FLUX-model efficiënter te laden.
  • fastapi : Framework voor het creëren van een webserver die verzoeken voor het genereren van afbeeldingen kan accepteren.
  • uvicorn : vereist om de FastAPI-server te laten draaien.
  • psutil : Hiermee kunnen we controleren hoeveel RAM er op onze machine aanwezig is.

U kunt alle bibliotheken installeren door de volgende opdracht uit te voeren: pip install torch diffusers transformers sentencepiece protobuf accelerate fastapi uvicorn .

Als u een Mac met een M1- of M2-chip gebruikt, moet u PyTorch met Metal instellen voor optimale prestaties. Volg de officiële PyTorch met Metal-gids voordat u verdergaat.

Je moet er ook voor zorgen dat je minimaal 12 GB VRAM hebt als je van plan bent om FLUX op een GPU-apparaat te draaien. Of minimaal 12 GB RAM voor het draaien op CPU/MPS (wat langzamer zal zijn).

Stap 1: De omgeving instellen

Laten we het script beginnen met het kiezen van het juiste apparaat om de inferentie uit te voeren, gebaseerd op de hardware die we gebruiken.

 device = 'cuda' # can also be 'cpu' or 'mps' import os # MPS support in PyTorch is not yet fully implemented if device == 'mps': os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" import torch if device == 'mps' and not torch.backends.mps.is_available(): raise Exception("Device set to MPS, but MPS is not available") elif device == 'cuda' and not torch.cuda.is_available(): raise Exception("Device set to CUDA, but CUDA is not available")

U kunt cpu , cuda (voor NVIDIA GPU's) of mps (voor Apple's Metal Performance Shaders) opgeven. Het script controleert vervolgens of het geselecteerde apparaat beschikbaar is en genereert een uitzondering als dat niet het geval is.

Stap 2: Het FLUX-model laden

Vervolgens laden we het FLUX-model. We laden het model in fp16-precisie, wat ons wat geheugen bespaart zonder veel kwaliteitsverlies.

Op dit punt wordt u mogelijk gevraagd om te authenticeren met HuggingFace, aangezien het FLUX-model is gegated. Om succesvol te authenticeren, moet u een HuggingFace-account aanmaken, naar de modelpagina gaan, de voorwaarden accepteren en vervolgens een HuggingFace-token maken vanuit uw accountinstellingen en deze toevoegen aan uw machine als de HF_TOKEN omgevingsvariabele.

 from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline import psutil model_name = "black-forest-labs/FLUX.1-dev" print(f"Loading {model_name} on {device}") pipeline = FluxPipeline.from_pretrained( model_name, # Diffusion models are generally trained on fp32, but fp16 # gets us 99% there in terms of quality, with just half the (V)RAM torch_dtype=torch.float16, # Ensure we don't load any dangerous binary code use_safetensors=True # We are using Euler here, but you can also use other samplers scheduler=FlowMatchEulerDiscreteScheduler() ).to(device)

Hier laden we het FLUX-model met behulp van de diffusers-bibliotheek. Het model dat we gebruiken is black-forest-labs/FLUX.1-dev , geladen in fp16-precisie.


Er is ook een timestep-distilled model genaamd FLUX Schnell dat snellere inferentie heeft, maar minder gedetailleerde beelden oplevert, evenals een FLUX Pro-model dat closed-source is. We gebruiken hier de Euler-scheduler, maar u kunt hiermee experimenteren. U kunt hier meer lezen over schedulers. Omdat het genereren van afbeeldingen veel resources kan kosten, is het cruciaal om het geheugengebruik te optimaliseren, vooral wanneer u op een CPU of een apparaat met beperkt geheugen draait.


 # Recommended if running on MPS or CPU with < 64 GB of RAM total_memory = psutil.virtual_memory().total total_memory_gb = total_memory / (1024 ** 3) if (device == 'cpu' or device == 'mps') and total_memory_gb < 64: print("Enabling attention slicing") pipeline.enable_attention_slicing()

Deze code controleert het totale beschikbare geheugen en schakelt attention slicing in als het systeem minder dan 64 GB RAM heeft. Attention slicing vermindert het geheugengebruik tijdens het genereren van afbeeldingen, wat essentieel is voor apparaten met beperkte bronnen.

Stap 3: De API maken met FastAPI

Vervolgens stellen we de FastAPI-server in. Deze server biedt een API om afbeeldingen te genereren.

 from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field, conint, confloat from fastapi.middleware.gzip import GZipMiddleware from io import BytesIO import base64 app = FastAPI() # We will be returning the image as a base64 encoded string # which we will want compressed app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=7)

FastAPI is een populair framework voor het bouwen van web-API's met Python. In dit geval gebruiken we het om een server te maken die verzoeken voor het genereren van afbeeldingen kan accepteren. We gebruiken ook GZip-middleware om de respons te comprimeren, wat vooral handig is bij het terugsturen van afbeeldingen in base64-formaat.

In een productieomgeving wilt u de gegenereerde afbeeldingen mogelijk opslaan in een S3-bucket of andere cloudopslag en de URL's retourneren in plaats van de base64-gecodeerde tekenreeksen, om te profiteren van een CDN en andere optimalisaties.

Stap 4: Het aanvraagmodel definiëren

Nu moeten we een model definiëren voor de verzoeken die onze API zal accepteren.

 class GenerateRequest(BaseModel): prompt: str seed: conint(ge=0) = Field(..., description="Seed for random number generation") height: conint(gt=0) = Field(..., description="Height of the generated image, must be a positive integer and a multiple of 8") width: conint(gt=0) = Field(..., description="Width of the generated image, must be a positive integer and a multiple of 8") cfg: confloat(gt=0) = Field(..., description="CFG (classifier-free guidance scale), must be a positive integer or 0") steps: conint(ge=0) = Field(..., description="Number of steps") batch_size: conint(gt=0) = Field(..., description="Number of images to generate in a batch")

Dit GenerateRequest -model definieert de parameters die nodig zijn om een afbeelding te genereren. Het prompt is de tekstuele beschrijving van de afbeelding die u wilt maken. Andere velden zijn de afbeeldingsafmetingen, het aantal inferentiestappen en de batchgrootte.

Stap 5: Het eindpunt voor het genereren van afbeeldingen maken

Laten we nu het eindpunt maken dat de verzoeken voor het genereren van afbeeldingen verwerkt.

 @app.post("/") async def generate_image(request: GenerateRequest): # Validate that height and width are multiples of 8 # as required by FLUX if request.height % 8 != 0 or request.width % 8 != 0: raise HTTPException(status_code=400, detail="Height and width must both be multiples of 8") # Always calculate the seed on CPU for deterministic RNG # For a batch of images, seeds will be sequential like n, n+1, n+2, ... generator = [torch.Generator(device="cpu").manual_seed(i) for i in range(request.seed, request.seed + request.batch_size)] images = pipeline( height=request.height, width=request.width, prompt=request.prompt, generator=generator, num_inference_steps=request.steps, guidance_scale=request.cfg, num_images_per_prompt=request.batch_size ).images # Convert images to base64 strings # (for a production app, you might want to store the # images in an S3 bucket and return the URLs instead) base64_images = [] for image in images: buffered = BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") base64_images.append(img_str) return { "images": base64_images, }

Dit eindpunt verwerkt het proces van het genereren van afbeeldingen. Het valideert eerst dat de hoogte en breedte veelvouden van 8 zijn, zoals vereist door FLUX. Vervolgens genereert het afbeeldingen op basis van de opgegeven prompt en retourneert deze als base64-gecodeerde strings.

Stap 6: De server starten

Tot slot voegen we wat code toe om de server te starten wanneer het script wordt uitgevoerd.

 @app.on_event("startup") async def startup_event(): print("Image generation server running") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)

Deze code start de FastAPI-server op poort 8000, waardoor deze niet alleen toegankelijk is vanaf http://localhost:8000 , maar ook vanaf andere apparaten op hetzelfde netwerk die gebruikmaken van het IP-adres van de hostcomputer, dankzij de binding 0.0.0.0 .

Stap 7: Uw server lokaal testen

Nu uw FLUX-server draait, is het tijd om deze te testen. U kunt curl gebruiken, een opdrachtregeltool voor het maken van HTTP-verzoeken, om te communiceren met uw server:

 curl -X POST "http://localhost:8000/" \ -H "Content-Type: application/json" \ -d '{ "prompt": "A futuristic cityscape at sunset", "seed": 42, "height": 1024, "width": 1024, "cfg": 3.5, "steps": 50, "batch_size": 1 }' | jq -r '.images[0]' | base64 -d > test.png

Deze opdracht werkt alleen op UNIX-gebaseerde systemen met de curl , jq en base64 hulpprogramma's geïnstalleerd. Het kan ook een paar minuten duren om te voltooien, afhankelijk van de hardware die de FLUX-server host.

Conclusie

Gefeliciteerd! U hebt met succes uw eigen FLUX-server gemaakt met Python. Met deze opstelling kunt u afbeeldingen genereren op basis van tekstprompts via een eenvoudige API. Als u niet tevreden bent met de resultaten van het basis-FLUX-model, kunt u overwegen het model te verfijnen voor nog betere prestaties in specifieke use cases .

Volledige code

Hieronder vindt u de volledige code die in deze handleiding wordt gebruikt:

 device = 'cuda' # can also be 'cpu' or 'mps' import os # MPS support in PyTorch is not yet fully implemented if device == 'mps': os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" import torch if device == 'mps' and not torch.backends.mps.is_available(): raise Exception("Device set to MPS, but MPS is not available") elif device == 'cuda' and not torch.cuda.is_available(): raise Exception("Device set to CUDA, but CUDA is not available") from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline import psutil model_name = "black-forest-labs/FLUX.1-dev" print(f"Loading {model_name} on {device}") pipeline = FluxPipeline.from_pretrained( model_name, # Diffusion models are generally trained on fp32, but fp16 # gets us 99% there in terms of quality, with just half the (V)RAM torch_dtype=torch.float16, # Ensure we don't load any dangerous binary code use_safetensors=True, # We are using Euler here, but you can also use other samplers scheduler=FlowMatchEulerDiscreteScheduler() ).to(device) # Recommended if running on MPS or CPU with < 64 GB of RAM total_memory = psutil.virtual_memory().total total_memory_gb = total_memory / (1024 ** 3) if (device == 'cpu' or device == 'mps') and total_memory_gb < 64: print("Enabling attention slicing") pipeline.enable_attention_slicing() from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field, conint, confloat from fastapi.middleware.gzip import GZipMiddleware from io import BytesIO import base64 app = FastAPI() # We will be returning the image as a base64 encoded string # which we will want compressed app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=7) class GenerateRequest(BaseModel): prompt: str seed: conint(ge=0) = Field(..., description="Seed for random number generation") height: conint(gt=0) = Field(..., description="Height of the generated image, must be a positive integer and a multiple of 8") width: conint(gt=0) = Field(..., description="Width of the generated image, must be a positive integer and a multiple of 8") cfg: confloat(gt=0) = Field(..., description="CFG (classifier-free guidance scale), must be a positive integer or 0") steps: conint(ge=0) = Field(..., description="Number of steps") batch_size: conint(gt=0) = Field(..., description="Number of images to generate in a batch") @app.post("/") async def generate_image(request: GenerateRequest): # Validate that height and width are multiples of 8 # as required by FLUX if request.height % 8 != 0 or request.width % 8 != 0: raise HTTPException(status_code=400, detail="Height and width must both be multiples of 8") # Always calculate the seed on CPU for deterministic RNG # For a batch of images, seeds will be sequential like n, n+1, n+2, ... generator = [torch.Generator(device="cpu").manual_seed(i) for i in range(request.seed, request.seed + request.batch_size)] images = pipeline( height=request.height, width=request.width, prompt=request.prompt, generator=generator, num_inference_steps=request.steps, guidance_scale=request.cfg, num_images_per_prompt=request.batch_size ).images # Convert images to base64 strings # (for a production app, you might want to store the # images in an S3 bucket and return the URL's instead) base64_images = [] for image in images: buffered = BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") base64_images.append(img_str) return { "images": base64_images, } @app.on_event("startup") async def startup_event(): print("Image generation server running") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)