import os import time from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler from PIL import Image import uuid app = FastAPI() # Configuration MODEL_ID = "runwayml/stable-diffusion-v1-5" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) OUTPUT_DIR = os.path.join(BASE_DIR, "outputs") os.makedirs(OUTPUT_DIR, exist_ok=True) # Global variables for the model pipe = None def load_model(): global pipe device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading Stable Diffusion model on {device}...") pipe = StableDiffusionPipeline.from_pretrained( MODEL_ID, torch_dtype=torch.float16 if device == "cuda" else torch.float32, ) # Use a high-quality scheduler for better detail pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) pipe.to(device) if device == "cpu": pipe.enable_attention_slicing() print("Model loaded successfully.") from contextlib import asynccontextmanager @asynccontextmanager async def lifespan(app: FastAPI): # Model loading is heavy, we'll do it on first request to avoid timeout at startup yield app = FastAPI(lifespan=lifespan) # Serve generated images from fastapi.staticfiles import StaticFiles app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs") import psycopg2 class ImageRequest(BaseModel): prompt: str lesson_id: str database_url: str = None table_name: str = "lessons" progress_column: str = "generation_progress" width: int = 512 height: int = 512 @app.post("/generate") async def generate_image(request: ImageRequest): global pipe if pipe is None: load_model() num_steps = 150 def progress_callback(step: int, timestep: int, latents: torch.FloatTensor): if request.database_url and request.lesson_id: try: progress = int((step / num_steps) * 100) conn = psycopg2.connect(request.database_url) cur = conn.cursor() # Use psycopg2.sql for safe table/column names if possible, # but here we'll just format since we control the backend values query = f"UPDATE {request.table_name} SET {request.progress_column} = %s WHERE id = %s" cur.execute(query, (progress, request.lesson_id)) conn.commit() cur.close() conn.close() except Exception as db_e: print(f"Database update error: {db_e}") def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs): if progress_callback: progress_callback(step_index, timestep, None) return callback_kwargs try: quality_prompt = f"{request.prompt}, highly detailed, high quality, masterpiece, 8k, realistic, photographic, sharp focus, perfect anatomy" negative_prompt = "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, low quality, low resolution, bad hands, extra fingers, cartoon, anime, illustration, draft, grainy" print(f"Generating image ({request.width}x{request.height}) for prompt: {quality_prompt}") # Generation with custom resolution image = pipe( quality_prompt, negative_prompt=negative_prompt, num_inference_steps=num_steps, guidance_scale=8.5, width=request.width, height=request.height, callback_on_step_end=callback_dynamic_cfg ).images[0] # Ensure progress is 100% at the end if request.database_url and request.lesson_id: try: conn = psycopg2.connect(request.database_url) cur = conn.cursor() query = f"UPDATE {request.table_name} SET {request.progress_column} = 100 WHERE id = %s" cur.execute(query, (request.lesson_id,)) conn.commit() cur.close() conn.close() except: pass image_filename = f"image_{request.lesson_id}_{uuid.uuid4().hex[:8]}.png" image_path = os.path.join(OUTPUT_DIR, image_filename) image.save(image_path) # Return the absolute URL pointing to t-800 so the frontend can find it hostname = os.getenv("BRIDGE_HOSTNAME", "t-800") full_url = f"http://{hostname}:8080/outputs/{image_filename}" return {"status": "completed", "url": full_url} except Exception as e: print(f"Generation error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health(): return {"status": "ok", "model_loaded": pipe is not None} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8080)