diff --git a/Cargo.lock b/Cargo.lock index aa41fa4..87e1539 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,12 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "accelerate-src" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "415ed64958754dbe991900f3940677e6a7eefb4d7367afd70d642677b0c7d19d" + [[package]] name = "addr2line" version = "0.24.2" @@ -151,9 +157,11 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d1e306c8a4276ba57ce9fac76d823cc8c8a7fca14bf222ac20ad8b12c4273152" dependencies = [ + "accelerate-src", "byteorder", "gemm", "half", + "libc", "memmap2", "num-traits", "num_cpus", @@ -173,6 +181,7 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39d417059c44d719fd03a0f711ccfe148d341469c9273d4b5731ebe965b2c97e" dependencies = [ + "accelerate-src", "candle-core", "half", "num-traits", @@ -188,6 +197,7 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "962a277e90dea20968164175138b836bba8b51b57505579fd628d79933da2b70" dependencies = [ + "accelerate-src", "byteorder", "candle-core", "candle-nn", @@ -307,6 +317,7 @@ version = "0.0.0" dependencies = [ "anyhow", "serde", + "tracing", ] [[package]] diff --git a/README.md b/README.md index 2fe43bf..062a5b0 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ Cydonia is a library based on [candle][candle] for developing modern AI applicat ```rust use cydonia::Model; + fn main() { let model = Model::new("gemma2").tag("latest"); let response = model.invoke("Hello, world!"); @@ -13,12 +14,20 @@ fn main() { We support quantized models only derived from `gemma` and `llama` family. -## Special Thanks +## TODOs -- [candle][candle] -- [ollama][ollama] +- [x] Support chat interface ( history prompts ) +- [ ] Function encoder for llama3 tools (static) +- [ ] Cydonia as service + - [ ] RPC support for llama3 tools (remote) + - [ ] GraphQL support for llama3 tools (remote) +- [ ] RAG support +- [ ] Agent interface +- [ ] Multi-agent support (single-node) +- [ ] An application based on the tools +- [ ] p2p for the decentralized cydonia network (multi-node) +- [ ] Test gpu [candle]: https://github.com/huggingface/candle -[ollama]: https://github.com/ollama/ollama diff --git a/crates/candle/Cargo.toml b/crates/candle/Cargo.toml index aa10ee2..8a93a75 100644 --- a/crates/candle/Cargo.toml +++ b/crates/candle/Cargo.toml @@ -15,3 +15,15 @@ rand.workspace = true serde.workspace = true tokenizers.workspace = true tracing.workspace = true + +[features] +accelerate = [ + "candle-core/accelerate", + "candle-nn/accelerate", + "candle-transformers/accelerate", +] + +[target.'cfg(target = "aarch64-apple-darwin")'.dependencies] +candle-core = { workspace = true, features = ["accelerate"] } +candle-nn = { workspace = true, features = ["accelerate"] } +candle-transformers = { workspace = true, features = ["accelerate"] } diff --git a/crates/candle/examples/llama.rs b/crates/candle/examples/llama.rs new file mode 100644 index 0000000..41288fd --- /dev/null +++ b/crates/candle/examples/llama.rs @@ -0,0 +1,38 @@ +use ccore::{Message, Release}; +use cydonia_candle::{Llama, ProcessorConfig}; +use std::io::Write; + +fn main() { + let mut model = Llama::new(ProcessorConfig::default(), Release::default()).unwrap(); + let mut init = true; + loop { + print!("> "); + std::io::stdout().flush().unwrap(); + + // Read input + let mut input = String::new(); + std::io::stdin().read_line(&mut input).unwrap(); + if input.ends_with('\n') { + input.pop(); + if input.ends_with('\r') { + input.pop(); + } + } + + // Generate response + let mut response = String::new(); + let message = Message::user(input); + let stream = model + .complete(&[message], init) + .expect("failed to generate response"); + for token in stream { + response.push_str(&token); + + print!("{}", token); + std::io::stdout().flush().unwrap(); + } + println!(); + + init = false; + } +} diff --git a/crates/candle/src/inference.rs b/crates/candle/src/inference.rs index 00b9db6..9075b32 100644 --- a/crates/candle/src/inference.rs +++ b/crates/candle/src/inference.rs @@ -1,15 +1,33 @@ //! Cydonia inference interface -use std::fs::File; - use anyhow::Result; use candle_core::{quantized::gguf_file::Content, Device, Tensor}; use candle_transformers::models::quantized_llama; +use ccore::{chat, Message}; +use std::fs::File; /// The inference interface for language models pub trait Inference: Sized { /// The max sequence length const MAX_SEQ_LEN: usize; + /// The formatter for the model + type Formatter: chat::Formatter; + + /// The end of stream token + fn eos_token() -> &'static str { + ::EOS_TOKEN + } + + /// Format the messages into a prompt + fn prompt(messages: &[Message]) -> Result { + ::format(messages) + } + + /// Complete the messages + fn complete(messages: &[Message]) -> Result { + ::complete(messages) + } + /// Load model from gguf file fn gguf(device: &Device, file: &mut File) -> Result; @@ -20,14 +38,16 @@ pub trait Inference: Sized { impl Inference for quantized_llama::ModelWeights { const MAX_SEQ_LEN: usize = quantized_llama::MAX_SEQ_LEN; + type Formatter = chat::Llama3; + fn gguf(device: &Device, file: &mut File) -> Result { let content = Content::read(file)?; let model = Self::from_gguf(content, file, device)?; Ok(model) } - fn forward(&mut self, input: &Tensor, squeeze: usize) -> Result { - quantized_llama::ModelWeights::forward(self, input, squeeze) + fn forward(&mut self, input: &Tensor, pos: usize) -> Result { + quantized_llama::ModelWeights::forward(self, input, pos) .map_err(|e| anyhow::anyhow!("failed to forward: {e}")) } } diff --git a/crates/candle/src/lib.rs b/crates/candle/src/lib.rs index 38b2b5b..1ddb7cb 100644 --- a/crates/candle/src/lib.rs +++ b/crates/candle/src/lib.rs @@ -5,7 +5,7 @@ mod inference; mod loader; mod model; mod processor; -mod stream; +mod token; pub use { device::detect as device, @@ -13,7 +13,7 @@ pub use { loader::Loader, model::Model, processor::{Processor, ProcessorConfig, SampleBuilder}, - stream::TokenStream, + token::{TokenStream, Tokenizer}, }; /// The Llama model diff --git a/crates/candle/src/loader.rs b/crates/candle/src/loader.rs index 1adadd9..1903072 100644 --- a/crates/candle/src/loader.rs +++ b/crates/candle/src/loader.rs @@ -1,12 +1,11 @@ //! Model loader -use crate::{Inference, TokenStream}; +use crate::{Inference, Tokenizer}; use anyhow::Result; use candle_core::Device; -use ccore::{Manifest, TOKENIZER}; +use ccore::{Release, TOKENIZER}; use hf_hub::api::sync::Api; use std::fs::File; -use tokenizers::Tokenizer; /// Huggingface model loader /// @@ -16,30 +15,30 @@ pub struct Loader { api: Api, /// The manifest of the model - manifest: Manifest, + release: Release, } impl Loader { /// Load the model - pub fn new(manifest: Manifest) -> Result { + pub fn new(release: Release) -> Result { Ok(Self { - manifest, + release, api: Api::new()?, }) } /// Load the tokenizer - pub fn tokenizer(&self) -> Result { + pub fn tokenizer(&self) -> Result { let trepo = self.api.model(TOKENIZER.into()); - let tokenizer = Tokenizer::from_file(trepo.get(self.manifest.release.tokenizer())?) + let tokenizer = tokenizers::Tokenizer::from_file(trepo.get(self.release.tokenizer())?) .map_err(|e| anyhow::anyhow!("failed to load tokenizer: {e}"))?; - Ok(TokenStream::new(tokenizer)) + Tokenizer::new::(tokenizer) } /// Load the model pub fn model(&self, device: &Device) -> Result { - let mrepo = self.api.model(self.manifest.release.repo()?.into()); - let model = mrepo.get(&self.manifest.release.model(self.manifest.quantization))?; + let mrepo = self.api.model(self.release.repo().into()); + let model = mrepo.get(&self.release.model())?; let mut file = File::open(model)?; let model = M::gguf(device, &mut file)?; Ok(model) diff --git a/crates/candle/src/model.rs b/crates/candle/src/model.rs index 166064e..c2d45ba 100644 --- a/crates/candle/src/model.rs +++ b/crates/candle/src/model.rs @@ -1,14 +1,13 @@ //! Model interface -use crate::{Inference, Loader, Processor, ProcessorConfig, TokenStream}; +use crate::{Inference, Loader, Processor, ProcessorConfig, TokenStream, Tokenizer}; use anyhow::Result; -use ccore::{Manifest, Message}; -use std::io::Write; +use ccore::{Message, Release}; /// Language Model interface pub struct Model { /// The tokenizer of the model - tokenizer: TokenStream, + tokenizer: Tokenizer, /// The weights of the model weights: I, @@ -19,10 +18,10 @@ pub struct Model { impl Model { /// Create a new model - pub fn new(config: ProcessorConfig, manifest: Manifest) -> Result { - let loader = Loader::new(manifest)?; - let tokenizer = loader.tokenizer()?; + pub fn new(config: ProcessorConfig, release: Release) -> Result { let processor = config.build(); + let loader = Loader::new(release)?; + let tokenizer = loader.tokenizer::()?; let weights = loader.model::(&processor.device)?; Ok(Self { @@ -33,54 +32,18 @@ impl Model { } /// Complete the chat - pub fn complete(&mut self, messages: &mut [Message]) -> Result { - let message = messages - .first() - .ok_or_else(|| anyhow::anyhow!("no messages"))?; - - let to_sample = self.processor.sample_len.saturating_sub(1); - let prompt_tokens = self - .tokenizer - .prompt(&message.content)? - .sample_len(to_sample) - .max_seq_len::() - .encode()?; - - // process the prompt tokens - let mut next_token = self - .processor - .sample_tokens(&prompt_tokens) - .sample(&mut self.weights)?; - - // process the tokens - let mut all_tokens = vec![next_token]; - let eos_token = self - .tokenizer - .token("") - .ok_or_else(|| anyhow::anyhow!("eos token not found"))?; - - let response = String::new(); - let pos = prompt_tokens.len(); - for index in 0..to_sample { - next_token = self - .processor - .sample_tokens(&[next_token]) - .all_tokens(&all_tokens) - .pos(pos + index) - .sample(&mut self.weights)?; - - all_tokens.push(next_token); - if let Some(t) = self.tokenizer.next_token(next_token)? { - print!("{t}"); - std::io::stdout().flush()?; - } - - if next_token == eos_token { - break; - } - } - - println!(); - Ok(response) + pub fn complete<'ts>( + &'ts mut self, + messages: &[Message], + init: bool, + ) -> Result> { + let formatted = if init { + I::prompt(messages)? + } else { + I::complete(messages)? + }; + + self.tokenizer + .stream(&mut self.weights, &mut self.processor, formatted) } } diff --git a/crates/candle/src/processor/config.rs b/crates/candle/src/processor/config.rs index cb6c567..a7168af 100644 --- a/crates/candle/src/processor/config.rs +++ b/crates/candle/src/processor/config.rs @@ -102,6 +102,8 @@ impl ProcessorConfig { } /// Set the sample length + /// + /// TODO: if there is a way to embed the sample length in the system prompt? pub fn sample_len(mut self, sample_len: usize) -> Self { self.sample_len = sample_len; self @@ -112,11 +114,11 @@ impl Default for ProcessorConfig { fn default() -> Self { Self { gpu: false, - seed: Some(1_024_243_212), + seed: None, temperature: Some(0.6), top_p: Some(0.9), top_k: Some(50), - sample_len: 256, + sample_len: 1024, repeat_penalty: 1.0, repeat_last_n: 64, } diff --git a/crates/candle/src/processor/mod.rs b/crates/candle/src/processor/mod.rs index f7f5485..6e7dfd9 100644 --- a/crates/candle/src/processor/mod.rs +++ b/crates/candle/src/processor/mod.rs @@ -40,8 +40,8 @@ impl Processor { } /// Sample tokens - pub fn sample_tokens<'s>(&'s mut self, tokens: &'s [u32]) -> SampleBuilder<'s> { - SampleBuilder::new(self, tokens) + pub fn sample_token(&mut self, token: u32) -> SampleBuilder<'_> { + SampleBuilder::new(self, token) } /// Apply repeat penalty diff --git a/crates/candle/src/processor/sample.rs b/crates/candle/src/processor/sample.rs index 2188e8e..dcae4c7 100644 --- a/crates/candle/src/processor/sample.rs +++ b/crates/candle/src/processor/sample.rs @@ -3,30 +3,30 @@ use crate::{Inference, Processor}; /// Sample builder pub struct SampleBuilder<'s> { - tokens: &'s [u32], + token: u32, unsqueeze: usize, pos: usize, squeeze: usize, - all_tokens: &'s [u32], + cur_tokens: &'s [u32], processor: &'s mut Processor, } impl<'s> SampleBuilder<'s> { /// Create a new sample builder - pub fn new(processor: &'s mut Processor, tokens: &'s [u32]) -> Self { + pub fn new(processor: &'s mut Processor, token: u32) -> Self { Self { - tokens, + token, unsqueeze: 0, pos: 0, squeeze: 0, - all_tokens: &[], + cur_tokens: &[], processor, } } /// Set the all tokens - pub fn all_tokens(mut self, all_tokens: &'s [u32]) -> Self { - self.all_tokens = all_tokens; + pub fn cur_tokens(mut self, cur_tokens: &'s [u32]) -> Self { + self.cur_tokens = cur_tokens; self } @@ -50,10 +50,10 @@ impl<'s> SampleBuilder<'s> { /// Build the sample pub fn sample(self, model: &mut impl Inference) -> anyhow::Result { - let input = self.processor.tensor(self.tokens, self.unsqueeze)?; + let input = self.processor.tensor(&[self.token], self.unsqueeze)?; let mut logits = model.forward(&input, self.pos)?.squeeze(self.squeeze)?; - if !self.all_tokens.is_empty() { - logits = self.processor.repeat_penalty(logits, self.all_tokens)?; + if !self.cur_tokens.is_empty() { + logits = self.processor.repeat_penalty(logits, self.cur_tokens)?; } self.processor diff --git a/crates/candle/src/stream/mod.rs b/crates/candle/src/stream/mod.rs deleted file mode 100644 index 7931444..0000000 --- a/crates/candle/src/stream/mod.rs +++ /dev/null @@ -1,81 +0,0 @@ -//! Token stream handler - -use anyhow::Result; -use prompt::PromptBuilder; -use tokenizers::Tokenizer; - -mod prompt; - -/// A token stream handler -pub struct TokenStream { - tokenizer: Tokenizer, - tokens: Vec, - prev_index: usize, - current_index: usize, -} - -impl TokenStream { - /// Create a new token stream - pub fn new(tokenizer: Tokenizer) -> Self { - Self { - tokenizer, - tokens: Vec::new(), - prev_index: 0, - current_index: 0, - } - } - - /// Clear the token stream - pub fn clear(&mut self) { - self.tokens.clear(); - self.prev_index = 0; - self.current_index = 0; - } - - fn decode(&self, tokens: &[u32]) -> Result { - match self.tokenizer.decode(tokens, true) { - Ok(str) => Ok(str), - Err(err) => anyhow::bail!("cannot decode: {err}"), - } - } - - /// Encode the input text - pub fn encode(&self, text: &str, special_tokens: bool) -> Result> { - self.tokenizer - .encode(text, special_tokens) - .map(|e| e.get_ids().to_vec()) - .map_err(|e| anyhow::anyhow!("failed to encode: {e}")) - } - - /// Get the next token - /// - /// - pub fn next_token(&mut self, token: u32) -> Result> { - let prev_text = if self.tokens.is_empty() { - String::new() - } else { - let tokens = &self.tokens[self.prev_index..self.current_index]; - self.decode(tokens)? - }; - self.tokens.push(token); - let text = self.decode(&self.tokens[self.prev_index..])?; - if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() { - let text = text.split_at(prev_text.len()); - self.prev_index = self.current_index; - self.current_index = self.tokens.len(); - Ok(Some(text.1.to_string())) - } else { - Ok(None) - } - } - - /// Encode the prompt string - pub fn prompt<'p>(&'p self, text: &'p str) -> Result> { - Ok(PromptBuilder::new(self, text)) - } - - /// Get token from the input string - pub fn token(&self, token_s: &str) -> Option { - self.tokenizer.get_vocab(true).get(token_s).copied() - } -} diff --git a/crates/candle/src/token/mod.rs b/crates/candle/src/token/mod.rs new file mode 100644 index 0000000..56c7074 --- /dev/null +++ b/crates/candle/src/token/mod.rs @@ -0,0 +1,92 @@ +//! Token stream handler + +use crate::{Inference, Processor}; +use anyhow::Result; +pub use {prompt::PromptBuilder, stream::TokenStream}; + +mod prompt; +mod stream; + +/// A token stream handler +pub struct Tokenizer { + /// The tokenizer + tokenizer: tokenizers::Tokenizer, + /// The full context including the tokens inferenced by the model + /// and the users' input + tokens: Vec, + + /// The end of stream token + pub eos: u32, +} + +impl Tokenizer { + /// Create a new token stream + pub fn new(tokenizer: tokenizers::Tokenizer) -> Result { + Ok(Self { + tokens: Vec::new(), + eos: tokenizer + .get_vocab(true) + .get(I::eos_token()) + .copied() + .ok_or_else(|| anyhow::anyhow!("eos token not found"))?, + tokenizer, + }) + } + + /// Get the count of the tokens + pub fn tokens(&self) -> usize { + self.tokens.len() + } + + /// Add a token to the context + pub fn sampled(&mut self, tokens: &[u32]) { + self.tokens.extend(tokens); + } + + /// Embed a token to the context + pub fn embed(&mut self, token: u32) -> Result { + match self.tokenizer.decode(&[token], true) { + Ok(str) => { + self.tokens.push(token); + Ok(str) + } + Err(err) => anyhow::bail!("cannot decode: {err}"), + } + } + + /// Decode the tokens to string + pub fn decode(&self, tokens: &[u32]) -> Result { + match self.tokenizer.decode(tokens, true) { + Ok(str) => Ok(str), + Err(err) => anyhow::bail!("cannot decode: {err}"), + } + } + + /// Encode the input text + pub fn encode(&self, text: &str, special_tokens: bool) -> Result> { + self.tokenizer + .encode(text, special_tokens) + .map(|e| e.get_ids().to_vec()) + .map_err(|e| anyhow::anyhow!("failed to encode: {e}")) + } + + /// Encode the prompt string + pub fn prompt<'p>(&'p mut self, text: &'p str) -> Result> { + Ok(PromptBuilder::new(self, text)) + } + + /// Get token from the input string + pub fn token(&self, token_s: &str) -> Option { + self.tokenizer.get_vocab(true).get(token_s).copied() + } + + /// Get the token stream + pub fn stream<'ts, I: Inference>( + &'ts mut self, + weights: &'ts mut I, + processor: &'ts mut Processor, + prompt: String, + ) -> Result> { + TokenStream::new(weights, processor, self, prompt) + } +} diff --git a/crates/candle/src/stream/prompt.rs b/crates/candle/src/token/prompt.rs similarity index 72% rename from crates/candle/src/stream/prompt.rs rename to crates/candle/src/token/prompt.rs index 7be0877..4b49280 100644 --- a/crates/candle/src/stream/prompt.rs +++ b/crates/candle/src/token/prompt.rs @@ -1,12 +1,12 @@ //! Prompt builder -use crate::{Inference, TokenStream}; +use crate::{Inference, Tokenizer}; use anyhow::Result; /// Prompt builder pub struct PromptBuilder<'t> { /// The token stream - tos: &'t TokenStream, + tos: &'t mut Tokenizer, /// The text text: &'t str, @@ -23,11 +23,11 @@ pub struct PromptBuilder<'t> { impl<'t> PromptBuilder<'t> { /// Create a new prompt builder - pub fn new(tos: &'t TokenStream, text: &'t str) -> Self { + pub fn new(tos: &'t mut Tokenizer, text: &'t str) -> Self { Self { tos, text, - special_tokens: true, + special_tokens: false, sample_len: None, max_seq_len: None, } @@ -52,16 +52,17 @@ impl<'t> PromptBuilder<'t> { } /// Encode the text to tokens - pub fn encode(self) -> Result> { + pub fn encode(self) -> Result> { let mut tokens = self.tos.encode(self.text, self.special_tokens)?; if let (Some(max_seq_len), Some(sample_len)) = (self.max_seq_len, self.sample_len) { - // NOTE: we need to subtract 10 to account for the eos token - if tokens.len() + sample_len > max_seq_len.saturating_sub(10) { - // TODO: handle the case where the tokens are too long - tokens = tokens[tokens.len().saturating_sub(sample_len)..].to_vec(); + let eos_token_len = I::eos_token().len(); + if tokens.len() + sample_len > max_seq_len.saturating_sub(eos_token_len) { + let to_remove = tokens.len() + sample_len + eos_token_len - max_seq_len; + tokens = tokens[tokens.len().saturating_sub(to_remove)..].to_vec(); } } + self.tos.sampled(&tokens); Ok(tokens) } } diff --git a/crates/candle/src/token/stream.rs b/crates/candle/src/token/stream.rs new file mode 100644 index 0000000..bf9dfd3 --- /dev/null +++ b/crates/candle/src/token/stream.rs @@ -0,0 +1,101 @@ +//! Token output stream + +use crate::{Inference, Processor, Tokenizer}; +use anyhow::Result; + +/// Token output stream +pub struct TokenStream<'ts, I: Inference> { + /// The current tokens + cur_tokens: Vec, + + /// The next token + next: u32, + + /// The position + pos: usize, + + /// The processor + processor: &'ts mut Processor, + + /// The sampled tokens + sampled: usize, + + /// The tokenizer + tokenizer: &'ts mut Tokenizer, + + /// The model weights + weights: &'ts mut I, +} + +impl<'ts, I: Inference> TokenStream<'ts, I> { + /// Create a new token stream + pub fn new( + weights: &'ts mut I, + processor: &'ts mut Processor, + tokenizer: &'ts mut Tokenizer, + prompt: String, + ) -> Result { + let mut this = Self { + cur_tokens: vec![], + pos: tokenizer.tokens(), + next: 0, + sampled: 0, + processor, + tokenizer, + weights, + }; + + this.sample_prompt(&prompt)?; + Ok(this) + } + + /// Sample the prompt + /// + /// This function should only be called on the start of the stream. + fn sample_prompt(&mut self, prompt: &str) -> Result<()> { + let tokens = self + .tokenizer + .prompt(prompt)? + .sample_len(self.processor.sample_len) + .max_seq_len::() + .encode::()?; + + for token in tokens.iter() { + self.sample_token(*token)?; + } + + Ok(()) + } + + /// Sample a token + fn sample_token(&mut self, token: u32) -> Result<()> { + self.next = self + .processor + .sample_token(token) + .cur_tokens(&self.cur_tokens) + .pos(self.pos) + .sample(self.weights)?; + + self.cur_tokens.push(self.next); + self.pos += 1; + Ok(()) + } +} + +impl Iterator for TokenStream<'_, I> { + type Item = String; + + fn next(&mut self) -> Option { + if self.pos == self.tokenizer.tokens() { + return self.tokenizer.embed(self.next).ok(); + } + + if self.next == self.tokenizer.eos || self.sampled >= self.processor.sample_len { + return None; + } + + self.sample_token(self.next).ok()?; + self.sampled += 1; + self.tokenizer.embed(self.next).ok() + } +} diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index ed9e48a..3a9a579 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -6,3 +6,4 @@ edition = "2021" [dependencies] anyhow.workspace = true serde.workspace = true +tracing.workspace = true diff --git a/crates/core/src/chat/llama.rs b/crates/core/src/chat/llama.rs new file mode 100644 index 0000000..a821671 --- /dev/null +++ b/crates/core/src/chat/llama.rs @@ -0,0 +1,81 @@ +//! Llama3 prompt formatter + +use crate::chat::{Formatter, Message}; + +/// Llama3 prompt formatter +#[derive(Default)] +pub struct Llama3 { + output: String, +} + +impl Llama3 { + /// Emit the begin of text token + fn emit_begin(&mut self) { + self.output.push_str("<|begin_of_text|>"); + } + + /// Add a default system message + fn default_system(&mut self) { + self.output + .push_str("<|start_header_id|>system<|end_header_id|>\n"); + self.output + .push_str("You are a helpful assistant.<|eot_id|>"); + } + + fn emit_system(&mut self, system: &str) { + self.output.push_str(&format!( + "<|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|>" + )); + } + + fn emit_user(&mut self, user: &str) { + self.output.push_str(&format!( + "<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|>" + )); + } + + fn emit_assistant(&mut self, assistant: &str) { + self.output.push_str(&format!( + "<|start_header_id|>assistant<|end_header_id|>\n\n{assistant}<|eot_id|>" + )); + } + + fn emit_complete(&mut self) { + self.output + .push_str("<|start_header_id|>assistant<|end_header_id|>\n\n"); + } +} + +impl Formatter for Llama3 { + const EOS_TOKEN: &str = "<|eot_id|>"; + + fn format(messages: &[Message]) -> anyhow::Result { + let mut formatter = Llama3::default(); + formatter.emit_begin(); + formatter.default_system(); + + for message in messages { + match message { + Message::System(system) => formatter.emit_system(system), + Message::User(user) => formatter.emit_user(user), + Message::Assistant(assistant) => formatter.emit_assistant(assistant), + } + } + + formatter.emit_complete(); + Ok(formatter.output) + } + + fn complete(messages: &[Message]) -> anyhow::Result { + let mut formatter = Llama3::default(); + for message in messages { + match message { + Message::System(system) => formatter.emit_system(system), + Message::User(user) => formatter.emit_user(user), + Message::Assistant(assistant) => formatter.emit_assistant(assistant), + } + } + formatter.emit_complete(); + Ok(formatter.output) + } +} diff --git a/crates/core/src/chat/mod.rs b/crates/core/src/chat/mod.rs index 8eb9c69..675e943 100644 --- a/crates/core/src/chat/mod.rs +++ b/crates/core/src/chat/mod.rs @@ -1,30 +1,60 @@ //! Chat interfaces +pub use llama::Llama3; use std::{fmt::Display, str::FromStr}; +mod llama; + /// A message in a chat. -#[derive(Debug, Clone, Default)] -pub struct Message { - /// The role of the message. - pub role: Role, - /// The content of the message. - /// - /// NOTE: only supports string atm - pub content: String, +#[derive(Debug, Clone)] +pub enum Message { + /// The assistant message + Assistant(String), + /// The user message + User(String), + /// The system message + System(String), +} + +impl Message { + /// Get the text of the message + pub fn text(&self) -> &str { + match self { + Self::Assistant(content) => content, + Self::User(content) => content, + Self::System(content) => content, + } + } + + /// Create a new user message + pub fn user(content: impl Into) -> Self { + Self::User(content.into()) + } + + /// Create a new assistant message + pub fn assistant(content: impl Into) -> Self { + Self::Assistant(content.into()) + } + + /// Create a new system message + pub fn system(content: impl Into) -> Self { + Self::System(content.into()) + } } impl Display for Message { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}: {}", self.role, self.content) + match self { + Self::Assistant(content) => write!(f, "assistant: {}", content), + Self::User(content) => write!(f, "user: {}", content), + Self::System(content) => write!(f, "system: {}", content), + } } } impl From<&str> for Message { fn from(s: &str) -> Self { - Self { - role: Role::User, - content: s.to_string(), - } + Self::User(s.to_string()) } } @@ -35,40 +65,23 @@ impl FromStr for Message { let (role, content) = s .split_once(": ") .ok_or_else(|| anyhow::anyhow!("invalid message format"))?; - Ok(Self { - role: Role::from_str(role)?, - content: content.to_string(), + Ok(match role.to_lowercase().trim() { + "assistant" => Self::Assistant(content.to_string()), + "user" => Self::User(content.to_string()), + "system" => Self::System(content.to_string()), + _ => anyhow::bail!("invalid role: {role}"), }) } } -/// The role of the message. -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] -pub enum Role { - /// The assistant. - /// - /// The content is the assistant's message. - Assistant, - /// The user. - /// - /// The content is the user's message. - #[default] - User, - /// The system. - /// - /// The content is the system's message. - System, -} +/// A formatter for chat messages +pub trait Formatter { + /// The end of stream token + const EOS_TOKEN: &str; -impl FromStr for Role { - type Err = anyhow::Error; + /// Format the messages + fn format(messages: &[Message]) -> anyhow::Result; - fn from_str(s: &str) -> Result { - Ok(match s { - "assistant" => Self::Assistant, - "user" => Self::User, - "system" => Self::System, - _ => anyhow::bail!("invalid role: {s}"), - }) - } + /// Format a single message + fn complete(messages: &[Message]) -> anyhow::Result; } diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 345f982..ae6911b 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -2,10 +2,13 @@ //! //! This library gathers the user interfaces of cydonia -mod chat; -mod manifest; +pub mod chat; +pub mod manifest; -pub use {chat::Message, manifest::Manifest}; +pub use { + chat::Message, + manifest::{Family, Quantization, Release}, +}; /// The tokenizer repo of cydonia in huggingface. pub const TOKENIZER: &str = "clearloop/tokenizer"; diff --git a/crates/core/src/manifest/family.rs b/crates/core/src/manifest/family.rs index d4d1b85..e977b70 100644 --- a/crates/core/src/manifest/family.rs +++ b/crates/core/src/manifest/family.rs @@ -1,124 +1,151 @@ -use crate::manifest::Quantization; +//! Model family + use anyhow::Result; +pub use llama::LlamaVer; use std::{fmt::Display, str::FromStr}; -/// Release info of a model -#[derive(Debug)] -pub struct Release { - /// The family of the model - pub family: Family, - - /// The version of the model - pub version: u8, +/// The family of the model +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Family { + /// Llama from Meta + Llama { + /// The version of the model + version: LlamaVer, - /// The parameters of the model in billions - pub parameters: f32, + /// The parameters of the model + params: Params, - /// The tag of the model - pub tag: Option, + /// The tag of the model + tag: Tag, + }, } -impl Release { - /// Create a new release from a model name - pub fn new(model: &str) -> Result { - match model { - "llama2" | "llama2-7b" | "llama2-7b-chat" => Ok(Self { - family: Family::Llama, - version: 2, - parameters: 6.74, - tag: Some("chat".into()), - }), - _ => anyhow::bail!("invalid model: {model}"), +impl Default for Family { + fn default() -> Self { + Self::Llama { + version: LlamaVer::V3_2, + params: Params::V3B, + tag: Tag::Instruct, } } +} - /// Get the repo of the model - pub fn repo(&self) -> Result<&str> { - match self.family.as_ref() { - "llama" => Ok("TheBloke/Llama-2-7B-Chat-GGUF"), - _ => anyhow::bail!("invalid family: {}", self.family), - } +impl From<&str> for Family { + fn from(s: &str) -> Self { + Self::from_str(s).unwrap_or_default() } +} - /// Get the tokenizer path from the tokenizer repo - pub fn tokenizer(&self) -> &str { - "llama2/tokenizer.json" +impl Display for Family { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Family::Llama { + version, + params, + tag, + } => format!("Llama-{version}-{params}-{tag}"), + } + ) } +} + +impl FromStr for Family { + type Err = anyhow::Error; - /// Get the model path of the model - /// - /// NOTE: only support llama2 for now - pub fn model(&self, quant: Quantization) -> String { - match self.family { - Family::Llama => format!( - "llama-2-{}b-{}.{}.gguf", - self.parameters.ceil() as u8, - self.tag.as_deref().unwrap_or("chat"), - quant - ), + fn from_str(s: &str) -> Result { + let model = s + .trim() + .to_lowercase() + .replace('-', "") + .replace("instruct", ""); + + match model.as_ref() { + "llama3.18b" => Ok(Family::Llama { + version: LlamaVer::V3_1, + params: Params::V8B, + tag: Tag::Instruct, + }), + "llama3.21b" => Ok(Family::Llama { + version: LlamaVer::V3_2, + params: Params::V1B, + tag: Tag::Instruct, + }), + "llama3.23b" => Ok(Family::Llama { + version: LlamaVer::V3_2, + params: Params::V3B, + tag: Tag::Instruct, + }), + _ => { + tracing::warn!("invalid family {s}, using default llama-3.2-1B-Instruct"); + Ok(Family::Llama { + version: LlamaVer::V3_2, + params: Params::V1B, + tag: Tag::Instruct, + }) + } } } } -impl Default for Release { - fn default() -> Self { - Self { - family: Family::Llama, - version: 2, - parameters: 6.74, - tag: Some("chat".into()), - } - } +/// The parameters of the model +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Params { + V1B, + V3B, + V8B, } -impl Display for Release { +impl Display for Params { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "{}{}-{}b-{}", - self.family.as_ref(), - self.version, - self.parameters.ceil() as u8, - self.tag.as_deref().unwrap_or("chat") + "{}", + match self { + Params::V1B => "1B", + Params::V3B => "3B", + Params::V8B => "8B", + } ) } } -/// The family of the model +/// The tag of the model #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] -pub enum Family { - /// Llama from Meta +pub enum Tag { #[default] - Llama, + Instruct, } -impl AsRef for Family { - fn as_ref(&self) -> &str { - match self { - Family::Llama => "llama", - } - } -} - -impl Display for Family { +impl Display for Tag { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.as_ref()) + write!(f, "{:?}", self) } } -impl FromStr for Family { - type Err = anyhow::Error; +mod llama { + use std::fmt::Display; - fn from_str(_: &str) -> Result { - Ok(Family::Llama) + /// The version of the llama model + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] + pub enum LlamaVer { + V3_1, + #[default] + V3_2, } -} -#[test] -fn test_fmt_release() { - assert_eq!(Release::default().to_string(), "llama2-7b-chat"); - assert_eq!( - Release::default().model(Quantization::Q4_0), - "llama-2-7b-chat.Q4_0.gguf" - ); + impl Display for LlamaVer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + LlamaVer::V3_1 => "3.1", + LlamaVer::V3_2 => "3.2", + } + ) + } + } } diff --git a/crates/core/src/manifest/mod.rs b/crates/core/src/manifest/mod.rs index f3ad583..eb39c14 100644 --- a/crates/core/src/manifest/mod.rs +++ b/crates/core/src/manifest/mod.rs @@ -1,57 +1,11 @@ //! Model manifest -use std::collections::HashMap; pub use { - family::{Family, Release}, + family::{Family, LlamaVer, Params, Tag}, quant::Quantization, + release::Release, }; mod family; mod quant; - -/// Manifest of a quantized model -#[derive(Debug)] -pub struct Manifest { - /// The name of the model - pub name: String, - - /// The release of the model - pub release: Release, - - /// The K-quantization of the model - pub quantization: Quantization, - - /// The revision of the model - pub revision: [u8; 12], - - /// The parameters of the model - pub params: HashMap, -} - -impl Manifest { - /// Create a new manifest from a model name - pub fn new(name: &str) -> anyhow::Result { - let release = Release::new(name)?; - Ok(Self { - name: name.into(), - quantization: match release.family { - Family::Llama => Quantization::Q4_0, - }, - release, - revision: [0; 12], - params: HashMap::new(), - }) - } -} - -impl Default for Manifest { - fn default() -> Self { - Self { - name: "llama2".into(), - release: Release::default(), - quantization: Quantization::Q4_0, - revision: [0; 12], - params: HashMap::new(), - } - } -} +mod release; diff --git a/crates/core/src/manifest/quant.rs b/crates/core/src/manifest/quant.rs index 2fbbe0a..5e97439 100644 --- a/crates/core/src/manifest/quant.rs +++ b/crates/core/src/manifest/quant.rs @@ -45,7 +45,6 @@ pub enum Quantization { Q5_K, /// 4-bit round-to-nearest quantization (q). Each block has 32 weights. /// Weight formula: w = q * block_scale. Legacy quantization method not used widely as of today. - #[default] Q4_0, /// 4-bit round-to-nearest quantization (q). Each block has 32 weights. /// Weight formula: w = q * block_scale + block_minimum. Legacy quantization method not used widely as of today. @@ -62,6 +61,7 @@ pub enum Quantization { /// Weight formula: w = q * block_scale(6-bit) + block_min(6-bit), resulting in 4.5 bits-per-weight. /// /// in medium size + #[default] Q4_K_M, /// 3-bit quantization (q). Super-blocks with 16 blocks, each block has 16 weights. /// Weight formula: w = q * block_scale(6-bit), resulting in 3.4375 bits-per-weight. diff --git a/crates/core/src/manifest/release.rs b/crates/core/src/manifest/release.rs new file mode 100644 index 0000000..756b2e7 --- /dev/null +++ b/crates/core/src/manifest/release.rs @@ -0,0 +1,66 @@ +//! Model release + +use crate::manifest::{Family, LlamaVer, Params, Quantization}; +use std::fmt::Display; + +/// Release info of a model +#[derive(Debug, Default)] +pub struct Release { + /// The family of the model + pub family: Family, + + /// The quantization of the model + pub quant: Quantization, +} + +impl Release { + /// Create a new release from a model name + pub fn new(model: &str) -> Self { + let family = Family::from(model); + Self { + family, + quant: Quantization::Q4_K_M, + } + } + + /// Get the repo of the model + pub fn repo(&self) -> &str { + let Family::Llama { + version, params, .. + } = self.family; + + match (version, params) { + (LlamaVer::V3_1, _) => "MaziyarPanahi/Meta-Llama-3.1-8B-Instruct-GGUF", + (LlamaVer::V3_2, Params::V3B) => "MaziyarPanahi/Llama-3.2-3B-Instruct-GGUF", + _ => "MaziyarPanahi/Llama-3.2-1B-Instruct-GGUF", + } + } + + /// Get the tokenizer path from the tokenizer repo + pub fn tokenizer(&self) -> &str { + "llama3/tokenizer.json" + } + + /// Get the model path of the model + pub fn model(&self) -> String { + format!("{}.gguf", self) + } +} + +impl Display for Release { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}.{}", self.family, self.quant) + } +} + +#[test] +fn test_fmt_release() { + assert_eq!( + Release::default().to_string(), + "Llama-3.2-3B-Instruct.Q4_K_M" + ); + assert_eq!( + Release::default().model(), + "Llama-3.2-3B-Instruct.Q4_K_M.gguf" + ); +} diff --git a/docs/src/README.md b/docs/src/README.md index 2fe43bf..428b894 100644 --- a/docs/src/README.md +++ b/docs/src/README.md @@ -4,19 +4,17 @@ Cydonia is a library based on [candle][candle] for developing modern AI applicat ```rust use cydonia::Model; + fn main() { - let model = Model::new("gemma2").tag("latest"); + let model = Model::new("llama3.2-1b"); let response = model.invoke("Hello, world!"); println!("{}", response); } ``` -We support quantized models only derived from `gemma` and `llama` family. - -## Special Thanks +## LICENSE -- [candle][candle] -- [ollama][ollama] +[GPL-3.0](LICENSE)