feat: Introduce course marketing features with dedicated metadata, image generation, and UI in both studio and experience apps.

This commit is contained in:
2026-03-04 15:41:34 -03:00
parent 4458decd22
commit 01c54429a0
25 changed files with 1453 additions and 401 deletions
+68 -13
View File
@@ -3,7 +3,7 @@ import time
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from diffusers import StableDiffusionPipeline
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from PIL import Image
import uuid
@@ -20,14 +20,20 @@ pipe = None
def load_model():
global pipe
print("Loading Stable Diffusion model on CPU...")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading Stable Diffusion model on {device}...")
pipe = StableDiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
)
pipe.to("cpu")
# pipe.enable_model_cpu_offload()
pipe.enable_attention_slicing()
# Use a high-quality scheduler for better detail
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.to(device)
if device == "cpu":
pipe.enable_attention_slicing()
print("Model loaded successfully.")
from contextlib import asynccontextmanager
@@ -43,9 +49,16 @@ app = FastAPI(lifespan=lifespan)
from fastapi.staticfiles import StaticFiles
app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs")
import psycopg2
class ImageRequest(BaseModel):
prompt: str
lesson_id: str
database_url: str = None
table_name: str = "lessons"
progress_column: str = "generation_progress"
width: int = 512
height: int = 512
@app.post("/generate")
async def generate_image(request: ImageRequest):
@@ -53,24 +66,66 @@ async def generate_image(request: ImageRequest):
if pipe is None:
load_model()
num_steps = 150
def progress_callback(step: int, timestep: int, latents: torch.FloatTensor):
if request.database_url and request.lesson_id:
try:
progress = int((step / num_steps) * 100)
conn = psycopg2.connect(request.database_url)
cur = conn.cursor()
# Use psycopg2.sql for safe table/column names if possible,
# but here we'll just format since we control the backend values
query = f"UPDATE {request.table_name} SET {request.progress_column} = %s WHERE id = %s"
cur.execute(query, (progress, request.lesson_id))
conn.commit()
cur.close()
conn.close()
except Exception as db_e:
print(f"Database update error: {db_e}")
def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
if progress_callback:
progress_callback(step_index, timestep, None)
return callback_kwargs
try:
print(f"Generating image for prompt: {request.prompt}")
quality_prompt = f"{request.prompt}, highly detailed, high quality, masterpiece, 8k, realistic, photographic, sharp focus, perfect anatomy"
negative_prompt = "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, low quality, low resolution, bad hands, extra fingers, cartoon, anime, illustration, draft, grainy"
print(f"Generating image ({request.width}x{request.height}) for prompt: {quality_prompt}")
generator = torch.manual_seed(42)
# Using a small number of steps for speed since it's on CPU
# Generation with custom resolution
image = pipe(
request.prompt,
num_inference_steps=20,
generator=generator
quality_prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_steps,
guidance_scale=8.5,
width=request.width,
height=request.height,
callback_on_step_end=callback_dynamic_cfg
).images[0]
# Ensure progress is 100% at the end
if request.database_url and request.lesson_id:
try:
conn = psycopg2.connect(request.database_url)
cur = conn.cursor()
query = f"UPDATE {request.table_name} SET {request.progress_column} = 100 WHERE id = %s"
cur.execute(query, (request.lesson_id,))
conn.commit()
cur.close()
conn.close()
except:
pass
image_filename = f"image_{request.lesson_id}_{uuid.uuid4().hex[:8]}.png"
image_path = os.path.join(OUTPUT_DIR, image_filename)
image.save(image_path)
# Return the absolute URL pointing to t-800 so the frontend can find it
hostname = os.getenv("BRIDGE_HOSTNAME", "localhost")
hostname = os.getenv("BRIDGE_HOSTNAME", "t-800")
full_url = f"http://{hostname}:8080/outputs/{image_filename}"
return {"status": "completed", "url": full_url}