feat: implementing embedding AI
This commit is contained in:
@@ -19,3 +19,4 @@ sha2.workspace = true
|
||||
hex.workspace = true
|
||||
tracing.workspace = true
|
||||
openidconnect.workspace = true
|
||||
thiserror.workspace = true
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
//! AI Utilities for OpenCCB
|
||||
//! Provides embedding generation and other AI helper functions
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
/// Default embedding model for Ollama
|
||||
pub const DEFAULT_EMBEDDING_MODEL: &str = "nomic-embed-text";
|
||||
|
||||
/// Default Ollama URL
|
||||
pub const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
|
||||
|
||||
/// Embedding dimensions for nomic-embed-text
|
||||
pub const EMBEDDING_DIMENSIONS: usize = 768;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum AiError {
|
||||
#[error("Ollama request failed: {0}")]
|
||||
OllamaRequest(String),
|
||||
#[error("Invalid embedding response: {0}")]
|
||||
InvalidResponse(String),
|
||||
#[error("Model not available: {0}")]
|
||||
ModelNotAvailable(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmbeddingResponse {
|
||||
pub embedding: Vec<f32>,
|
||||
#[serde(default)]
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
/// Get Ollama URL from environment or default
|
||||
pub fn get_ollama_url() -> String {
|
||||
std::env::var("LOCAL_OLLAMA_URL").unwrap_or_else(|_| DEFAULT_OLLAMA_URL.to_string())
|
||||
}
|
||||
|
||||
/// Get embedding model from environment or default
|
||||
pub fn get_embedding_model() -> String {
|
||||
std::env::var("EMBEDDING_MODEL").unwrap_or_else(|_| DEFAULT_EMBEDDING_MODEL.to_string())
|
||||
}
|
||||
|
||||
/// Create a reqwest client that accepts invalid certificates (for dev with self-signed certs)
|
||||
fn create_insecure_client() -> Result<reqwest::Client, AiError> {
|
||||
reqwest::Client::builder()
|
||||
.danger_accept_invalid_certs(true)
|
||||
.danger_accept_invalid_hostnames(true)
|
||||
.build()
|
||||
.map_err(|e| AiError::OllamaRequest(format!("Failed to create HTTP client: {}", e)))
|
||||
}
|
||||
|
||||
/// Generate embedding for text using Ollama
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `client` - reqwest::Client instance
|
||||
/// * `ollama_url` - Base URL for Ollama (e.g., "http://localhost:11434")
|
||||
/// * `model` - Embedding model name (default: "nomic-embed-text")
|
||||
/// * `text` - Text to embed
|
||||
pub async fn generate_embedding(
|
||||
client: &reqwest::Client,
|
||||
ollama_url: &str,
|
||||
model: &str,
|
||||
text: &str,
|
||||
) -> Result<EmbeddingResponse, AiError> {
|
||||
let endpoint = format!("{}/api/embeddings", ollama_url.trim_end_matches('/'));
|
||||
|
||||
let response = client
|
||||
.post(&endpoint)
|
||||
.json(&serde_json::json!({
|
||||
"model": model,
|
||||
"prompt": text
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AiError::OllamaRequest(format!("Request failed: {}", e)))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
return Err(AiError::OllamaRequest(
|
||||
format!("Ollama API error ({}): {}", status, error_text)
|
||||
));
|
||||
}
|
||||
|
||||
let embedding_response: EmbeddingResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AiError::InvalidResponse(format!("Failed to parse response: {}", e)))?;
|
||||
|
||||
Ok(embedding_response)
|
||||
}
|
||||
|
||||
/// Generate embeddings for multiple texts in batch
|
||||
pub async fn generate_embeddings_batch(
|
||||
client: &reqwest::Client,
|
||||
ollama_url: &str,
|
||||
model: &str,
|
||||
texts: Vec<&str>,
|
||||
) -> Result<Vec<EmbeddingResponse>, AiError> {
|
||||
let mut embeddings = Vec::with_capacity(texts.len());
|
||||
|
||||
for text in texts {
|
||||
let embedding = generate_embedding(client, ollama_url, model, text).await?;
|
||||
embeddings.push(embedding);
|
||||
}
|
||||
|
||||
Ok(embeddings)
|
||||
}
|
||||
|
||||
/// Convert a vector of f32 to pgvector-compatible format
|
||||
/// PostgreSQL vector format: "[0.1,0.2,0.3,...]"
|
||||
pub fn embedding_to_pgvector(embedding: &[f32]) -> String {
|
||||
let formatted: Vec<String> = embedding
|
||||
.iter()
|
||||
.map(|v| format!("{:.7}", v))
|
||||
.collect();
|
||||
format!("[{}]", formatted.join(","))
|
||||
}
|
||||
|
||||
/// Parse pgvector format back to Vec<f32>
|
||||
pub fn pgvector_to_embedding(pgvector: &str) -> Result<Vec<f32>, String> {
|
||||
let trimmed = pgvector.trim().trim_start_matches('[').trim_end_matches(']');
|
||||
trimmed
|
||||
.split(',')
|
||||
.map(|s| s.trim().parse::<f32>().map_err(|e| format!("Parse error: {}", e)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_embedding_to_pgvector() {
|
||||
let embedding = vec![0.1, 0.2, 0.3];
|
||||
let pg = embedding_to_pgvector(&embedding);
|
||||
assert_eq!(pg, "[0.1000000,0.2000000,0.3000000]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pgvector_to_embedding() {
|
||||
let pg = "[0.1000000,0.2000000,0.3000000]";
|
||||
let embedding = pgvector_to_embedding(pg).unwrap();
|
||||
assert_eq!(embedding, vec![0.1, 0.2, 0.3]);
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod ai;
|
||||
pub mod auth;
|
||||
pub mod middleware;
|
||||
pub mod models;
|
||||
|
||||
+17
-12
@@ -1176,18 +1176,18 @@ pub struct PublicProfile {
|
||||
// ==================== Test Templates ====================
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, sqlx::Type, Clone, PartialEq)]
|
||||
#[sqlx(type_name = "course_level", rename_all = "snake_case")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[sqlx(type_name = "course_level", rename_all = "lowercase")]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum CourseLevel {
|
||||
Beginner,
|
||||
Beginner1,
|
||||
Beginner2,
|
||||
Beginner_1,
|
||||
Beginner_2,
|
||||
Intermediate,
|
||||
Intermediate1,
|
||||
Intermediate2,
|
||||
Intermediate_1,
|
||||
Intermediate_2,
|
||||
Advanced,
|
||||
Advanced1,
|
||||
Advanced2,
|
||||
Advanced_1,
|
||||
Advanced_2,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, sqlx::Type, Clone, PartialEq)]
|
||||
@@ -1229,10 +1229,11 @@ impl std::fmt::Display for TestType {
|
||||
pub struct TestTemplate {
|
||||
pub id: Uuid,
|
||||
pub organization_id: Uuid,
|
||||
pub mysql_course_id: Option<i32>, // Reference to imported MySQL course
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub level: CourseLevel,
|
||||
pub course_type: CourseType,
|
||||
pub level: Option<CourseLevel>, // Deprecated: use mysql_course_id instead
|
||||
pub course_type: Option<CourseType>, // Deprecated: use mysql_course_id instead
|
||||
pub test_type: TestType,
|
||||
pub duration_minutes: i32,
|
||||
pub passing_score: i32, // 0-100 percentage
|
||||
@@ -1280,8 +1281,9 @@ pub struct TestTemplateQuestion {
|
||||
pub struct CreateTestTemplatePayload {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub level: CourseLevel,
|
||||
pub course_type: CourseType,
|
||||
pub mysql_course_id: Option<i32>, // Reference to imported MySQL course (preferred)
|
||||
pub level: Option<CourseLevel>, // Fallback if mysql_course_id not provided
|
||||
pub course_type: Option<CourseType>, // Fallback if mysql_course_id not provided
|
||||
pub test_type: TestType,
|
||||
pub duration_minutes: i32,
|
||||
pub passing_score: i32,
|
||||
@@ -1295,6 +1297,7 @@ pub struct CreateTestTemplatePayload {
|
||||
pub struct UpdateTestTemplatePayload {
|
||||
pub name: Option<String>,
|
||||
pub description: Option<String>,
|
||||
pub mysql_course_id: Option<i32>,
|
||||
pub level: Option<CourseLevel>,
|
||||
pub course_type: Option<CourseType>,
|
||||
pub test_type: Option<TestType>,
|
||||
@@ -1394,6 +1397,8 @@ pub struct QuestionBank {
|
||||
pub created_by: Option<Uuid>,
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
pub updated_at: chrono::DateTime<chrono::Utc>,
|
||||
pub embedding: Option<String>, // PGVector embedding for semantic search
|
||||
pub embedding_updated_at: Option<chrono::DateTime<chrono::Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
|
||||
Reference in New Issue
Block a user