Skip to content

Latest commit

 

History

History
226 lines (168 loc) · 5.19 KB

File metadata and controls

226 lines (168 loc) · 5.19 KB

API Reference

Inference

UnifiedPipeline

The recommended way to load models and run inference:

from chuk_lazarus.inference import UnifiedPipeline, UnifiedPipelineConfig, DType

# One-liner model loading - auto-detects family
pipeline = UnifiedPipeline.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

# With configuration
config = UnifiedPipelineConfig(
    dtype=DType.BFLOAT16,
    use_lora=True,
)
pipeline = UnifiedPipeline.from_pretrained("model-id", config=config)

# Generate text
result = pipeline.chat("Hello!")
print(result.text)
print(result.stats.summary)  # "25 tokens in 0.42s (59.5 tok/s)"

# Streaming
for chunk in pipeline.generate_stream("Once upon a time"):
    print(chunk, end="", flush=True)

Context Engines

Stateful inference engines for multi-turn generation with explicit KV store management.

KVDirectGenerator

from chuk_lazarus.inference.context import KVDirectGenerator, make_kv_generator

# Auto-detect factory
gen = make_kv_generator(model)        # Gemma, Llama, Mistral, ...

# Or use named constructors
gen = KVDirectGenerator.from_gemma_rs(rs_model)
gen = KVDirectGenerator.from_llama(llama_model)

# Lifecycle
logits, kv_store = gen.prefill(input_ids)                            # Full forward pass
logits, kv_store = gen.step(token_ids, kv_store, seq_len)            # Single token (compiled)
logits, kv_store = gen.extend(new_ids, kv_store, abs_start)          # Batch new tokens
kv_store = gen.slide(kv_store, evict_count)                          # Evict oldest tokens
bytes_used = gen.kv_bytes(seq_len)                                   # Memory accounting

kv_store format: list[tuple[mx.array, mx.array]] — one (K, V) pair per layer. K.shape = V.shape = (batch, num_kv_heads, seq_len, head_dim)

EngineMode

from chuk_lazarus.inference.unified import EngineMode
from chuk_lazarus.inference import UnifiedPipeline, UnifiedPipelineConfig

config = UnifiedPipelineConfig(engine=EngineMode.KV_DIRECT)
pipeline = UnifiedPipeline.from_pretrained("model-id", config=config)

# Get generator from loaded pipeline
kv_gen = pipeline.make_engine()

Values: EngineMode.STANDARD (default), EngineMode.KV_DIRECT

Protocols and Adapters

from chuk_lazarus.inference.context import (
    ModelBackboneProtocol,      # Interface for backbone adapters
    TransformerLayerProtocol,   # Per-layer interface
    GemmaBackboneAdapter,       # Wraps GemmaResidualStreamForCausalLM
    GemmaLayerAdapter,
    LlamaBackboneAdapter,       # Wraps LlamaForCausalLM / Mistral
    LlamaLayerAdapter,
)

To support a new architecture, implement ModelBackboneProtocol and pass it to KVDirectGenerator(backbone).

load_tokenizer

from chuk_lazarus.utils.tokenizer_loader import load_tokenizer

tokenizer = load_tokenizer("model-path-or-id")

Training

BaseTrainer

All trainers inherit from BaseTrainer:

class BaseTrainer:
    def train(dataset, num_epochs, eval_dataset=None, callback=None)
    def evaluate(dataset) -> Dict[str, float]
    def save_checkpoint(name: str)
    def load_checkpoint(path: str)

SFTTrainer

from chuk_lazarus.training import SFTTrainer, SFTConfig

config = SFTConfig(
    num_epochs=3,
    batch_size=4,
    learning_rate=1e-5,
)

trainer = SFTTrainer(model, tokenizer, config)
trainer.train(dataset)

DPOTrainer

from chuk_lazarus.training import DPOTrainer, DPOTrainerConfig

trainer = DPOTrainer(
    policy_model,
    reference_model,
    tokenizer,
    config,
)
trainer.train(dataset)

Data

SFTDataset

from chuk_lazarus.data import SFTDataset

dataset = SFTDataset(
    path: str,           # JSONL file path
    tokenizer,
    max_length: int = 512,
)

len(dataset)  # Number of samples
dataset[0]    # Get sample by index

for batch in dataset.iter_batches(batch_size=4):
    # batch["input_ids"], batch["labels"], batch["loss_mask"]

PreferenceDataset

from chuk_lazarus.data import PreferenceDataset

dataset = PreferenceDataset(
    path: str,  # JSONL with prompt/chosen/rejected
    tokenizer,
    max_length: int = 512,
)

MathProblemGenerator

from chuk_lazarus.data.generators import MathProblemGenerator

gen = MathProblemGenerator(seed=42)
samples = gen.generate_batch(100, difficulty_range=(1, 3))

# Export
gen.save_sft_dataset(samples, "train_sft.jsonl")
gen.save_dpo_dataset(samples, "train_dpo.jsonl")

Configuration Classes

ModelConfig

@dataclass
class ModelConfig:
    hidden_size: int
    num_hidden_layers: int
    num_attention_heads: int
    intermediate_size: int
    vocab_size: int
    max_position_embeddings: int = 2048
    hidden_act: str = "silu"
    rms_norm_eps: float = 1e-6

LoRAConfig

@dataclass
class LoRAConfig:
    rank: int = 8
    alpha: float = 16.0
    dropout: float = 0.0
    target_modules: List[str] = ("q_proj", "v_proj")

SFTConfig

@dataclass
class SFTConfig:
    num_epochs: int = 3
    batch_size: int = 4
    learning_rate: float = 1e-5
    warmup_steps: int = 100
    max_grad_norm: float = 1.0
    log_interval: int = 10
    checkpoint_interval: int = 500
    checkpoint_dir: str = "./checkpoints/sft"