Skip to content

Add Ollama embedding/querying support #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 135 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 11 additions & 14 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -4,28 +4,27 @@ version = "1.3.1"
edition = "2024"

[dependencies]
rmcp = { version = "0.1.5", features = ["tower", "transport-io", "transport-sse-server", "macros", "server"] } # Add macros, server, schemars
rmcp = { version = "0.1.5", features = ["tower", "transport-io", "transport-sse-server", "macros", "server"] }
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
dotenvy = "0.15"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
thiserror = "2.0.12"
walkdir = "2.5.0"
scraper = "0.23.1"
ndarray = { version = "0.16.1", features = ["serde"] } # Enable serde feature
async-openai = "0.28.0"
# async-trait = "0.1.88" # Removed, likely no longer needed
ndarray = { version = "0.16.1", features = ["serde"] }
async-openai = "0.28.0" # Keep for chat completion (optional)
ollama-rs = "0.2.0" # Add Ollama client
url = "2.4" # Add url parsing for Ollama client construction
futures = "0.3"
bincode = { version = "2.0.1", features = ["serde"] } # Enable serde integration
bincode = { version = "2.0.1", features = ["serde"] }
tiktoken-rs = "0.6.0"
# Configure cargo crate to vendor openssl to avoid system mismatches
cargo = { version = "0.87.1", default-features = false, features = ["vendored-openssl"] }
tempfile = "3.19.1"
anyhow = "1.0.97"
schemars = "0.8.22"
clap = { version = "4.5.34", features = ["cargo", "derive", "env"] }


# --- Platform Specific Dependencies ---

[target.'cfg(not(target_os = "windows"))'.dependencies]
@@ -34,12 +33,10 @@ xdg = { version = "2.5.2", features = ["serde"] }
[target.'cfg(target_os = "windows")'.dependencies]
dirs = "6.0.0"


# Optimize release builds for size
[profile.release]
opt-level = "z" # Optimize for size
lto = true # Enable Link Time Optimization
codegen-units = 1 # Maximize size reduction opportunities
panic = "abort" # Abort on panic to remove unwinding code
strip = true # Strip symbols from binary

opt-level = "z"
lto = true
codegen-units = 1
panic = "abort"
strip = true
221 changes: 170 additions & 51 deletions src/embeddings.rs
Original file line number Diff line number Diff line change
@@ -4,14 +4,18 @@ use async_openai::{
Client as OpenAIClient,
};
use ndarray::{Array1, ArrayView1};
use ollama_rs::{
generation::embeddings::request::GenerateEmbeddingsRequest,
Ollama,
};
use std::sync::OnceLock;
use std::sync::Arc;
use tiktoken_rs::cl100k_base;
use futures::stream::{self, StreamExt};

// Static OnceLock for the OpenAI client
// Static OnceLocks for both clients
pub static OPENAI_CLIENT: OnceLock<OpenAIClient<OpenAIConfig>> = OnceLock::new();

pub static OLLAMA_CLIENT: OnceLock<Ollama> = OnceLock::new();

use bincode::{Encode, Decode};
use serde::{Serialize, Deserialize};
@@ -20,11 +24,10 @@ use serde::{Serialize, Deserialize};
#[derive(Serialize, Deserialize, Debug, Encode, Decode)]
pub struct CachedDocumentEmbedding {
pub path: String,
pub content: String, // Add the extracted document content
pub content: String,
pub vector: Vec<f32>,
}


/// Calculates the cosine similarity between two vectors.
pub fn cosine_similarity(v1: ArrayView1<f32>, v2: ArrayView1<f32>) -> f32 {
let dot_product = v1.dot(&v2);
@@ -38,60 +41,160 @@ pub fn cosine_similarity(v1: ArrayView1<f32>, v2: ArrayView1<f32>) -> f32 {
}
}

/// Generates embeddings for a list of documents using the OpenAI API.
pub async fn generate_embeddings(
client: &OpenAIClient<OpenAIConfig>,
/// Generates embeddings using Ollama with the nomic-embed-text model
pub async fn generate_ollama_embeddings(
ollama_client: &Ollama,
documents: &[Document],
model: &str,
) -> Result<(Vec<(String, Array1<f32>)>, usize), ServerError> { // Return tuple: (embeddings, total_tokens)
// eprintln!("Generating embeddings for {} documents...", documents.len());
) -> Result<Vec<(String, Array1<f32>)>, ServerError> {
eprintln!("Generating embeddings for {} documents using Ollama...", documents.len());

const CONCURRENCY_LIMIT: usize = 4; // Lower concurrency for Ollama
const TOKEN_LIMIT: usize = 8000; // Adjust based on your model's limits

// Get the tokenizer for the model and wrap in Arc
// Get the tokenizer (we'll use this for approximate token counting)
let bpe = Arc::new(cl100k_base().map_err(|e| ServerError::Tiktoken(e.to_string()))?);

const CONCURRENCY_LIMIT: usize = 8; // Number of concurrent requests
const TOKEN_LIMIT: usize = 8000; // Keep a buffer below the 8192 limit
let results = stream::iter(documents.iter().enumerate())
.map(|(index, doc)| {
let ollama_client = ollama_client.clone();
let model = model.to_string();
let doc = doc.clone();
let bpe = Arc::clone(&bpe);

async move {
// Approximate token count for filtering
let token_count = bpe.encode_with_special_tokens(&doc.content).len();

if token_count > TOKEN_LIMIT {
eprintln!(
" Skipping document {}: Approximate tokens ({}) exceed limit ({}). Path: {}",
index + 1,
token_count,
TOKEN_LIMIT,
doc.path
);
return Ok::<Option<(String, Array1<f32>)>, ServerError>(None);
}

eprintln!(
" Processing document {} (approx {} tokens)... Path: {}",
index + 1,
token_count,
doc.path
);

// Create embeddings request for Ollama
let request = GenerateEmbeddingsRequest::new(
model,
doc.content.clone().into(),
);

match ollama_client.generate_embeddings(request).await {
Ok(response) => {
if let Some(embedding) = response.embeddings.first() {
let embedding_array = Array1::from(embedding.clone());
eprintln!(" Received response for document {}.", index + 1);
Ok(Some((doc.path.clone(), embedding_array)))
} else {
Err(ServerError::Config(format!(
"No embeddings returned for document {}",
index + 1
)))
}
}
Err(e) => Err(ServerError::Config(format!(
"Ollama embedding error for document {}: {}",
index + 1, e
)))
}
}
})
.buffer_unordered(CONCURRENCY_LIMIT)
.collect::<Vec<Result<Option<(String, Array1<f32>)>, ServerError>>>()
.await;

// Process collected results
let mut embeddings_vec = Vec::new();
for result in results {
match result {
Ok(Some((path, embedding))) => {
embeddings_vec.push((path, embedding));
}
Ok(None) => {} // Skipped document
Err(e) => {
eprintln!("Error during Ollama embedding generation: {}", e);
return Err(e);
}
}
}

eprintln!(
"Finished generating Ollama embeddings. Successfully processed {} documents.",
embeddings_vec.len()
);
Ok(embeddings_vec)
}

/// Generates embeddings for a single text using Ollama (for questions)
pub async fn generate_single_ollama_embedding(
ollama_client: &Ollama,
text: &str,
model: &str,
) -> Result<Array1<f32>, ServerError> {
let request = GenerateEmbeddingsRequest::new(
model.to_string(),
text.to_string().into(),
);

match ollama_client.generate_embeddings(request).await {
Ok(response) => {
if let Some(embedding) = response.embeddings.first() {
Ok(Array1::from(embedding.clone()))
} else {
Err(ServerError::Config("No embedding returned".to_string()))
}
}
Err(e) => Err(ServerError::Config(format!(
"Ollama embedding error: {}",
e
)))
}
}

/// Legacy OpenAI embedding generation (kept for fallback)
pub async fn generate_openai_embeddings(
client: &OpenAIClient<OpenAIConfig>,
documents: &[Document],
model: &str,
) -> Result<(Vec<(String, Array1<f32>)>, usize), ServerError> {
// Keep the original OpenAI implementation for fallback
let bpe = Arc::new(cl100k_base().map_err(|e| ServerError::Tiktoken(e.to_string()))?);

const CONCURRENCY_LIMIT: usize = 8;
const TOKEN_LIMIT: usize = 8000;

let results = stream::iter(documents.iter().enumerate())
.map(|(index, doc)| {
// Clone client, model, doc, and Arc<BPE> for the async block
let client = client.clone();
let model = model.to_string();
let doc = doc.clone();
let bpe = Arc::clone(&bpe); // Clone the Arc pointer
let bpe = Arc::clone(&bpe);

async move {
// Calculate token count for this document
let token_count = bpe.encode_with_special_tokens(&doc.content).len();

if token_count > TOKEN_LIMIT {
// eprintln!(
// " Skipping document {}: Actual tokens ({}) exceed limit ({}). Path: {}",
// index + 1,
// token_count,
// TOKEN_LIMIT,
// doc.path
// );
// Return Ok(None) to indicate skipping, with 0 tokens processed for this doc
return Ok::<Option<(String, Array1<f32>, usize)>, ServerError>(None); // Include token count type
return Ok::<Option<(String, Array1<f32>, usize)>, ServerError>(None);
}

// Prepare input for this single document
let inputs: Vec<String> = vec![doc.content.clone()];

let request = CreateEmbeddingRequestArgs::default()
.model(&model) // Use cloned model string
.model(&model)
.input(inputs)
.build()?; // Propagates OpenAIError
.build()?;

// eprintln!(
// " Sending request for document {} ({} tokens)... Path: {}",
// index + 1,
// token_count, // Use correct variable name
// doc.path
// );
let response = client.embeddings().create(request).await?; // Propagates OpenAIError
// eprintln!(" Received response for document {}.", index + 1);
let response = client.embeddings().create(request).await?;

if response.data.len() != 1 {
return Err(ServerError::OpenAI(
@@ -107,39 +210,55 @@ pub async fn generate_embeddings(
));
}

// Process result
let embedding_data = response.data.first().unwrap(); // Safe unwrap due to check above
let embedding_data = response.data.first().unwrap();
let embedding_array = Array1::from(embedding_data.embedding.clone());
// Return Ok(Some(...)) for successful embedding, include token count
Ok(Some((doc.path.clone(), embedding_array, token_count))) // Include token count
Ok(Some((doc.path.clone(), embedding_array, token_count)))
}
})
.buffer_unordered(CONCURRENCY_LIMIT) // Run up to CONCURRENCY_LIMIT futures concurrently
.collect::<Vec<Result<Option<(String, Array1<f32>, usize)>, ServerError>>>() // Update collected result type
.buffer_unordered(CONCURRENCY_LIMIT)
.collect::<Vec<Result<Option<(String, Array1<f32>, usize)>, ServerError>>>()
.await;

// Process collected results, filtering out errors and skipped documents, summing tokens
let mut embeddings_vec = Vec::new();
let mut total_processed_tokens: usize = 0;
for result in results {
match result {
Ok(Some((path, embedding, tokens))) => {
embeddings_vec.push((path, embedding)); // Keep successful embeddings
total_processed_tokens += tokens; // Add tokens for successful ones
embeddings_vec.push((path, embedding));
total_processed_tokens += tokens;
}
Ok(None) => {} // Ignore skipped documents
Ok(None) => {}
Err(e) => {
// Log error but potentially continue? Or return the first error?
// For now, let's return the first error encountered.
eprintln!("Error during concurrent embedding generation: {}", e);
eprintln!("Error during OpenAI embedding generation: {}", e);
return Err(e);
}
}
}

eprintln!(
"Finished generating embeddings. Successfully processed {} documents ({} tokens).",
"Finished generating OpenAI embeddings. Successfully processed {} documents ({} tokens).",
embeddings_vec.len(), total_processed_tokens
);
Ok((embeddings_vec, total_processed_tokens)) // Return tuple
Ok((embeddings_vec, total_processed_tokens))
}

/// Main embedding generation function that tries Ollama first, falls back to OpenAI
pub async fn generate_embeddings(
documents: &[Document],
model: &str,
) -> Result<(Vec<(String, Array1<f32>)>, usize), ServerError> {
// Check if Ollama is available
if let Some(ollama_client) = OLLAMA_CLIENT.get() {
eprintln!("Using Ollama for embedding generation with model: {}", model);
// For Ollama, we don't track tokens the same way, so return 0 for token count
let embeddings = generate_ollama_embeddings(ollama_client, documents, model).await?;
Ok((embeddings, 0))
} else if let Some(openai_client) = OPENAI_CLIENT.get() {
eprintln!("Fallback to OpenAI for embedding generation with model: {}", model);
generate_openai_embeddings(openai_client, documents, model).await
} else {
Err(ServerError::Config(
"No embedding client available (neither Ollama nor OpenAI)".to_string()
))
}
}
12 changes: 6 additions & 6 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
use rmcp::ServiceError; // Assuming ServiceError is the correct top-level error
use rmcp::ServiceError;
use thiserror::Error;
use crate::doc_loader::DocLoaderError; // Need to import DocLoaderError from the sibling module
use crate::doc_loader::DocLoaderError;

#[derive(Debug, Error)]
pub enum ServerError {
#[error("Environment variable not set: {0}")]
MissingEnvVar(String),
// MissingArgument removed as clap handles this now
#[error("Configuration Error: {0}")]
Config(String),

#[error("MCP Service Error: {0}")]
Mcp(#[from] ServiceError), // Use ServiceError
Mcp(#[from] ServiceError),
#[error("IO Error: {0}")]
Io(#[from] std::io::Error),
#[error("Document Loading Error: {0}")]
DocLoader(#[from] DocLoaderError),
#[error("OpenAI Error: {0}")]
OpenAI(#[from] async_openai::error::OpenAIError),
#[error("Ollama Error: {0}")]
Ollama(#[from] ollama_rs::error::OllamaError), // Add Ollama error handling
#[error("JSON Error: {0}")]
Json(#[from] serde_json::Error), // Add error for JSON deserialization
Json(#[from] serde_json::Error),
#[error("Tiktoken Error: {0}")]
Tiktoken(String),
#[error("XDG Directory Error: {0}")]
219 changes: 150 additions & 69 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -2,33 +2,34 @@
mod doc_loader;
mod embeddings;
mod error;
mod server; // Keep server module as RustDocsServer is defined there
mod server;

// Use necessary items from modules and crates
use crate::{
doc_loader::Document,
embeddings::{generate_embeddings, CachedDocumentEmbedding, OPENAI_CLIENT},
embeddings::{generate_embeddings, CachedDocumentEmbedding, OPENAI_CLIENT, OLLAMA_CLIENT},
error::ServerError,
server::RustDocsServer, // Import the updated RustDocsServer
server::RustDocsServer,
};
use async_openai::{Client as OpenAIClient, config::OpenAIConfig};
use bincode::config;
use cargo::core::PackageIdSpec;
use clap::Parser; // Import clap Parser
use clap::Parser;
use ndarray::Array1;
// Import rmcp items needed for the new approach
use ollama_rs::Ollama;
use rmcp::{
transport::io::stdio, // Use the standard stdio transport
ServiceExt, // Import the ServiceExt trait for .serve() and .waiting()
transport::io::stdio,
ServiceExt,
};
use std::{
collections::hash_map::DefaultHasher,
env,
fs::{self, File},
hash::{Hash, Hasher}, // Import hashing utilities
hash::{Hash, Hasher},
io::BufReader,
path::PathBuf,
};
// Removed unused url import
#[cfg(not(target_os = "windows"))]
use xdg::BaseDirectories;

@@ -38,12 +39,24 @@ use xdg::BaseDirectories;
#[command(author, version, about, long_about = None)]
struct Cli {
/// The package ID specification (e.g., "serde@^1.0", "tokio").
#[arg()] // Positional argument
#[arg()]
package_spec: String,

/// Optional features to enable for the crate when generating documentation.
#[arg(short = 'F', long, value_delimiter = ',', num_args = 0..)] // Allow multiple comma-separated values
#[arg(short = 'F', long, value_delimiter = ',', num_args = 0..)]
features: Option<Vec<String>>,

/// Use OpenAI instead of Ollama for embeddings (fallback mode)
#[arg(long)]
use_openai: bool,

/// Specify Ollama host (default: localhost)
#[arg(long, default_value = "localhost")]
ollama_host: String,

/// Specify Ollama port (default: 11434)
#[arg(long, default_value_t = 11434)]
ollama_port: u16,
}

// Helper function to create a stable hash from features
@@ -52,12 +65,12 @@ fn hash_features(features: &Option<Vec<String>>) -> String {
.as_ref()
.map(|f| {
let mut sorted_features = f.clone();
sorted_features.sort_unstable(); // Sort for consistent hashing
sorted_features.sort_unstable();
let mut hasher = DefaultHasher::new();
sorted_features.hash(&mut hasher);
format!("{:x}", hasher.finish()) // Return hex representation of hash
format!("{:x}", hasher.finish())
})
.unwrap_or_else(|| "no_features".to_string()) // Use a specific string if no features
.unwrap_or_else(|| "no_features".to_string())
}

#[tokio::main]
@@ -67,9 +80,9 @@ async fn main() -> Result<(), ServerError> {

// --- Parse CLI Arguments ---
let cli = Cli::parse();
let specid_str = cli.package_spec.trim().to_string(); // Trim whitespace
let specid_str = cli.package_spec.trim().to_string();
let features = cli.features.map(|f| {
f.into_iter().map(|s| s.trim().to_string()).collect() // Trim each feature
f.into_iter().map(|s| s.trim().to_string()).collect()
});

// Parse the specid string
@@ -91,19 +104,81 @@ async fn main() -> Result<(), ServerError> {
specid_str, crate_name, crate_version_req, features
);

// --- Determine Paths (incorporating features) ---
// --- Initialize Clients ---

// Initialize Ollama client (unless forced to use OpenAI)
if !cli.use_openai {
// Use the simpler approach: default() for localhost:11434, or construct URL for custom hosts
let ollama_client = if cli.ollama_host == "localhost" && cli.ollama_port == 11434 {
// Use the default for the most common case
eprintln!("Initializing Ollama client with default settings (localhost:11434)");
Ollama::default()
} else {
// For custom hosts, construct the URL properly
let scheme_and_host = if cli.ollama_host.contains("://") {
cli.ollama_host.clone()
} else {
format!("http://{}", cli.ollama_host)
};
eprintln!("Initializing Ollama client at {}:{}", scheme_and_host, cli.ollama_port);
Ollama::new(scheme_and_host, cli.ollama_port)
};

// Test Ollama connection
match ollama_client.show_model_info("nomic-embed-text".to_string()).await {
Ok(_) => {
eprintln!("✓ Connected to Ollama, nomic-embed-text model available");
OLLAMA_CLIENT.set(ollama_client)
.map_err(|_| ServerError::Config("Failed to set Ollama client".to_string()))?;
}
Err(e) => {
eprintln!("⚠ Failed to connect to Ollama or nomic-embed-text not available: {}", e);
eprintln!("Make sure Ollama is running and pull the model with: ollama pull nomic-embed-text");
return Err(ServerError::Config(format!(
"Ollama connection failed: {}. Try using --use-openai flag as fallback.",
e
)));
}
}
}

// Initialize OpenAI client (for chat completion and fallback)
let openai_client = if let Ok(api_base) = env::var("OPENAI_API_BASE") {
let config = OpenAIConfig::new().with_api_base(api_base);
OpenAIClient::with_config(config)
} else {
OpenAIClient::new()
};

// Always set OpenAI client for chat completion
OPENAI_CLIENT.set(openai_client.clone())
.map_err(|_| ServerError::Config("Failed to set OpenAI client".to_string()))?;

// Check if we have any embedding client
if OLLAMA_CLIENT.get().is_none() && !cli.use_openai {
eprintln!("No Ollama client available and not forced to use OpenAI");
let _openai_api_key = env::var("OPENAI_API_KEY")
.map_err(|_| ServerError::MissingEnvVar("OPENAI_API_KEY".to_string()))?;
eprintln!("Falling back to OpenAI for embeddings");
}

// Sanitize the version requirement string
// --- Determine Paths (incorporating features and model type) ---
let sanitized_version_req = crate_version_req
.replace(|c: char| !c.is_alphanumeric() && c != '.' && c != '-', "_");

// Generate a stable hash for the features to use in the path
let features_hash = hash_features(&features);

// Include model type in cache path to avoid conflicts
let model_type = if OLLAMA_CLIENT.get().is_some() && !cli.use_openai {
"ollama"
} else {
"openai"
};

// Construct the relative path component including features hash
let embeddings_relative_path = PathBuf::from(&crate_name)
.join(&sanitized_version_req)
.join(&features_hash) // Add features hash as a directory level
.join(&features_hash)
.join(model_type) // Separate cache for different embedding models
.join("embeddings.bin");

#[cfg(not(target_os = "windows"))]
@@ -121,7 +196,6 @@ async fn main() -> Result<(), ServerError> {
ServerError::Config("Could not determine cache directory on Windows".to_string())
})?;
let app_cache_dir = cache_dir.join("rustdocs-mcp-server");
// Ensure the base app cache directory exists
fs::create_dir_all(&app_cache_dir).map_err(ServerError::Io)?;
app_cache_dir.join(embeddings_relative_path)
};
@@ -181,17 +255,6 @@ async fn main() -> Result<(), ServerError> {
let mut generation_cost: Option<f64> = None;
let mut documents_for_server: Vec<Document> = loaded_documents_from_cache.unwrap_or_default();

// --- Initialize OpenAI Client (needed for question embedding even if cache hit) ---
let openai_client = if let Ok(api_base) = env::var("OPENAI_API_BASE") {
let config = OpenAIConfig::new().with_api_base(api_base);
OpenAIClient::with_config(config)
} else {
OpenAIClient::new()
};
OPENAI_CLIENT
.set(openai_client.clone()) // Clone the client for the OnceCell
.expect("Failed to set OpenAI client");

let final_embeddings = match loaded_embeddings {
Some(embeddings) => {
eprintln!("Using embeddings and documents loaded from cache.");
@@ -200,33 +263,39 @@ async fn main() -> Result<(), ServerError> {
None => {
eprintln!("Proceeding with documentation loading and embedding generation.");

let _openai_api_key = env::var("OPENAI_API_KEY")
.map_err(|_| ServerError::MissingEnvVar("OPENAI_API_KEY".to_string()))?;

eprintln!(
"Loading documents for crate: {} (Version Req: {}, Features: {:?})",
crate_name, crate_version_req, features
);
// Pass features to load_documents
let loaded_documents =
doc_loader::load_documents(&crate_name, &crate_version_req, features.as_ref())?; // Pass features here
doc_loader::load_documents(&crate_name, &crate_version_req, features.as_ref())?;
eprintln!("Loaded {} documents.", loaded_documents.len());
documents_for_server = loaded_documents.clone();

eprintln!("Generating embeddings...");
let embedding_model: String = env::var("EMBEDDING_MODEL")
.unwrap_or_else(|_| "text-embedding-3-small".to_string());
let (generated_embeddings, total_tokens) =
generate_embeddings(&openai_client, &loaded_documents, &embedding_model).await?;
let embedding_model = if OLLAMA_CLIENT.get().is_some() && !cli.use_openai {
"nomic-embed-text".to_string()
} else {
env::var("EMBEDDING_MODEL")
.unwrap_or_else(|_| "text-embedding-3-small".to_string())
};

let cost_per_million = 0.02;
let estimated_cost = (total_tokens as f64 / 1_000_000.0) * cost_per_million;
eprintln!(
"Embedding generation cost for {} tokens: ${:.6}",
total_tokens, estimated_cost
);
generated_tokens = Some(total_tokens);
generation_cost = Some(estimated_cost);
let (generated_embeddings, total_tokens) =
generate_embeddings(&loaded_documents, &embedding_model).await?;

// Only calculate cost for OpenAI
if cli.use_openai || OLLAMA_CLIENT.get().is_none() {
let cost_per_million = 0.02;
let estimated_cost = (total_tokens as f64 / 1_000_000.0) * cost_per_million;
eprintln!(
"Embedding generation cost for {} tokens: ${:.6}",
total_tokens, estimated_cost
);
generated_tokens = Some(total_tokens);
generation_cost = Some(estimated_cost);
} else {
eprintln!("Generated embeddings using Ollama (local, no cost)");
}

eprintln!(
"Saving generated documents and embeddings to: {:?}",
@@ -293,50 +362,62 @@ async fn main() -> Result<(), ServerError> {
.map(|f| format!(" Features: {:?}", f))
.unwrap_or_default();

let model_info = if OLLAMA_CLIENT.get().is_some() && !cli.use_openai {
"using Ollama/nomic-embed-text".to_string()
} else {
"using OpenAI".to_string()
};

let startup_message = if loaded_from_cache {
format!(
"Server for crate '{}' (Version Req: '{}'{}) initialized. Loaded {} embeddings from cache.",
crate_name, crate_version_req, features_str, final_embeddings.len()
"Server for crate '{}' (Version Req: '{}'{}) initialized. Loaded {} embeddings from cache ({}).",
crate_name, crate_version_req, features_str, final_embeddings.len(), model_info
)
} else {
let tokens = generated_tokens.unwrap_or(0);
let cost = generation_cost.unwrap_or(0.0);
format!(
"Server for crate '{}' (Version Req: '{}'{}) initialized. Generated {} embeddings for {} tokens (Est. Cost: ${:.6}).",
crate_name,
crate_version_req,
features_str,
final_embeddings.len(),
tokens,
cost
)
if OLLAMA_CLIENT.get().is_some() && !cli.use_openai {
format!(
"Server for crate '{}' (Version Req: '{}'{}) initialized. Generated {} embeddings using Ollama (local).",
crate_name,
crate_version_req,
features_str,
final_embeddings.len(),
)
} else {
format!(
"Server for crate '{}' (Version Req: '{}'{}) initialized. Generated {} embeddings for {} tokens (Est. Cost: ${:.6}) using OpenAI.",
crate_name,
crate_version_req,
features_str,
final_embeddings.len(),
tokens,
cost
)
}
};

// Create the service instance using the updated ::new()
let service = RustDocsServer::new(
crate_name.clone(), // Pass crate_name directly
crate_name.clone(),
documents_for_server,
final_embeddings,
startup_message,
)?;

// --- Use standard stdio transport and ServiceExt ---
eprintln!("Rust Docs MCP server starting via stdio...");

// Serve the server using the ServiceExt trait and standard stdio transport
let server_handle = service.serve(stdio()).await.map_err(|e| {
eprintln!("Failed to start server: {:?}", e);
ServerError::McpRuntime(e.to_string()) // Use the new McpRuntime variant
ServerError::McpRuntime(e.to_string())
})?;

eprintln!("{} Docs MCP server running...", &crate_name);
eprintln!("{} Docs MCP server running {} ...", &crate_name, model_info);

// Wait for the server to complete (e.g., stdin closed)
server_handle.waiting().await.map_err(|e| {
eprintln!("Server encountered an error while running: {:?}", e);
ServerError::McpRuntime(e.to_string()) // Use the new McpRuntime variant
ServerError::McpRuntime(e.to_string())
})?;

eprintln!("Rust Docs MCP server stopped.");
Ok(())
}
}
334 changes: 207 additions & 127 deletions src/server.rs

Large diffs are not rendered by default.