actualizaciones
This commit is contained in:
@@ -12,7 +12,7 @@ DECLARE
|
||||
v_org_id UUID;
|
||||
BEGIN
|
||||
-- Find or create organization
|
||||
IF p_org_name IS NULL OR p_org_name = '' OR p_org_name = 'Default Organization' THEN
|
||||
IF p_org_name IS NULL OR p_org_name = '' OR p_org_name = 'Default Organization' OR p_org_name = 'Organización por Defecto' THEN
|
||||
v_org_id := '00000000-0000-0000-0000-000000000001';
|
||||
ELSE
|
||||
INSERT INTO organizations (name)
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
-- Add video_generation_status and video_generation_error to lessons table
|
||||
ALTER TABLE lessons ADD COLUMN IF NOT EXISTS video_generation_status VARCHAR(20) DEFAULT 'idle';
|
||||
ALTER TABLE lessons ADD COLUMN IF NOT EXISTS video_generation_error TEXT;
|
||||
@@ -0,0 +1,9 @@
|
||||
fastapi
|
||||
uvicorn
|
||||
torch
|
||||
diffusers
|
||||
transformers
|
||||
accelerate
|
||||
pillow
|
||||
imageio-ffmpeg
|
||||
pydantic
|
||||
@@ -0,0 +1,88 @@
|
||||
import os
|
||||
import time
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from PIL import Image
|
||||
import uuid
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Configuration
|
||||
MODEL_ID = "runwayml/stable-diffusion-v1-5"
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
OUTPUT_DIR = os.path.join(BASE_DIR, "outputs")
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# Global variables for the model
|
||||
pipe = None
|
||||
|
||||
def load_model():
|
||||
global pipe
|
||||
print("Loading Stable Diffusion model on CPU...")
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
pipe.to("cpu")
|
||||
# pipe.enable_model_cpu_offload()
|
||||
pipe.enable_attention_slicing()
|
||||
print("Model loaded successfully.")
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Model loading is heavy, we'll do it on first request to avoid timeout at startup
|
||||
yield
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# Serve generated images
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs")
|
||||
|
||||
class ImageRequest(BaseModel):
|
||||
prompt: str
|
||||
lesson_id: str
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate_image(request: ImageRequest):
|
||||
global pipe
|
||||
if pipe is None:
|
||||
load_model()
|
||||
|
||||
try:
|
||||
print(f"Generating image for prompt: {request.prompt}")
|
||||
|
||||
generator = torch.manual_seed(42)
|
||||
# Using a small number of steps for speed since it's on CPU
|
||||
image = pipe(
|
||||
request.prompt,
|
||||
num_inference_steps=20,
|
||||
generator=generator
|
||||
).images[0]
|
||||
|
||||
image_filename = f"image_{request.lesson_id}_{uuid.uuid4().hex[:8]}.png"
|
||||
image_path = os.path.join(OUTPUT_DIR, image_filename)
|
||||
|
||||
image.save(image_path)
|
||||
|
||||
# Return the absolute URL pointing to t-800 so the frontend can find it
|
||||
hostname = os.getenv("BRIDGE_HOSTNAME", "localhost")
|
||||
full_url = f"http://{hostname}:8080/outputs/{image_filename}"
|
||||
|
||||
return {"status": "completed", "url": full_url}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Generation error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok", "model_loaded": pipe is not None}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8080)
|
||||
@@ -1283,6 +1283,125 @@ pub async fn generate_quiz(
|
||||
Ok(Json(quiz_blocks))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct VideoAIRequest {
|
||||
pub prompt: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn generate_image(
|
||||
Org(org_ctx): Org,
|
||||
claims: common::auth::Claims,
|
||||
State(pool): State<PgPool>,
|
||||
Path(id): Path<Uuid>,
|
||||
Json(payload): Json<VideoAIRequest>,
|
||||
) -> Result<Json<Lesson>, StatusCode> {
|
||||
if let Some(prompt) = &payload.prompt {
|
||||
tracing::info!("Received prompt for video generation: {}", prompt);
|
||||
}
|
||||
|
||||
// 1. Fetch lesson
|
||||
let _lesson =
|
||||
sqlx::query_as::<_, Lesson>("SELECT * FROM lessons WHERE id = $1 AND organization_id = $2")
|
||||
.bind(id)
|
||||
.bind(org_ctx.id)
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.map_err(|_| StatusCode::NOT_FOUND)?;
|
||||
|
||||
// 2. Set status to queued
|
||||
let updated_lesson = sqlx::query_as::<_, Lesson>(
|
||||
"UPDATE lessons SET video_generation_status = 'queued', video_generation_error = NULL WHERE id = $1 RETURNING *",
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Database update failed (video queued): {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
log_action(
|
||||
&pool,
|
||||
org_ctx.id,
|
||||
claims.sub,
|
||||
"VIDEO_GENERATION_QUEUED",
|
||||
"Lesson",
|
||||
id,
|
||||
json!({ "status": "queued" }),
|
||||
)
|
||||
.await;
|
||||
|
||||
// 3. Spawn background task
|
||||
let pool_clone = pool.clone();
|
||||
let prompt_to_task = payload.prompt.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_image_generation_task(pool_clone, id, prompt_to_task).await {
|
||||
tracing::error!("Image generation task failed for lesson {}: {}", id, e);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Json(updated_lesson))
|
||||
}
|
||||
|
||||
pub async fn run_image_generation_task(pool: PgPool, lesson_id: Uuid, custom_prompt: Option<String>) -> Result<(), String> {
|
||||
// 1. Set status to processing
|
||||
sqlx::query("UPDATE lessons SET video_generation_status = 'processing' WHERE id = $1")
|
||||
.bind(lesson_id)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.map_err(|e| format!("Update to processing failed: {}", e))?;
|
||||
|
||||
// 2. Call Local Video Bridge (Python)
|
||||
let client = reqwest::Client::new();
|
||||
let bridge_base_url = std::env::var("LOCAL_VIDEO_BRIDGE_URL")
|
||||
.unwrap_or_else(|_| "http://localhost:8080".to_string());
|
||||
let bridge_url = format!("{}/generate", bridge_base_url);
|
||||
|
||||
// Fallback logic for prompt: Custom Prompt > Title
|
||||
let final_prompt = match custom_prompt {
|
||||
Some(p) if !p.is_empty() => p,
|
||||
_ => {
|
||||
sqlx::query_scalar("SELECT title FROM lessons WHERE id = $1")
|
||||
.bind(lesson_id)
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to fetch fallback prompt: {}", e))?
|
||||
}
|
||||
};
|
||||
|
||||
let response = client.post(bridge_url)
|
||||
.json(&serde_json::json!({
|
||||
"prompt": final_prompt,
|
||||
"lesson_id": lesson_id.to_string()
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to call video bridge: {}", e))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let err_text = response.text().await.unwrap_or_default();
|
||||
return Err(format!("Video bridge error: {}", err_text));
|
||||
}
|
||||
|
||||
let result: serde_json::Value = response.json().await
|
||||
.map_err(|e| format!("Failed to parse video bridge response: {}", e))?;
|
||||
|
||||
let content_url = result["url"].as_str()
|
||||
.ok_or_else(|| "Video bridge response missing URL".to_string())?;
|
||||
|
||||
// 3. Complete task
|
||||
sqlx::query(
|
||||
"UPDATE lessons SET video_generation_status = 'completed', content_url = $1, content_type = 'image' WHERE id = $2"
|
||||
)
|
||||
.bind(content_url)
|
||||
.bind(lesson_id)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.map_err(|e| format!("Update to completed failed: {}", e))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_lesson(
|
||||
Org(org_ctx): Org,
|
||||
State(pool): State<PgPool>,
|
||||
|
||||
@@ -14,6 +14,7 @@ pub struct BackgroundTask {
|
||||
pub title: String,
|
||||
pub course_title: Option<String>,
|
||||
pub transcription_status: Option<String>,
|
||||
pub video_generation_status: Option<String>,
|
||||
pub updated_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
@@ -34,11 +35,13 @@ pub async fn get_background_tasks(
|
||||
l.title,
|
||||
c.title as course_title,
|
||||
l.transcription_status,
|
||||
l.video_generation_status,
|
||||
l.updated_at
|
||||
FROM lessons l
|
||||
JOIN modules m ON l.module_id = m.id
|
||||
JOIN courses c ON m.course_id = c.id
|
||||
WHERE l.transcription_status IN ('queued', 'processing', 'failed')
|
||||
OR l.video_generation_status IN ('queued', 'processing', 'failed')
|
||||
ORDER BY l.updated_at DESC
|
||||
"#;
|
||||
|
||||
|
||||
@@ -74,6 +74,36 @@ async fn main() {
|
||||
}
|
||||
}
|
||||
|
||||
// Check for queued video generations
|
||||
let queued_video_lessons: Vec<sqlx::types::Uuid> = match sqlx::query_scalar(
|
||||
"SELECT id FROM lessons WHERE video_generation_status = 'queued' LIMIT 5",
|
||||
)
|
||||
.fetch_all(&worker_pool)
|
||||
.await
|
||||
{
|
||||
Ok(ids) => ids,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to fetch queued video lessons: {}", e);
|
||||
tokio::time::sleep(Duration::from_secs(10)).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
for lesson_id in queued_video_lessons {
|
||||
tracing::info!("Processing video generation for lesson: {}", lesson_id);
|
||||
if let Err(e) =
|
||||
handlers::run_image_generation_task(worker_pool.clone(), lesson_id, None).await
|
||||
{
|
||||
tracing::error!("Image generation task failed for lesson {}: {}", lesson_id, e);
|
||||
let _ = sqlx::query(
|
||||
"UPDATE lessons SET video_generation_status = 'failed' WHERE id = $1",
|
||||
)
|
||||
.bind(lesson_id)
|
||||
.execute(&worker_pool)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(5)).await;
|
||||
}
|
||||
});
|
||||
@@ -145,6 +175,7 @@ async fn main() {
|
||||
.route("/lessons/{id}/vtt", get(handlers::get_lesson_vtt))
|
||||
.route("/lessons/{id}/summarize", post(handlers::summarize_lesson))
|
||||
.route("/lessons/{id}/generate-quiz", post(handlers::generate_quiz))
|
||||
.route("/lessons/{id}/generate-image", post(handlers::generate_image))
|
||||
.route("/courses/generate", post(handlers::generate_course))
|
||||
.route("/courses/{id}/export", get(handlers::export_course))
|
||||
.route("/courses/import", post(handlers::import_course))
|
||||
|
||||
Reference in New Issue
Block a user