feat: implementing embedding AI
This commit is contained in:
@@ -2608,28 +2608,92 @@ pub async fn chat_with_tutor(
|
||||
}
|
||||
}
|
||||
|
||||
// 2.2 Knowledge Base Retrieval (RAG)
|
||||
let search_results = sqlx::query(
|
||||
r#"
|
||||
SELECT content_chunk
|
||||
FROM knowledge_base
|
||||
WHERE organization_id = $1
|
||||
AND search_vector @@ plainto_tsquery('english', $2)
|
||||
LIMIT 3
|
||||
"#,
|
||||
)
|
||||
.bind(org_ctx.id)
|
||||
.bind(&payload.message)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
// 2.2 Knowledge Base Retrieval (RAG) - Hybrid Search
|
||||
// First try semantic search with embeddings (more accurate)
|
||||
// Fall back to full-text search if embeddings not available
|
||||
|
||||
use common::ai::{self, generate_embedding};
|
||||
|
||||
let mut kb_context = String::new();
|
||||
if !search_results.is_empty() {
|
||||
kb_context.push_str("\n--- CONTEXTO ADICIONAL DE LA BASE DE CONOCIMIENTOS ---\n");
|
||||
for row in search_results {
|
||||
let chunk: String = row.get("content_chunk");
|
||||
kb_context.push_str(&format!("Relevant Snippet: {}\n\n", chunk));
|
||||
|
||||
// Try semantic search with embeddings first
|
||||
// Create client that accepts invalid certificates (for dev with self-signed certs)
|
||||
let client = reqwest::Client::builder()
|
||||
.danger_accept_invalid_certs(true)
|
||||
.danger_accept_invalid_hostnames(true)
|
||||
.build()
|
||||
.map_err(|e| {
|
||||
tracing::warn!("Failed to create HTTP client for embeddings: {}", e);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, format!("HTTP client error: {}", e))
|
||||
})?;
|
||||
|
||||
let ollama_url = ai::get_ollama_url();
|
||||
let model = ai::get_embedding_model();
|
||||
|
||||
match generate_embedding(&client, &ollama_url, &model, &payload.message).await {
|
||||
Ok(response) => {
|
||||
let pgvector = ai::embedding_to_pgvector(&response.embedding);
|
||||
|
||||
// Semantic search with pgvector
|
||||
let search_results = sqlx::query(
|
||||
r#"
|
||||
SELECT content_chunk, 1 - (embedding <=> $1::vector) AS similarity
|
||||
FROM knowledge_base
|
||||
WHERE organization_id = $2
|
||||
AND embedding IS NOT NULL
|
||||
ORDER BY embedding <=> $1::vector
|
||||
LIMIT 5
|
||||
"#,
|
||||
)
|
||||
.bind(&pgvector)
|
||||
.bind(org_ctx.id)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
// Filter by similarity threshold (0.5)
|
||||
let relevant_results: Vec<_> = search_results
|
||||
.into_iter()
|
||||
.filter(|row| {
|
||||
let similarity: f64 = row.get("similarity");
|
||||
similarity >= 0.5
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !relevant_results.is_empty() {
|
||||
kb_context.push_str("\n--- CONTEXTO DE LA BASE DE CONOCIMIENTOS (Búsqueda Semántica) ---\n");
|
||||
for row in relevant_results {
|
||||
let chunk: String = row.get("content_chunk");
|
||||
kb_context.push_str(&format!("Relevant Snippet: {}\n\n", chunk));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Semantic search failed, falling back to full-text search: {}", e);
|
||||
|
||||
// Fall back to full-text search
|
||||
let search_results = sqlx::query(
|
||||
r#"
|
||||
SELECT content_chunk
|
||||
FROM knowledge_base
|
||||
WHERE organization_id = $1
|
||||
AND search_vector @@ plainto_tsquery('english', $2)
|
||||
LIMIT 3
|
||||
"#,
|
||||
)
|
||||
.bind(org_ctx.id)
|
||||
.bind(&payload.message)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
if !search_results.is_empty() {
|
||||
kb_context.push_str("\n--- CONTEXTO DE LA BASE DE CONOCIMIENTOS (Búsqueda Full-Text) ---\n");
|
||||
for row in search_results {
|
||||
let chunk: String = row.get("content_chunk");
|
||||
kb_context.push_str(&format!("Relevant Snippet: {}\n\n", chunk));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,287 @@
|
||||
//! Handlers for PGVector embeddings in Knowledge Base (LMS)
|
||||
//! Enables semantic search for AI tutor chat with RAG
|
||||
|
||||
use axum::{
|
||||
Json,
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
};
|
||||
use common::ai::{self, generate_embedding};
|
||||
use common::middleware::Org;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgPool;
|
||||
use uuid::Uuid;
|
||||
|
||||
// ==================== Query Parameters ====================
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct KnowledgeSearchFilters {
|
||||
pub query: String,
|
||||
pub course_id: Option<Uuid>,
|
||||
pub lesson_id: Option<Uuid>,
|
||||
pub limit: Option<i32>,
|
||||
pub threshold: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)]
|
||||
pub struct KnowledgeSearchResult {
|
||||
pub id: Uuid,
|
||||
pub course_id: Uuid,
|
||||
pub lesson_id: Option<Uuid>,
|
||||
pub block_id: Option<Uuid>,
|
||||
pub content_chunk: String,
|
||||
pub similarity: f64, // PostgreSQL vector similarity returns double precision
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct GenerateKnowledgeEmbeddingsResult {
|
||||
pub processed: i32,
|
||||
pub failed: i32,
|
||||
pub duration_ms: u64,
|
||||
}
|
||||
|
||||
// ==================== Generate Embeddings ====================
|
||||
|
||||
/// POST /api/knowledge-base/embeddings/generate - Generate embeddings for all knowledge base entries
|
||||
pub async fn generate_knowledge_embeddings(
|
||||
Org(org_ctx): Org,
|
||||
State(pool): State<PgPool>,
|
||||
) -> Result<Json<GenerateKnowledgeEmbeddingsResult>, (StatusCode, String)> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
// Create client that accepts invalid certificates (for dev with self-signed certs)
|
||||
let client = reqwest::Client::builder()
|
||||
.danger_accept_invalid_certs(true)
|
||||
.danger_accept_invalid_hostnames(true)
|
||||
.build()
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("HTTP client error: {}", e)))?;
|
||||
|
||||
let ollama_url = ai::get_ollama_url();
|
||||
let model = ai::get_embedding_model();
|
||||
|
||||
// Get knowledge base entries without embeddings
|
||||
let entries: Vec<KnowledgeBaseEntry> = sqlx::query_as(
|
||||
r#"
|
||||
SELECT * FROM knowledge_base
|
||||
WHERE organization_id = $1
|
||||
AND (embedding IS NULL OR embedding_updated_at IS NULL)
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 100
|
||||
"#
|
||||
)
|
||||
.bind(org_ctx.id)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
let total = entries.len();
|
||||
let mut processed = 0;
|
||||
let mut failed = 0;
|
||||
|
||||
for entry in entries {
|
||||
// Generate embedding from content chunk
|
||||
match generate_embedding(&client, &ollama_url, &model, &entry.content_chunk).await {
|
||||
Ok(response) => {
|
||||
let pgvector = ai::embedding_to_pgvector(&response.embedding);
|
||||
|
||||
// Update entry with embedding
|
||||
let result: Result<(i64,), sqlx::Error> = sqlx::query_as(
|
||||
r#"
|
||||
UPDATE knowledge_base
|
||||
SET embedding = $1::vector,
|
||||
embedding_updated_at = NOW()
|
||||
WHERE id = $2
|
||||
RETURNING 1
|
||||
"#
|
||||
)
|
||||
.bind(&pgvector)
|
||||
.bind(entry.id)
|
||||
.fetch_one(&pool)
|
||||
.await;
|
||||
|
||||
if result.is_ok() {
|
||||
processed += 1;
|
||||
} else {
|
||||
failed += 1;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"Failed to generate embedding for knowledge entry {}: {}",
|
||||
entry.id,
|
||||
e
|
||||
);
|
||||
failed += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let duration_ms = start.elapsed().as_millis() as u64;
|
||||
|
||||
tracing::info!(
|
||||
"Generated knowledge embeddings: {} processed, {} failed in {}ms",
|
||||
processed,
|
||||
failed,
|
||||
duration_ms
|
||||
);
|
||||
|
||||
Ok(Json(GenerateKnowledgeEmbeddingsResult {
|
||||
processed,
|
||||
failed,
|
||||
duration_ms,
|
||||
}))
|
||||
}
|
||||
|
||||
/// POST /api/knowledge-base/{id}/embedding/regenerate - Regenerate embedding for a specific entry
|
||||
pub async fn regenerate_knowledge_embedding(
|
||||
Org(org_ctx): Org,
|
||||
Path(entry_id): Path<Uuid>,
|
||||
State(pool): State<PgPool>,
|
||||
) -> Result<StatusCode, (StatusCode, String)> {
|
||||
// Create client that accepts invalid certificates
|
||||
let client = reqwest::Client::builder()
|
||||
.danger_accept_invalid_certs(true)
|
||||
.danger_accept_invalid_hostnames(true)
|
||||
.build()
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("HTTP client error: {}", e)))?;
|
||||
|
||||
let ollama_url = ai::get_ollama_url();
|
||||
let model = ai::get_embedding_model();
|
||||
|
||||
// Get entry
|
||||
let entry: KnowledgeBaseEntry = sqlx::query_as(
|
||||
"SELECT * FROM knowledge_base WHERE id = $1 AND organization_id = $2"
|
||||
)
|
||||
.bind(entry_id)
|
||||
.bind(org_ctx.id)
|
||||
.fetch_optional(&pool)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
|
||||
.ok_or((StatusCode::NOT_FOUND, "Knowledge base entry not found".to_string()))?;
|
||||
|
||||
// Generate embedding
|
||||
let response = generate_embedding(&client, &ollama_url, &model, &entry.content_chunk)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("AI error: {}", e)))?;
|
||||
|
||||
let pgvector = ai::embedding_to_pgvector(&response.embedding);
|
||||
|
||||
// Update entry
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE knowledge_base
|
||||
SET embedding = $1::vector,
|
||||
embedding_updated_at = NOW()
|
||||
WHERE id = $2
|
||||
"#
|
||||
)
|
||||
.bind(&pgvector)
|
||||
.bind(entry_id)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
Ok(StatusCode::OK)
|
||||
}
|
||||
|
||||
// ==================== Semantic Search ====================
|
||||
|
||||
/// GET /api/knowledge-base/semantic-search - Search knowledge base by semantic similarity
|
||||
pub async fn semantic_search_knowledge(
|
||||
Org(org_ctx): Org,
|
||||
State(pool): State<PgPool>,
|
||||
Query(filters): Query<KnowledgeSearchFilters>,
|
||||
) -> Result<Json<Vec<KnowledgeSearchResult>>, (StatusCode, String)> {
|
||||
// Create client that accepts invalid certificates
|
||||
let client = reqwest::Client::builder()
|
||||
.danger_accept_invalid_certs(true)
|
||||
.danger_accept_invalid_hostnames(true)
|
||||
.build()
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("HTTP client error: {}", e)))?;
|
||||
|
||||
let ollama_url = ai::get_ollama_url();
|
||||
let model = ai::get_embedding_model();
|
||||
|
||||
// Generate embedding for query
|
||||
let embedding_response = generate_embedding(&client, &ollama_url, &model, &filters.query)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("AI error: {}", e)))?;
|
||||
|
||||
let pgvector = ai::embedding_to_pgvector(&embedding_response.embedding);
|
||||
|
||||
let limit = filters.limit.unwrap_or(10);
|
||||
let threshold = filters.threshold.unwrap_or(0.5);
|
||||
|
||||
// Build query with optional filters
|
||||
let mut query = String::from(
|
||||
r#"
|
||||
SELECT
|
||||
id,
|
||||
course_id,
|
||||
lesson_id,
|
||||
block_id,
|
||||
content_chunk,
|
||||
1 - (embedding <=> $1::vector) AS similarity,
|
||||
metadata
|
||||
FROM knowledge_base
|
||||
WHERE organization_id = $2
|
||||
AND embedding IS NOT NULL
|
||||
AND 1 - (embedding <=> $1::vector) >= $3
|
||||
"#
|
||||
);
|
||||
|
||||
let mut param_idx = 3;
|
||||
|
||||
if let Some(course_id) = filters.course_id {
|
||||
param_idx += 1;
|
||||
query.push_str(&format!(" AND course_id = ${}", param_idx));
|
||||
}
|
||||
|
||||
if let Some(lesson_id) = filters.lesson_id {
|
||||
param_idx += 1;
|
||||
query.push_str(&format!(" AND lesson_id = ${}", param_idx));
|
||||
}
|
||||
|
||||
param_idx += 1;
|
||||
query.push_str(&format!(" ORDER BY embedding <=> $1::vector LIMIT ${}", param_idx));
|
||||
|
||||
let mut sql_query = sqlx::query_as::<_, KnowledgeSearchResult>(&query)
|
||||
.bind(&pgvector)
|
||||
.bind(org_ctx.id)
|
||||
.bind(threshold);
|
||||
|
||||
if let Some(course_id) = filters.course_id {
|
||||
sql_query = sql_query.bind(course_id);
|
||||
}
|
||||
|
||||
if let Some(lesson_id) = filters.lesson_id {
|
||||
sql_query = sql_query.bind(lesson_id);
|
||||
}
|
||||
|
||||
sql_query = sql_query.bind(limit);
|
||||
|
||||
let results = sql_query
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
Ok(Json(results))
|
||||
}
|
||||
|
||||
// ==================== Helper Structs ====================
|
||||
|
||||
#[derive(Debug, sqlx::FromRow, Clone)]
|
||||
struct KnowledgeBaseEntry {
|
||||
id: Uuid,
|
||||
organization_id: Uuid,
|
||||
course_id: Uuid,
|
||||
lesson_id: Option<Uuid>,
|
||||
block_id: Option<Uuid>,
|
||||
content_chunk: String,
|
||||
chunk_order: i32,
|
||||
metadata: Option<serde_json::Value>,
|
||||
#[allow(dead_code)]
|
||||
created_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
@@ -6,6 +6,7 @@ mod handlers_discussions;
|
||||
mod handlers_notes;
|
||||
mod handlers_payments;
|
||||
mod handlers_peer_review;
|
||||
mod handlers_embeddings;
|
||||
mod lti;
|
||||
mod jwks;
|
||||
mod predictive;
|
||||
@@ -149,6 +150,19 @@ async fn main() {
|
||||
"/notifications/{id}/read",
|
||||
post(handlers::mark_notification_as_read),
|
||||
)
|
||||
// Knowledge Base Embedding Routes for Semantic RAG
|
||||
.route(
|
||||
"/knowledge-base/embeddings/generate",
|
||||
post(handlers_embeddings::generate_knowledge_embeddings),
|
||||
)
|
||||
.route(
|
||||
"/knowledge-base/semantic-search",
|
||||
get(handlers_embeddings::semantic_search_knowledge),
|
||||
)
|
||||
.route(
|
||||
"/knowledge-base/{id}/embedding/regenerate",
|
||||
post(handlers_embeddings::regenerate_knowledge_embedding),
|
||||
)
|
||||
// Discussion Forums Routes
|
||||
.route(
|
||||
"/courses/{id}/discussions",
|
||||
|
||||
Reference in New Issue
Block a user