feat: Introduce course marketing features with dedicated metadata, image generation, and UI in both studio and experience apps.
This commit is contained in:
@@ -3,7 +3,7 @@ import time
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
|
||||
from PIL import Image
|
||||
import uuid
|
||||
|
||||
@@ -20,14 +20,20 @@ pipe = None
|
||||
|
||||
def load_model():
|
||||
global pipe
|
||||
print("Loading Stable Diffusion model on CPU...")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Loading Stable Diffusion model on {device}...")
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=torch.float32,
|
||||
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
||||
)
|
||||
pipe.to("cpu")
|
||||
# pipe.enable_model_cpu_offload()
|
||||
pipe.enable_attention_slicing()
|
||||
# Use a high-quality scheduler for better detail
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(device)
|
||||
|
||||
if device == "cpu":
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
print("Model loaded successfully.")
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
@@ -43,9 +49,16 @@ app = FastAPI(lifespan=lifespan)
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs")
|
||||
|
||||
import psycopg2
|
||||
|
||||
class ImageRequest(BaseModel):
|
||||
prompt: str
|
||||
lesson_id: str
|
||||
database_url: str = None
|
||||
table_name: str = "lessons"
|
||||
progress_column: str = "generation_progress"
|
||||
width: int = 512
|
||||
height: int = 512
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate_image(request: ImageRequest):
|
||||
@@ -53,24 +66,66 @@ async def generate_image(request: ImageRequest):
|
||||
if pipe is None:
|
||||
load_model()
|
||||
|
||||
num_steps = 150
|
||||
|
||||
def progress_callback(step: int, timestep: int, latents: torch.FloatTensor):
|
||||
if request.database_url and request.lesson_id:
|
||||
try:
|
||||
progress = int((step / num_steps) * 100)
|
||||
conn = psycopg2.connect(request.database_url)
|
||||
cur = conn.cursor()
|
||||
# Use psycopg2.sql for safe table/column names if possible,
|
||||
# but here we'll just format since we control the backend values
|
||||
query = f"UPDATE {request.table_name} SET {request.progress_column} = %s WHERE id = %s"
|
||||
cur.execute(query, (progress, request.lesson_id))
|
||||
conn.commit()
|
||||
cur.close()
|
||||
conn.close()
|
||||
except Exception as db_e:
|
||||
print(f"Database update error: {db_e}")
|
||||
|
||||
def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
|
||||
if progress_callback:
|
||||
progress_callback(step_index, timestep, None)
|
||||
return callback_kwargs
|
||||
|
||||
try:
|
||||
print(f"Generating image for prompt: {request.prompt}")
|
||||
quality_prompt = f"{request.prompt}, highly detailed, high quality, masterpiece, 8k, realistic, photographic, sharp focus, perfect anatomy"
|
||||
negative_prompt = "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, low quality, low resolution, bad hands, extra fingers, cartoon, anime, illustration, draft, grainy"
|
||||
|
||||
print(f"Generating image ({request.width}x{request.height}) for prompt: {quality_prompt}")
|
||||
|
||||
generator = torch.manual_seed(42)
|
||||
# Using a small number of steps for speed since it's on CPU
|
||||
# Generation with custom resolution
|
||||
image = pipe(
|
||||
request.prompt,
|
||||
num_inference_steps=20,
|
||||
generator=generator
|
||||
quality_prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=num_steps,
|
||||
guidance_scale=8.5,
|
||||
width=request.width,
|
||||
height=request.height,
|
||||
callback_on_step_end=callback_dynamic_cfg
|
||||
).images[0]
|
||||
|
||||
# Ensure progress is 100% at the end
|
||||
if request.database_url and request.lesson_id:
|
||||
try:
|
||||
conn = psycopg2.connect(request.database_url)
|
||||
cur = conn.cursor()
|
||||
query = f"UPDATE {request.table_name} SET {request.progress_column} = 100 WHERE id = %s"
|
||||
cur.execute(query, (request.lesson_id,))
|
||||
conn.commit()
|
||||
cur.close()
|
||||
conn.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
image_filename = f"image_{request.lesson_id}_{uuid.uuid4().hex[:8]}.png"
|
||||
image_path = os.path.join(OUTPUT_DIR, image_filename)
|
||||
|
||||
image.save(image_path)
|
||||
|
||||
# Return the absolute URL pointing to t-800 so the frontend can find it
|
||||
hostname = os.getenv("BRIDGE_HOSTNAME", "localhost")
|
||||
hostname = os.getenv("BRIDGE_HOSTNAME", "t-800")
|
||||
full_url = f"http://{hostname}:8080/outputs/{image_filename}"
|
||||
|
||||
return {"status": "completed", "url": full_url}
|
||||
|
||||
Reference in New Issue
Block a user