The recommended way to load models and run inference:
from chuk_lazarus.inference import UnifiedPipeline, UnifiedPipelineConfig, DType
# One-liner model loading - auto-detects family
pipeline = UnifiedPipeline.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
# With configuration
config = UnifiedPipelineConfig(
dtype=DType.BFLOAT16,
use_lora=True,
)
pipeline = UnifiedPipeline.from_pretrained("model-id", config=config)
# Generate text
result = pipeline.chat("Hello!")
print(result.text)
print(result.stats.summary) # "25 tokens in 0.42s (59.5 tok/s)"
# Streaming
for chunk in pipeline.generate_stream("Once upon a time"):
print(chunk, end="", flush=True)Stateful inference engines for multi-turn generation with explicit KV store management.
from chuk_lazarus.inference.context import KVDirectGenerator, make_kv_generator
# Auto-detect factory
gen = make_kv_generator(model) # Gemma, Llama, Mistral, ...
# Or use named constructors
gen = KVDirectGenerator.from_gemma_rs(rs_model)
gen = KVDirectGenerator.from_llama(llama_model)
# Lifecycle
logits, kv_store = gen.prefill(input_ids) # Full forward pass
logits, kv_store = gen.step(token_ids, kv_store, seq_len) # Single token (compiled)
logits, kv_store = gen.extend(new_ids, kv_store, abs_start) # Batch new tokens
kv_store = gen.slide(kv_store, evict_count) # Evict oldest tokens
bytes_used = gen.kv_bytes(seq_len) # Memory accountingkv_store format: list[tuple[mx.array, mx.array]] — one (K, V) pair per layer.
K.shape = V.shape = (batch, num_kv_heads, seq_len, head_dim)
from chuk_lazarus.inference.unified import EngineMode
from chuk_lazarus.inference import UnifiedPipeline, UnifiedPipelineConfig
config = UnifiedPipelineConfig(engine=EngineMode.KV_DIRECT)
pipeline = UnifiedPipeline.from_pretrained("model-id", config=config)
# Get generator from loaded pipeline
kv_gen = pipeline.make_engine()Values: EngineMode.STANDARD (default), EngineMode.KV_DIRECT
from chuk_lazarus.inference.context import (
ModelBackboneProtocol, # Interface for backbone adapters
TransformerLayerProtocol, # Per-layer interface
GemmaBackboneAdapter, # Wraps GemmaResidualStreamForCausalLM
GemmaLayerAdapter,
LlamaBackboneAdapter, # Wraps LlamaForCausalLM / Mistral
LlamaLayerAdapter,
)To support a new architecture, implement ModelBackboneProtocol and pass it to KVDirectGenerator(backbone).
from chuk_lazarus.utils.tokenizer_loader import load_tokenizer
tokenizer = load_tokenizer("model-path-or-id")All trainers inherit from BaseTrainer:
class BaseTrainer:
def train(dataset, num_epochs, eval_dataset=None, callback=None)
def evaluate(dataset) -> Dict[str, float]
def save_checkpoint(name: str)
def load_checkpoint(path: str)from chuk_lazarus.training import SFTTrainer, SFTConfig
config = SFTConfig(
num_epochs=3,
batch_size=4,
learning_rate=1e-5,
)
trainer = SFTTrainer(model, tokenizer, config)
trainer.train(dataset)from chuk_lazarus.training import DPOTrainer, DPOTrainerConfig
trainer = DPOTrainer(
policy_model,
reference_model,
tokenizer,
config,
)
trainer.train(dataset)from chuk_lazarus.data import SFTDataset
dataset = SFTDataset(
path: str, # JSONL file path
tokenizer,
max_length: int = 512,
)
len(dataset) # Number of samples
dataset[0] # Get sample by index
for batch in dataset.iter_batches(batch_size=4):
# batch["input_ids"], batch["labels"], batch["loss_mask"]from chuk_lazarus.data import PreferenceDataset
dataset = PreferenceDataset(
path: str, # JSONL with prompt/chosen/rejected
tokenizer,
max_length: int = 512,
)from chuk_lazarus.data.generators import MathProblemGenerator
gen = MathProblemGenerator(seed=42)
samples = gen.generate_batch(100, difficulty_range=(1, 3))
# Export
gen.save_sft_dataset(samples, "train_sft.jsonl")
gen.save_dpo_dataset(samples, "train_dpo.jsonl")@dataclass
class ModelConfig:
hidden_size: int
num_hidden_layers: int
num_attention_heads: int
intermediate_size: int
vocab_size: int
max_position_embeddings: int = 2048
hidden_act: str = "silu"
rms_norm_eps: float = 1e-6@dataclass
class LoRAConfig:
rank: int = 8
alpha: float = 16.0
dropout: float = 0.0
target_modules: List[str] = ("q_proj", "v_proj")@dataclass
class SFTConfig:
num_epochs: int = 3
batch_size: int = 4
learning_rate: float = 1e-5
warmup_steps: int = 100
max_grad_norm: float = 1.0
log_interval: int = 10
checkpoint_interval: int = 500
checkpoint_dir: str = "./checkpoints/sft"