feat: implementing embedding AI

This commit is contained in:
2026-03-18 17:15:39 -03:00
parent e8cdf61468
commit 64d3d5be91
32 changed files with 3568 additions and 174 deletions
+1
View File
@@ -19,3 +19,4 @@ sha2.workspace = true
hex.workspace = true
tracing.workspace = true
openidconnect.workspace = true
thiserror.workspace = true
+146
View File
@@ -0,0 +1,146 @@
//! AI Utilities for OpenCCB
//! Provides embedding generation and other AI helper functions
use serde::{Deserialize, Serialize};
use thiserror::Error;
/// Default embedding model for Ollama
pub const DEFAULT_EMBEDDING_MODEL: &str = "nomic-embed-text";
/// Default Ollama URL
pub const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
/// Embedding dimensions for nomic-embed-text
pub const EMBEDDING_DIMENSIONS: usize = 768;
#[derive(Error, Debug)]
pub enum AiError {
#[error("Ollama request failed: {0}")]
OllamaRequest(String),
#[error("Invalid embedding response: {0}")]
InvalidResponse(String),
#[error("Model not available: {0}")]
ModelNotAvailable(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingResponse {
pub embedding: Vec<f32>,
#[serde(default)]
pub model: String,
}
/// Get Ollama URL from environment or default
pub fn get_ollama_url() -> String {
std::env::var("LOCAL_OLLAMA_URL").unwrap_or_else(|_| DEFAULT_OLLAMA_URL.to_string())
}
/// Get embedding model from environment or default
pub fn get_embedding_model() -> String {
std::env::var("EMBEDDING_MODEL").unwrap_or_else(|_| DEFAULT_EMBEDDING_MODEL.to_string())
}
/// Create a reqwest client that accepts invalid certificates (for dev with self-signed certs)
fn create_insecure_client() -> Result<reqwest::Client, AiError> {
reqwest::Client::builder()
.danger_accept_invalid_certs(true)
.danger_accept_invalid_hostnames(true)
.build()
.map_err(|e| AiError::OllamaRequest(format!("Failed to create HTTP client: {}", e)))
}
/// Generate embedding for text using Ollama
///
/// # Arguments
/// * `client` - reqwest::Client instance
/// * `ollama_url` - Base URL for Ollama (e.g., "http://localhost:11434")
/// * `model` - Embedding model name (default: "nomic-embed-text")
/// * `text` - Text to embed
pub async fn generate_embedding(
client: &reqwest::Client,
ollama_url: &str,
model: &str,
text: &str,
) -> Result<EmbeddingResponse, AiError> {
let endpoint = format!("{}/api/embeddings", ollama_url.trim_end_matches('/'));
let response = client
.post(&endpoint)
.json(&serde_json::json!({
"model": model,
"prompt": text
}))
.send()
.await
.map_err(|e| AiError::OllamaRequest(format!("Request failed: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(AiError::OllamaRequest(
format!("Ollama API error ({}): {}", status, error_text)
));
}
let embedding_response: EmbeddingResponse = response
.json()
.await
.map_err(|e| AiError::InvalidResponse(format!("Failed to parse response: {}", e)))?;
Ok(embedding_response)
}
/// Generate embeddings for multiple texts in batch
pub async fn generate_embeddings_batch(
client: &reqwest::Client,
ollama_url: &str,
model: &str,
texts: Vec<&str>,
) -> Result<Vec<EmbeddingResponse>, AiError> {
let mut embeddings = Vec::with_capacity(texts.len());
for text in texts {
let embedding = generate_embedding(client, ollama_url, model, text).await?;
embeddings.push(embedding);
}
Ok(embeddings)
}
/// Convert a vector of f32 to pgvector-compatible format
/// PostgreSQL vector format: "[0.1,0.2,0.3,...]"
pub fn embedding_to_pgvector(embedding: &[f32]) -> String {
let formatted: Vec<String> = embedding
.iter()
.map(|v| format!("{:.7}", v))
.collect();
format!("[{}]", formatted.join(","))
}
/// Parse pgvector format back to Vec<f32>
pub fn pgvector_to_embedding(pgvector: &str) -> Result<Vec<f32>, String> {
let trimmed = pgvector.trim().trim_start_matches('[').trim_end_matches(']');
trimmed
.split(',')
.map(|s| s.trim().parse::<f32>().map_err(|e| format!("Parse error: {}", e)))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_to_pgvector() {
let embedding = vec![0.1, 0.2, 0.3];
let pg = embedding_to_pgvector(&embedding);
assert_eq!(pg, "[0.1000000,0.2000000,0.3000000]");
}
#[test]
fn test_pgvector_to_embedding() {
let pg = "[0.1000000,0.2000000,0.3000000]";
let embedding = pgvector_to_embedding(pg).unwrap();
assert_eq!(embedding, vec![0.1, 0.2, 0.3]);
}
}
+1
View File
@@ -1,3 +1,4 @@
pub mod ai;
pub mod auth;
pub mod middleware;
pub mod models;
+17 -12
View File
@@ -1176,18 +1176,18 @@ pub struct PublicProfile {
// ==================== Test Templates ====================
#[derive(Debug, Serialize, Deserialize, sqlx::Type, Clone, PartialEq)]
#[sqlx(type_name = "course_level", rename_all = "snake_case")]
#[serde(rename_all = "snake_case")]
#[sqlx(type_name = "course_level", rename_all = "lowercase")]
#[serde(rename_all = "lowercase")]
pub enum CourseLevel {
Beginner,
Beginner1,
Beginner2,
Beginner_1,
Beginner_2,
Intermediate,
Intermediate1,
Intermediate2,
Intermediate_1,
Intermediate_2,
Advanced,
Advanced1,
Advanced2,
Advanced_1,
Advanced_2,
}
#[derive(Debug, Serialize, Deserialize, sqlx::Type, Clone, PartialEq)]
@@ -1229,10 +1229,11 @@ impl std::fmt::Display for TestType {
pub struct TestTemplate {
pub id: Uuid,
pub organization_id: Uuid,
pub mysql_course_id: Option<i32>, // Reference to imported MySQL course
pub name: String,
pub description: Option<String>,
pub level: CourseLevel,
pub course_type: CourseType,
pub level: Option<CourseLevel>, // Deprecated: use mysql_course_id instead
pub course_type: Option<CourseType>, // Deprecated: use mysql_course_id instead
pub test_type: TestType,
pub duration_minutes: i32,
pub passing_score: i32, // 0-100 percentage
@@ -1280,8 +1281,9 @@ pub struct TestTemplateQuestion {
pub struct CreateTestTemplatePayload {
pub name: String,
pub description: Option<String>,
pub level: CourseLevel,
pub course_type: CourseType,
pub mysql_course_id: Option<i32>, // Reference to imported MySQL course (preferred)
pub level: Option<CourseLevel>, // Fallback if mysql_course_id not provided
pub course_type: Option<CourseType>, // Fallback if mysql_course_id not provided
pub test_type: TestType,
pub duration_minutes: i32,
pub passing_score: i32,
@@ -1295,6 +1297,7 @@ pub struct CreateTestTemplatePayload {
pub struct UpdateTestTemplatePayload {
pub name: Option<String>,
pub description: Option<String>,
pub mysql_course_id: Option<i32>,
pub level: Option<CourseLevel>,
pub course_type: Option<CourseType>,
pub test_type: Option<TestType>,
@@ -1394,6 +1397,8 @@ pub struct QuestionBank {
pub created_by: Option<Uuid>,
pub created_at: chrono::DateTime<chrono::Utc>,
pub updated_at: chrono::DateTime<chrono::Utc>,
pub embedding: Option<String>, // PGVector embedding for semantic search
pub embedding_updated_at: Option<chrono::DateTime<chrono::Utc>>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]