feat(gateway): skill-aware slash commands, paginated /commands, Telegram 100-cap (#3934)

* feat(gateway): skill-aware slash commands, paginated /commands, Telegram 100-cap

Map active skills to Telegram's slash command menu so users can
discover and invoke skills directly. Three changes:

1. Telegram menu now includes active skill commands alongside built-in
   commands, capped at 100 entries (Telegram Bot API limit). Overflow
   commands remain callable but hidden from the picker. Logged at
   startup when cap is hit.

2. New /commands [page] gateway command for paginated browsing of all
   commands + skills. /help now shows first 10 skill commands and
   points to /commands for the full list.

3. When a user types a slash command that matches a disabled or
   uninstalled skill, they get actionable guidance:
   - Disabled: 'Enable it with: hermes skills config'
   - Optional (not installed): 'Install with: hermes skills install official/<path>'

Built on ideas from PR #3921 by @kshitijk4poor.

* chore: move 21 niche skills to optional-skills

Move specialized/niche skills from built-in (skills/) to optional
(optional-skills/) to reduce the default skill count. Users can
install them with: hermes skills install official/<category>/<name>

Moved skills (21):
- mlops: accelerate, chroma, faiss, flash-attention,
  hermes-atropos-environments, huggingface-tokenizers, instructor,
  lambda-labs, llava, nemo-curator, pinecone, pytorch-lightning,
  qdrant, saelens, simpo, slime, tensorrt-llm, torchtitan
- research: domain-intel, duckduckgo-search
- devops: inference-sh cli

Built-in skills: 96 → 75
Optional skills: 22 → 43

* fix: only include repo built-in skills in Telegram menu, not user-installed

User-installed skills (from hub or manually added) stay accessible via
/skills and by typing the command directly, but don't get registered
in the Telegram slash command picker. Only skills whose SKILL.md is
under the repo's skills/ directory are included in the menu.

This keeps the Telegram menu focused on the curated built-in set while
user-installed skills remain discoverable through /skills and /commands.
This commit is contained in:
Teknium 2026-03-30 10:57:30 -07:00 committed by GitHub
parent 97d6813f51
commit 5ceed021dc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
73 changed files with 163 additions and 4 deletions

View file

@ -0,0 +1,155 @@
---
name: inference-sh-cli
description: "Run 150+ AI apps via inference.sh CLI (infsh) — image generation, video creation, LLMs, search, 3D, social automation. Uses the terminal tool. Triggers: inference.sh, infsh, ai apps, flux, veo, image generation, video generation, seedream, seedance, tavily"
version: 1.0.0
author: okaris
license: MIT
metadata:
hermes:
tags: [AI, image-generation, video, LLM, search, inference, FLUX, Veo, Claude]
related_skills: []
---
# inference.sh CLI
Run 150+ AI apps in the cloud with a simple CLI. No GPU required.
All commands use the **terminal tool** to run `infsh` commands.
## When to Use
- User asks to generate images (FLUX, Reve, Seedream, Grok, Gemini image)
- User asks to generate video (Veo, Wan, Seedance, OmniHuman)
- User asks about inference.sh or infsh
- User wants to run AI apps without managing individual provider APIs
- User asks for AI-powered search (Tavily, Exa)
- User needs avatar/lipsync generation
## Prerequisites
The `infsh` CLI must be installed and authenticated. Check with:
```bash
infsh me
```
If not installed:
```bash
curl -fsSL https://cli.inference.sh | sh
infsh login
```
See `references/authentication.md` for full setup details.
## Workflow
### 1. Always Search First
Never guess app names — always search to find the correct app ID:
```bash
infsh app list --search flux
infsh app list --search video
infsh app list --search image
```
### 2. Run an App
Use the exact app ID from the search results. Always use `--json` for machine-readable output:
```bash
infsh app run <app-id> --input '{"prompt": "your prompt here"}' --json
```
### 3. Parse the Output
The JSON output contains URLs to generated media. Present these to the user with `MEDIA:<url>` for inline display.
## Common Commands
### Image Generation
```bash
# Search for image apps
infsh app list --search image
# FLUX Dev with LoRA
infsh app run falai/flux-dev-lora --input '{"prompt": "sunset over mountains", "num_images": 1}' --json
# Gemini image generation
infsh app run google/gemini-2-5-flash-image --input '{"prompt": "futuristic city", "num_images": 1}' --json
# Seedream (ByteDance)
infsh app run bytedance/seedream-5-lite --input '{"prompt": "nature scene"}' --json
# Grok Imagine (xAI)
infsh app run xai/grok-imagine-image --input '{"prompt": "abstract art"}' --json
```
### Video Generation
```bash
# Search for video apps
infsh app list --search video
# Veo 3.1 (Google)
infsh app run google/veo-3-1-fast --input '{"prompt": "drone shot of coastline"}' --json
# Seedance (ByteDance)
infsh app run bytedance/seedance-1-5-pro --input '{"prompt": "dancing figure", "resolution": "1080p"}' --json
# Wan 2.5
infsh app run falai/wan-2-5 --input '{"prompt": "person walking through city"}' --json
```
### Local File Uploads
The CLI automatically uploads local files when you provide a path:
```bash
# Upscale a local image
infsh app run falai/topaz-image-upscaler --input '{"image": "/path/to/photo.jpg", "upscale_factor": 2}' --json
# Image-to-video from local file
infsh app run falai/wan-2-5-i2v --input '{"image": "/path/to/image.png", "prompt": "make it move"}' --json
# Avatar with audio
infsh app run bytedance/omnihuman-1-5 --input '{"audio": "/path/to/audio.mp3", "image": "/path/to/face.jpg"}' --json
```
### Search & Research
```bash
infsh app list --search search
infsh app run tavily/tavily-search --input '{"query": "latest AI news"}' --json
infsh app run exa/exa-search --input '{"query": "machine learning papers"}' --json
```
### Other Categories
```bash
# 3D generation
infsh app list --search 3d
# Audio / TTS
infsh app list --search tts
# Twitter/X automation
infsh app list --search twitter
```
## Pitfalls
1. **Never guess app IDs** — always run `infsh app list --search <term>` first. App IDs change and new apps are added frequently.
2. **Always use `--json`** — raw output is hard to parse. The `--json` flag gives structured output with URLs.
3. **Check authentication** — if commands fail with auth errors, run `infsh login` or verify `INFSH_API_KEY` is set.
4. **Long-running apps** — video generation can take 30-120 seconds. The terminal tool timeout should be sufficient, but warn the user it may take a moment.
5. **Input format** — the `--input` flag takes a JSON string. Make sure to properly escape quotes.
## Reference Docs
- `references/authentication.md` — Setup, login, API keys
- `references/app-discovery.md` — Searching and browsing the app catalog
- `references/running-apps.md` — Running apps, input formats, output handling
- `references/cli-reference.md` — Complete CLI command reference

View file

@ -0,0 +1,112 @@
# Discovering Apps
## List All Apps
```bash
infsh app list
```
## Pagination
```bash
infsh app list --page 2
```
## Filter by Category
```bash
infsh app list --category image
infsh app list --category video
infsh app list --category audio
infsh app list --category text
infsh app list --category other
```
## Search
```bash
infsh app search "flux"
infsh app search "video generation"
infsh app search "tts" -l
infsh app search "image" --category image
```
Or use the flag form:
```bash
infsh app list --search "flux"
infsh app list --search "video generation"
infsh app list --search "tts"
```
## Featured Apps
```bash
infsh app list --featured
```
## Newest First
```bash
infsh app list --new
```
## Detailed View
```bash
infsh app list -l
```
Shows table with app name, category, description, and featured status.
## Save to File
```bash
infsh app list --save apps.json
```
## Your Apps
List apps you've deployed:
```bash
infsh app my
infsh app my -l # detailed
```
## Get App Details
```bash
infsh app get falai/flux-dev-lora
infsh app get falai/flux-dev-lora --json
```
Shows full app info including input/output schema.
## Popular Apps by Category
### Image Generation
- `falai/flux-dev-lora` - FLUX.2 Dev (high quality)
- `falai/flux-2-klein-lora` - FLUX.2 Klein (fastest)
- `infsh/sdxl` - Stable Diffusion XL
- `google/gemini-3-pro-image-preview` - Gemini 3 Pro
- `xai/grok-imagine-image` - Grok image generation
### Video Generation
- `google/veo-3-1-fast` - Veo 3.1 Fast
- `google/veo-3` - Veo 3
- `bytedance/seedance-1-5-pro` - Seedance 1.5 Pro
- `infsh/ltx-video-2` - LTX Video 2 (with audio)
- `bytedance/omnihuman-1-5` - OmniHuman avatar
### Audio
- `infsh/dia-tts` - Conversational TTS
- `infsh/kokoro-tts` - Kokoro TTS
- `infsh/fast-whisper-large-v3` - Fast transcription
- `infsh/diffrythm` - Music generation
## Documentation
- [Browsing the Grid](https://inference.sh/docs/apps/browsing-grid) - Visual app browsing
- [Apps Overview](https://inference.sh/docs/apps/overview) - Understanding apps
- [Running Apps](https://inference.sh/docs/apps/running) - How to run apps

View file

@ -0,0 +1,59 @@
# Authentication & Setup
## Install the CLI
```bash
curl -fsSL https://cli.inference.sh | sh
```
## Login
```bash
infsh login
```
This opens a browser for authentication. After login, credentials are stored locally.
## Check Authentication
```bash
infsh me
```
Shows your user info if authenticated.
## Environment Variable
For CI/CD or scripts, set your API key:
```bash
export INFSH_API_KEY=your-api-key
```
The environment variable overrides the config file.
## Update CLI
```bash
infsh update
```
Or reinstall:
```bash
curl -fsSL https://cli.inference.sh | sh
```
## Troubleshooting
| Error | Solution |
|-------|----------|
| "not authenticated" | Run `infsh login` |
| "command not found" | Reinstall CLI or add to PATH |
| "API key invalid" | Check `INFSH_API_KEY` or re-login |
## Documentation
- [CLI Setup](https://inference.sh/docs/extend/cli-setup) - Complete CLI installation guide
- [API Authentication](https://inference.sh/docs/api/authentication) - API key management
- [Secrets](https://inference.sh/docs/secrets/overview) - Managing credentials

View file

@ -0,0 +1,104 @@
# CLI Reference
## Installation
```bash
curl -fsSL https://cli.inference.sh | sh
```
## Global Commands
| Command | Description |
|---------|-------------|
| `infsh help` | Show help |
| `infsh version` | Show CLI version |
| `infsh update` | Update CLI to latest |
| `infsh login` | Authenticate |
| `infsh me` | Show current user |
## App Commands
### Discovery
| Command | Description |
|---------|-------------|
| `infsh app list` | List available apps |
| `infsh app list --category <cat>` | Filter by category (image, video, audio, text, other) |
| `infsh app search <query>` | Search apps |
| `infsh app list --search <query>` | Search apps (flag form) |
| `infsh app list --featured` | Show featured apps |
| `infsh app list --new` | Sort by newest |
| `infsh app list --page <n>` | Pagination |
| `infsh app list -l` | Detailed table view |
| `infsh app list --save <file>` | Save to JSON file |
| `infsh app my` | List your deployed apps |
| `infsh app get <app>` | Get app details |
| `infsh app get <app> --json` | Get app details as JSON |
### Execution
| Command | Description |
|---------|-------------|
| `infsh app run <app> --input <file>` | Run app with input file |
| `infsh app run <app> --input '<json>'` | Run with inline JSON |
| `infsh app run <app> --input <file> --no-wait` | Run without waiting for completion |
| `infsh app sample <app>` | Show sample input |
| `infsh app sample <app> --save <file>` | Save sample to file |
## Task Commands
| Command | Description |
|---------|-------------|
| `infsh task get <task-id>` | Get task status and result |
| `infsh task get <task-id> --json` | Get task as JSON |
| `infsh task get <task-id> --save <file>` | Save task result to file |
### Development
| Command | Description |
|---------|-------------|
| `infsh app init` | Create new app (interactive) |
| `infsh app init <name>` | Create new app with name |
| `infsh app test --input <file>` | Test app locally |
| `infsh app deploy` | Deploy app |
| `infsh app deploy --dry-run` | Validate without deploying |
| `infsh app pull <id>` | Pull app source |
| `infsh app pull --all` | Pull all your apps |
## Environment Variables
| Variable | Description |
|----------|-------------|
| `INFSH_API_KEY` | API key (overrides config) |
## Shell Completions
```bash
# Bash
infsh completion bash > /etc/bash_completion.d/infsh
# Zsh
infsh completion zsh > "${fpath[1]}/_infsh"
# Fish
infsh completion fish > ~/.config/fish/completions/infsh.fish
```
## App Name Format
Apps use the format `namespace/app-name`:
- `falai/flux-dev-lora` - fal.ai's FLUX 2 Dev
- `google/veo-3` - Google's Veo 3
- `infsh/sdxl` - inference.sh's SDXL
- `bytedance/seedance-1-5-pro` - ByteDance's Seedance
- `xai/grok-imagine-image` - xAI's Grok
Version pinning: `namespace/app-name@version`
## Documentation
- [CLI Setup](https://inference.sh/docs/extend/cli-setup) - Complete CLI installation guide
- [Running Apps](https://inference.sh/docs/apps/running) - How to run apps via CLI
- [Creating an App](https://inference.sh/docs/extend/creating-app) - Build your own apps
- [Deploying](https://inference.sh/docs/extend/deploying) - Deploy apps to the cloud

View file

@ -0,0 +1,171 @@
# Running Apps
## Basic Run
```bash
infsh app run user/app-name --input input.json
```
## Inline JSON
```bash
infsh app run falai/flux-dev-lora --input '{"prompt": "a sunset over mountains"}'
```
## Version Pinning
```bash
infsh app run user/app-name@1.0.0 --input input.json
```
## Local File Uploads
The CLI automatically uploads local files when you provide a file path instead of a URL. Any field that accepts a URL also accepts a local path:
```bash
# Upscale a local image
infsh app run falai/topaz-image-upscaler --input '{"image": "/path/to/photo.jpg", "upscale_factor": 2}'
# Image-to-video from local file
infsh app run falai/wan-2-5-i2v --input '{"image": "./my-image.png", "prompt": "make it move"}'
# Avatar with local audio and image
infsh app run bytedance/omnihuman-1-5 --input '{"audio": "/path/to/speech.mp3", "image": "/path/to/face.jpg"}'
# Post tweet with local media
infsh app run x/post-create --input '{"text": "Check this out!", "media": "./screenshot.png"}'
```
Supported paths:
- Absolute paths: `/home/user/images/photo.jpg`
- Relative paths: `./image.png`, `../data/video.mp4`
- Home directory: `~/Pictures/photo.jpg`
## Generate Sample Input
Before running, generate a sample input file:
```bash
infsh app sample falai/flux-dev-lora
```
Save to file:
```bash
infsh app sample falai/flux-dev-lora --save input.json
```
Then edit `input.json` and run:
```bash
infsh app run falai/flux-dev-lora --input input.json
```
## Workflow Example
### Image Generation with FLUX
```bash
# 1. Get app details
infsh app get falai/flux-dev-lora
# 2. Generate sample input
infsh app sample falai/flux-dev-lora --save input.json
# 3. Edit input.json
# {
# "prompt": "a cat astronaut floating in space",
# "num_images": 1,
# "image_size": "landscape_16_9"
# }
# 4. Run
infsh app run falai/flux-dev-lora --input input.json
```
### Video Generation with Veo
```bash
# 1. Generate sample
infsh app sample google/veo-3-1-fast --save input.json
# 2. Edit prompt
# {
# "prompt": "A drone shot flying over a forest at sunset"
# }
# 3. Run
infsh app run google/veo-3-1-fast --input input.json
```
### Text-to-Speech
```bash
# Quick inline run
infsh app run falai/kokoro-tts --input '{"text": "Hello, this is a test."}'
```
## Task Tracking
When you run an app, the CLI shows the task ID:
```
Running falai/flux-dev-lora
Task ID: abc123def456
```
For long-running tasks, you can check status anytime:
```bash
# Check task status
infsh task get abc123def456
# Get result as JSON
infsh task get abc123def456 --json
# Save result to file
infsh task get abc123def456 --save result.json
```
### Run Without Waiting
For very long tasks, run in background:
```bash
# Submit and return immediately
infsh app run google/veo-3 --input input.json --no-wait
# Check later
infsh task get <task-id>
```
## Output
The CLI returns the app output directly. For file outputs (images, videos, audio), you'll receive URLs to download.
Example output:
```json
{
"images": [
{
"url": "https://cloud.inference.sh/...",
"content_type": "image/png"
}
]
}
```
## Error Handling
| Error | Cause | Solution |
|-------|-------|----------|
| "invalid input" | Schema mismatch | Check `infsh app get` for required fields |
| "app not found" | Wrong app name | Check `infsh app list --search` |
| "quota exceeded" | Out of credits | Check account balance |
## Documentation
- [Running Apps](https://inference.sh/docs/apps/running) - Complete running apps guide
- [Streaming Results](https://inference.sh/docs/api/sdk/streaming) - Real-time progress updates
- [Setup Parameters](https://inference.sh/docs/apps/setup-parameters) - Configuring app inputs

View file

@ -0,0 +1,335 @@
---
name: huggingface-accelerate
description: Simplest distributed training API. 4 lines to add distributed support to any PyTorch script. Unified API for DeepSpeed/FSDP/Megatron/DDP. Automatic device placement, mixed precision (FP16/BF16/FP8). Interactive config, single launch command. HuggingFace ecosystem standard.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [accelerate, torch, transformers]
metadata:
hermes:
tags: [Distributed Training, HuggingFace, Accelerate, DeepSpeed, FSDP, Mixed Precision, PyTorch, DDP, Unified API, Simple]
---
# HuggingFace Accelerate - Unified Distributed Training
## Quick start
Accelerate simplifies distributed training to 4 lines of code.
**Installation**:
```bash
pip install accelerate
```
**Convert PyTorch script** (4 lines):
```python
import torch
+ from accelerate import Accelerator
+ accelerator = Accelerator()
model = torch.nn.Transformer()
optimizer = torch.optim.Adam(model.parameters())
dataloader = torch.utils.data.DataLoader(dataset)
+ model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
for batch in dataloader:
optimizer.zero_grad()
loss = model(batch)
- loss.backward()
+ accelerator.backward(loss)
optimizer.step()
```
**Run** (single command):
```bash
accelerate launch train.py
```
## Common workflows
### Workflow 1: From single GPU to multi-GPU
**Original script**:
```python
# train.py
import torch
model = torch.nn.Linear(10, 2).to('cuda')
optimizer = torch.optim.Adam(model.parameters())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
for epoch in range(10):
for batch in dataloader:
batch = batch.to('cuda')
optimizer.zero_grad()
loss = model(batch).mean()
loss.backward()
optimizer.step()
```
**With Accelerate** (4 lines added):
```python
# train.py
import torch
from accelerate import Accelerator # +1
accelerator = Accelerator() # +2
model = torch.nn.Linear(10, 2)
optimizer = torch.optim.Adam(model.parameters())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) # +3
for epoch in range(10):
for batch in dataloader:
# No .to('cuda') needed - automatic!
optimizer.zero_grad()
loss = model(batch).mean()
accelerator.backward(loss) # +4
optimizer.step()
```
**Configure** (interactive):
```bash
accelerate config
```
**Questions**:
- Which machine? (single/multi GPU/TPU/CPU)
- How many machines? (1)
- Mixed precision? (no/fp16/bf16/fp8)
- DeepSpeed? (no/yes)
**Launch** (works on any setup):
```bash
# Single GPU
accelerate launch train.py
# Multi-GPU (8 GPUs)
accelerate launch --multi_gpu --num_processes 8 train.py
# Multi-node
accelerate launch --multi_gpu --num_processes 16 \
--num_machines 2 --machine_rank 0 \
--main_process_ip $MASTER_ADDR \
train.py
```
### Workflow 2: Mixed precision training
**Enable FP16/BF16**:
```python
from accelerate import Accelerator
# FP16 (with gradient scaling)
accelerator = Accelerator(mixed_precision='fp16')
# BF16 (no scaling, more stable)
accelerator = Accelerator(mixed_precision='bf16')
# FP8 (H100+)
accelerator = Accelerator(mixed_precision='fp8')
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
# Everything else is automatic!
for batch in dataloader:
with accelerator.autocast(): # Optional, done automatically
loss = model(batch)
accelerator.backward(loss)
```
### Workflow 3: DeepSpeed ZeRO integration
**Enable DeepSpeed ZeRO-2**:
```python
from accelerate import Accelerator
accelerator = Accelerator(
mixed_precision='bf16',
deepspeed_plugin={
"zero_stage": 2, # ZeRO-2
"offload_optimizer": False,
"gradient_accumulation_steps": 4
}
)
# Same code as before!
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
```
**Or via config**:
```bash
accelerate config
# Select: DeepSpeed → ZeRO-2
```
**deepspeed_config.json**:
```json
{
"fp16": {"enabled": false},
"bf16": {"enabled": true},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {"device": "cpu"},
"allgather_bucket_size": 5e8,
"reduce_bucket_size": 5e8
}
}
```
**Launch**:
```bash
accelerate launch --config_file deepspeed_config.json train.py
```
### Workflow 4: FSDP (Fully Sharded Data Parallel)
**Enable FSDP**:
```python
from accelerate import Accelerator, FullyShardedDataParallelPlugin
fsdp_plugin = FullyShardedDataParallelPlugin(
sharding_strategy="FULL_SHARD", # ZeRO-3 equivalent
auto_wrap_policy="TRANSFORMER_AUTO_WRAP",
cpu_offload=False
)
accelerator = Accelerator(
mixed_precision='bf16',
fsdp_plugin=fsdp_plugin
)
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
```
**Or via config**:
```bash
accelerate config
# Select: FSDP → Full Shard → No CPU Offload
```
### Workflow 5: Gradient accumulation
**Accumulate gradients**:
```python
from accelerate import Accelerator
accelerator = Accelerator(gradient_accumulation_steps=4)
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
for batch in dataloader:
with accelerator.accumulate(model): # Handles accumulation
optimizer.zero_grad()
loss = model(batch)
accelerator.backward(loss)
optimizer.step()
```
**Effective batch size**: `batch_size * num_gpus * gradient_accumulation_steps`
## When to use vs alternatives
**Use Accelerate when**:
- Want simplest distributed training
- Need single script for any hardware
- Use HuggingFace ecosystem
- Want flexibility (DDP/DeepSpeed/FSDP/Megatron)
- Need quick prototyping
**Key advantages**:
- **4 lines**: Minimal code changes
- **Unified API**: Same code for DDP, DeepSpeed, FSDP, Megatron
- **Automatic**: Device placement, mixed precision, sharding
- **Interactive config**: No manual launcher setup
- **Single launch**: Works everywhere
**Use alternatives instead**:
- **PyTorch Lightning**: Need callbacks, high-level abstractions
- **Ray Train**: Multi-node orchestration, hyperparameter tuning
- **DeepSpeed**: Direct API control, advanced features
- **Raw DDP**: Maximum control, minimal abstraction
## Common issues
**Issue: Wrong device placement**
Don't manually move to device:
```python
# WRONG
batch = batch.to('cuda')
# CORRECT
# Accelerate handles it automatically after prepare()
```
**Issue: Gradient accumulation not working**
Use context manager:
```python
# CORRECT
with accelerator.accumulate(model):
optimizer.zero_grad()
accelerator.backward(loss)
optimizer.step()
```
**Issue: Checkpointing in distributed**
Use accelerator methods:
```python
# Save only on main process
if accelerator.is_main_process:
accelerator.save_state('checkpoint/')
# Load on all processes
accelerator.load_state('checkpoint/')
```
**Issue: Different results with FSDP**
Ensure same random seed:
```python
from accelerate.utils import set_seed
set_seed(42)
```
## Advanced topics
**Megatron integration**: See [references/megatron-integration.md](references/megatron-integration.md) for tensor parallelism, pipeline parallelism, and sequence parallelism setup.
**Custom plugins**: See [references/custom-plugins.md](references/custom-plugins.md) for creating custom distributed plugins and advanced configuration.
**Performance tuning**: See [references/performance.md](references/performance.md) for profiling, memory optimization, and best practices.
## Hardware requirements
- **CPU**: Works (slow)
- **Single GPU**: Works
- **Multi-GPU**: DDP (default), DeepSpeed, or FSDP
- **Multi-node**: DDP, DeepSpeed, FSDP, Megatron
- **TPU**: Supported
- **Apple MPS**: Supported
**Launcher requirements**:
- **DDP**: `torch.distributed.run` (built-in)
- **DeepSpeed**: `deepspeed` (pip install deepspeed)
- **FSDP**: PyTorch 1.12+ (built-in)
- **Megatron**: Custom setup
## Resources
- Docs: https://huggingface.co/docs/accelerate
- GitHub: https://github.com/huggingface/accelerate
- Version: 1.11.0+
- Tutorial: "Accelerate your scripts"
- Examples: https://github.com/huggingface/accelerate/tree/main/examples
- Used by: HuggingFace Transformers, TRL, PEFT, all HF libraries

View file

@ -0,0 +1,453 @@
# Custom Plugins for Accelerate
## Overview
Accelerate allows creating **custom plugins** to extend distributed training strategies beyond built-in options (DDP, FSDP, DeepSpeed).
## Plugin Architecture
### Base Plugin Structure
```python
from accelerate.utils import DistributedDataParallelKwargs
from dataclasses import dataclass
@dataclass
class CustomPlugin:
"""Custom training plugin."""
# Plugin configuration
param1: int = 1
param2: str = "default"
def __post_init__(self):
# Validation logic
if self.param1 < 1:
raise ValueError("param1 must be >= 1")
```
### Using Custom Plugin
```python
from accelerate import Accelerator
# Create plugin
custom_plugin = CustomPlugin(param1=4, param2="value")
# Pass to Accelerator
accelerator = Accelerator(
custom_plugin=custom_plugin # Not a real parameter, example only
)
```
## Built-In Plugin Examples
### 1. GradScalerKwargs (FP16 Configuration)
```python
from accelerate.utils import GradScalerKwargs
# Configure gradient scaler for FP16
scaler_kwargs = GradScalerKwargs(
init_scale=2.**16, # Initial loss scale
growth_factor=2.0, # Scale growth rate
backoff_factor=0.5, # Scale backoff rate
growth_interval=2000, # Steps between scale increases
enabled=True # Enable scaler
)
accelerator = Accelerator(
mixed_precision='fp16',
kwargs_handlers=[scaler_kwargs] # Pass as kwargs handler
)
```
**Use case**: Fine-tune FP16 gradient scaling behavior
### 2. DistributedDataParallelKwargs
```python
from accelerate.utils import DistributedDataParallelKwargs
# Configure DDP behavior
ddp_kwargs = DistributedDataParallelKwargs(
bucket_cap_mb=25, # Gradient bucketing size
find_unused_parameters=False, # Find unused params (slower)
check_reduction=False, # Check gradient reduction
gradient_as_bucket_view=True, # Memory optimization
static_graph=False # Static computation graph
)
accelerator = Accelerator(
kwargs_handlers=[ddp_kwargs]
)
```
**Use case**: Optimize DDP performance for specific models
### 3. FP8RecipeKwargs (H100 FP8)
```python
from accelerate.utils import FP8RecipeKwargs
# Configure FP8 training (H100)
fp8_recipe = FP8RecipeKwargs(
backend="te", # TransformerEngine backend
margin=0, # Scaling margin
interval=1, # Scaling interval
fp8_format="HYBRID", # E4M3 + E5M2 hybrid
amax_history_len=1024, # AMAX history length
amax_compute_algo="max" # AMAX computation algorithm
)
accelerator = Accelerator(
mixed_precision='fp8',
kwargs_handlers=[fp8_recipe]
)
```
**Use case**: Ultra-fast training on H100 GPUs
## Custom DeepSpeed Configuration
### ZeRO-3 with CPU Offload
```python
from accelerate import Accelerator
from accelerate.utils import DeepSpeedPlugin
# Custom DeepSpeed config
ds_plugin = DeepSpeedPlugin(
zero_stage=3, # ZeRO-3
offload_optimizer_device="cpu", # CPU offload optimizer
offload_param_device="cpu", # CPU offload parameters
zero3_init_flag=True, # ZeRO-3 initialization
zero3_save_16bit_model=True, # Save FP16 weights
)
accelerator = Accelerator(
deepspeed_plugin=ds_plugin,
mixed_precision='bf16'
)
```
### ZeRO-2 with NVMe Offload
```python
ds_plugin = DeepSpeedPlugin(
zero_stage=2,
offload_optimizer_device="nvme", # NVMe offload
offload_param_device="nvme",
nvme_path="/local_nvme", # NVMe mount path
)
```
### Custom JSON Config
```python
import json
# Load custom DeepSpeed config
with open('deepspeed_config.json', 'r') as f:
ds_config = json.load(f)
ds_plugin = DeepSpeedPlugin(hf_ds_config=ds_config)
accelerator = Accelerator(deepspeed_plugin=ds_plugin)
```
**Example config** (`deepspeed_config.json`):
```json
{
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": 5e8,
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e6,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"bf16": {
"enabled": true
},
"steps_per_print": 100,
"wall_clock_breakdown": false
}
```
## Custom FSDP Configuration
### FSDP with Custom Auto-Wrap Policy
```python
from accelerate.utils import FullyShardedDataParallelPlugin
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
import functools
# Custom wrap policy (size-based)
wrap_policy = functools.partial(
size_based_auto_wrap_policy,
min_num_params=1e6 # Wrap layers with 1M+ params
)
fsdp_plugin = FullyShardedDataParallelPlugin(
sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3 equivalent
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # Prefetch strategy
mixed_precision_policy=None, # Use Accelerator's mixed precision
auto_wrap_policy=wrap_policy, # Custom wrapping
cpu_offload=False,
ignored_modules=None, # Modules to not wrap
state_dict_type="FULL_STATE_DICT", # Save format
optim_state_dict_config=None,
limit_all_gathers=False,
use_orig_params=True, # Use original param shapes
)
accelerator = Accelerator(
fsdp_plugin=fsdp_plugin,
mixed_precision='bf16'
)
```
### FSDP with Transformer Auto-Wrap
```python
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
# Wrap at transformer block level
wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={GPT2Block} # Wrap GPT2Block layers
)
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=wrap_policy
)
```
## Creating Custom Training Strategy
### Example: Custom Gradient Accumulation
```python
from accelerate import Accelerator
class CustomGradientAccumulation:
def __init__(self, steps=4, adaptive=False):
self.steps = steps
self.adaptive = adaptive
self.current_step = 0
def should_sync(self, loss):
"""Decide whether to sync gradients."""
self.current_step += 1
# Adaptive: sync on high loss
if self.adaptive and loss > threshold:
self.current_step = 0
return True
# Regular: sync every N steps
if self.current_step >= self.steps:
self.current_step = 0
return True
return False
# Usage
custom_accum = CustomGradientAccumulation(steps=8, adaptive=True)
accelerator = Accelerator()
for batch in dataloader:
outputs = model(**batch)
loss = outputs.loss
# Scale loss
loss = loss / custom_accum.steps
accelerator.backward(loss)
# Conditional sync
if custom_accum.should_sync(loss.item()):
optimizer.step()
optimizer.zero_grad()
```
### Example: Custom Mixed Precision
```python
import torch
class CustomMixedPrecision:
"""Custom mixed precision with dynamic loss scaling."""
def __init__(self, init_scale=2**16, scale_window=2000):
self.scaler = torch.cuda.amp.GradScaler(
init_scale=init_scale,
growth_interval=scale_window
)
self.scale_history = []
def scale_loss(self, loss):
"""Scale loss for backward."""
return self.scaler.scale(loss)
def unscale_and_clip(self, optimizer, max_norm=1.0):
"""Unscale gradients and clip."""
self.scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
optimizer.param_groups[0]['params'],
max_norm
)
def step(self, optimizer):
"""Optimizer step with scaler update."""
scale_before = self.scaler.get_scale()
self.scaler.step(optimizer)
self.scaler.update()
scale_after = self.scaler.get_scale()
# Track scale changes
if scale_before != scale_after:
self.scale_history.append(scale_after)
# Usage
custom_mp = CustomMixedPrecision()
for batch in dataloader:
with torch.cuda.amp.autocast(dtype=torch.float16):
loss = model(**batch).loss
scaled_loss = custom_mp.scale_loss(loss)
scaled_loss.backward()
custom_mp.unscale_and_clip(optimizer, max_norm=1.0)
custom_mp.step(optimizer)
optimizer.zero_grad()
```
## Advanced: Custom Distributed Backend
### Custom AllReduce Strategy
```python
import torch.distributed as dist
class CustomAllReduce:
"""Custom all-reduce with compression."""
def __init__(self, compression_ratio=0.1):
self.compression_ratio = compression_ratio
def compress_gradients(self, tensor):
"""Top-k gradient compression."""
k = int(tensor.numel() * self.compression_ratio)
values, indices = torch.topk(tensor.abs().view(-1), k)
return values, indices
def all_reduce_compressed(self, tensor):
"""All-reduce with gradient compression."""
# Compress
values, indices = self.compress_gradients(tensor)
# All-reduce compressed gradients
dist.all_reduce(values, op=dist.ReduceOp.SUM)
# Decompress
tensor_compressed = torch.zeros_like(tensor).view(-1)
tensor_compressed[indices] = values / dist.get_world_size()
return tensor_compressed.view_as(tensor)
# Usage in training loop
custom_ar = CustomAllReduce(compression_ratio=0.1)
for batch in dataloader:
loss = model(**batch).loss
loss.backward()
# Custom all-reduce
for param in model.parameters():
if param.grad is not None:
param.grad.data = custom_ar.all_reduce_compressed(param.grad.data)
optimizer.step()
optimizer.zero_grad()
```
## Plugin Best Practices
### 1. Validation in `__post_init__`
```python
@dataclass
class CustomPlugin:
learning_rate: float = 1e-3
warmup_steps: int = 1000
def __post_init__(self):
# Validate parameters
if self.learning_rate <= 0:
raise ValueError("learning_rate must be positive")
if self.warmup_steps < 0:
raise ValueError("warmup_steps must be non-negative")
# Compute derived values
self.min_lr = self.learning_rate * 0.1
```
### 2. Compatibility Checks
```python
@dataclass
class CustomPlugin:
feature_enabled: bool = True
def is_compatible(self, accelerator):
"""Check if plugin is compatible with accelerator config."""
if self.feature_enabled and accelerator.mixed_precision == 'fp8':
raise ValueError("Custom plugin not compatible with FP8")
return True
```
### 3. State Management
```python
@dataclass
class CustomPlugin:
counter: int = 0
history: list = None
def __post_init__(self):
if self.history is None:
self.history = []
def update_state(self, value):
"""Update plugin state during training."""
self.counter += 1
self.history.append(value)
```
## Resources
- Accelerate Plugins: https://huggingface.co/docs/accelerate/package_reference/kwargs
- DeepSpeed Config: https://www.deepspeed.ai/docs/config-json/
- FSDP Guide: https://pytorch.org/docs/stable/fsdp.html
- Custom Training Loops: https://huggingface.co/docs/accelerate/usage_guides/training_tpu

View file

@ -0,0 +1,489 @@
# Megatron Integration with Accelerate
## Overview
Accelerate supports Megatron-LM for massive model training with tensor parallelism and pipeline parallelism.
**Megatron capabilities**:
- **Tensor Parallelism (TP)**: Split layers across GPUs
- **Pipeline Parallelism (PP)**: Split model depth across GPUs
- **Data Parallelism (DP)**: Replicate model across GPU groups
- **Sequence Parallelism**: Split sequences for long contexts
## Setup
### Install Megatron-LM
```bash
# Clone Megatron-LM repository
git clone https://github.com/NVIDIA/Megatron-LM.git
cd Megatron-LM
pip install -e .
# Install Apex (NVIDIA optimizations)
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \
--config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
```
### Accelerate Configuration
```bash
accelerate config
```
**Questions**:
```
In which compute environment are you running?
> This machine
Which type of machine are you using?
> Multi-GPU
How many different machines will you use?
> 1
Do you want to use DeepSpeed/FSDP?
> No
Do you want to use Megatron-LM?
> Yes
What is the Tensor Parallelism degree? [1-8]
> 2
Do you want to enable Sequence Parallelism?
> No
What is the Pipeline Parallelism degree? [1-8]
> 2
What is the Data Parallelism degree? [1-8]
> 2
Where to perform activation checkpointing? ['SELECTIVE', 'FULL', 'NONE']
> SELECTIVE
Where to perform activation partitioning? ['SEQUENTIAL', 'UNIFORM']
> SEQUENTIAL
```
**Generated config** (`~/.cache/huggingface/accelerate/default_config.yaml`):
```yaml
compute_environment: LOCAL_MACHINE
distributed_type: MEGATRON_LM
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
megatron_lm_config:
megatron_lm_gradient_clipping: 1.0
megatron_lm_learning_rate_decay_iters: 320000
megatron_lm_num_micro_batches: 1
megatron_lm_pp_degree: 2
megatron_lm_recompute_activations: true
megatron_lm_sequence_parallelism: false
megatron_lm_tp_degree: 2
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
## Parallelism Strategies
### Tensor Parallelism (TP)
**Splits each transformer layer across GPUs**:
```python
# Layer split across 2 GPUs
# GPU 0: First half of attention heads
# GPU 1: Second half of attention heads
# Each GPU computes partial outputs
# All-reduce combines results
```
**TP degree recommendations**:
- **TP=1**: No tensor parallelism (single GPU per layer)
- **TP=2**: 2 GPUs per layer (good for 7-13B models)
- **TP=4**: 4 GPUs per layer (good for 20-40B models)
- **TP=8**: 8 GPUs per layer (good for 70B+ models)
**Benefits**:
- Reduces memory per GPU
- All-reduce communication (fast)
**Drawbacks**:
- Requires fast inter-GPU bandwidth (NVLink)
- Communication overhead per layer
### Pipeline Parallelism (PP)
**Splits model depth across GPUs**:
```python
# 12-layer model, PP=4
# GPU 0: Layers 0-2
# GPU 1: Layers 3-5
# GPU 2: Layers 6-8
# GPU 3: Layers 9-11
```
**PP degree recommendations**:
- **PP=1**: No pipeline parallelism
- **PP=2**: 2 pipeline stages (good for 20-40B models)
- **PP=4**: 4 pipeline stages (good for 70B+ models)
- **PP=8**: 8 pipeline stages (good for 175B+ models)
**Benefits**:
- Linear memory reduction (4× PP = 4× less memory)
- Works across nodes (slower interconnect OK)
**Drawbacks**:
- Pipeline bubbles (idle time)
- Requires micro-batching
### Data Parallelism (DP)
**Replicates model across GPU groups**:
```python
# 8 GPUs, TP=2, PP=2, DP=2
# Group 0 (GPUs 0-3): Full model replica
# Group 1 (GPUs 4-7): Full model replica
```
**DP degree**:
- `DP = total_gpus / (TP × PP)`
- Example: 8 GPUs, TP=2, PP=2 → DP=2
**Benefits**:
- Increases throughput
- Scales batch size
### Sequence Parallelism
**Splits long sequences across GPUs** (extends TP):
```python
# 8K sequence, TP=2, Sequence Parallel=True
# GPU 0: Tokens 0-4095
# GPU 1: Tokens 4096-8191
```
**Benefits**:
- Enables very long sequences (100K+ tokens)
- Reduces activation memory
**Requirements**:
- Must use with TP > 1
- RoPE/ALiBi position encodings work best
## Accelerate Code Example
### Basic Setup
```python
from accelerate import Accelerator
from accelerate.utils import MegatronLMPlugin
# Configure Megatron
megatron_plugin = MegatronLMPlugin(
tp_degree=2, # Tensor parallelism degree
pp_degree=2, # Pipeline parallelism degree
num_micro_batches=4, # Micro-batches for pipeline
gradient_clipping=1.0, # Gradient clipping value
sequence_parallelism=False, # Enable sequence parallelism
recompute_activations=True, # Activation checkpointing
use_distributed_optimizer=True, # Distributed optimizer
custom_prepare_model_function=None, # Custom model prep
)
# Initialize accelerator
accelerator = Accelerator(
mixed_precision='bf16',
megatron_lm_plugin=megatron_plugin
)
# Prepare model and optimizer
model, optimizer, train_dataloader = accelerator.prepare(
model, optimizer, train_dataloader
)
# Training loop (same as DDP!)
for batch in train_dataloader:
optimizer.zero_grad()
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
```
### Full Training Script
```python
import torch
from accelerate import Accelerator
from accelerate.utils import MegatronLMPlugin
from transformers import GPT2Config, GPT2LMHeadModel
def main():
# Megatron configuration
megatron_plugin = MegatronLMPlugin(
tp_degree=2,
pp_degree=2,
num_micro_batches=4,
gradient_clipping=1.0,
)
accelerator = Accelerator(
mixed_precision='bf16',
gradient_accumulation_steps=8,
megatron_lm_plugin=megatron_plugin
)
# Model
config = GPT2Config(
n_layer=24,
n_head=16,
n_embd=1024,
)
model = GPT2LMHeadModel(config)
# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4)
# Prepare
model, optimizer, train_loader = accelerator.prepare(
model, optimizer, train_loader
)
# Training loop
for epoch in range(num_epochs):
for batch in train_loader:
with accelerator.accumulate(model):
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
# Save checkpoint
accelerator.wait_for_everyone()
accelerator.save_state(f'checkpoint-epoch-{epoch}')
if __name__ == '__main__':
main()
```
### Launch Command
```bash
# 8 GPUs, TP=2, PP=2, DP=2
accelerate launch --multi_gpu --num_processes 8 train.py
# Multi-node (2 nodes, 8 GPUs each)
# Node 0
accelerate launch --multi_gpu --num_processes 16 \
--num_machines 2 --machine_rank 0 \
--main_process_ip $MASTER_ADDR \
--main_process_port 29500 \
train.py
# Node 1
accelerate launch --multi_gpu --num_processes 16 \
--num_machines 2 --machine_rank 1 \
--main_process_ip $MASTER_ADDR \
--main_process_port 29500 \
train.py
```
## Activation Checkpointing
**Reduces memory by recomputing activations**:
```python
megatron_plugin = MegatronLMPlugin(
recompute_activations=True, # Enable checkpointing
checkpoint_num_layers=1, # Checkpoint every N layers
distribute_checkpointed_activations=True, # Distribute across TP
partition_activations=True, # Partition in PP
check_for_nan_in_loss_and_grad=True, # Stability check
)
```
**Strategies**:
- `SELECTIVE`: Checkpoint transformer blocks only
- `FULL`: Checkpoint all layers
- `NONE`: No checkpointing
**Memory savings**: 30-50% with 10-15% slowdown
## Distributed Optimizer
**Shards optimizer state across DP ranks**:
```python
megatron_plugin = MegatronLMPlugin(
use_distributed_optimizer=True, # Enable sharded optimizer
)
```
**Benefits**:
- Reduces optimizer memory by DP degree
- Example: DP=4 → 4× less optimizer memory per GPU
**Compatible with**:
- AdamW, Adam, SGD
- Mixed precision training
## Performance Tuning
### Micro-Batch Size
```python
# Pipeline parallelism requires micro-batching
megatron_plugin = MegatronLMPlugin(
pp_degree=4,
num_micro_batches=16, # 16 micro-batches per pipeline
)
# Effective batch = num_micro_batches × micro_batch_size × DP
# Example: 16 × 2 × 4 = 128
```
**Recommendations**:
- More micro-batches → less pipeline bubble
- Typical: 4-16 micro-batches
### Sequence Length
```python
# For long sequences, enable sequence parallelism
megatron_plugin = MegatronLMPlugin(
tp_degree=4,
sequence_parallelism=True, # Required: TP > 1
)
# Enables sequences up to TP × normal limit
# Example: TP=4, 8K normal → 32K with sequence parallel
```
### GPU Topology
**NVLink required for TP**:
```bash
# Check NVLink topology
nvidia-smi topo -m
# Good topology (NVLink between all GPUs)
# GPU0 - GPU1: NV12 (fast)
# GPU0 - GPU2: NV12 (fast)
# Bad topology (PCIe only)
# GPU0 - GPU4: PHB (slow, avoid TP across these)
```
**Recommendations**:
- **TP**: Within same node (NVLink)
- **PP**: Across nodes (slower interconnect OK)
- **DP**: Any topology
## Model Size Guidelines
| Model Size | GPUs | TP | PP | DP | Micro-Batches |
|------------|------|----|----|----|--------------|
| 7B | 8 | 1 | 1 | 8 | 1 |
| 13B | 8 | 2 | 1 | 4 | 1 |
| 20B | 16 | 4 | 1 | 4 | 1 |
| 40B | 32 | 4 | 2 | 4 | 4 |
| 70B | 64 | 8 | 2 | 4 | 8 |
| 175B | 128 | 8 | 4 | 4 | 16 |
**Assumptions**: BF16, 2K sequence length, A100 80GB
## Checkpointing
### Save Checkpoint
```python
# Save full model state
accelerator.save_state('checkpoint-1000')
# Megatron saves separate files per rank
# checkpoint-1000/
# pytorch_model_tp_0_pp_0.bin
# pytorch_model_tp_0_pp_1.bin
# pytorch_model_tp_1_pp_0.bin
# pytorch_model_tp_1_pp_1.bin
# optimizer_tp_0_pp_0.bin
# ...
```
### Load Checkpoint
```python
# Resume training
accelerator.load_state('checkpoint-1000')
# Automatically loads correct shard per rank
```
### Convert to Standard PyTorch
```bash
# Merge Megatron checkpoint to single file
python merge_megatron_checkpoint.py \
--checkpoint-dir checkpoint-1000 \
--output pytorch_model.bin
```
## Common Issues
### Issue: OOM with Pipeline Parallelism
**Solution**: Increase micro-batches
```python
megatron_plugin = MegatronLMPlugin(
pp_degree=4,
num_micro_batches=16, # Increase from 4
)
```
### Issue: Slow Training
**Check 1**: Pipeline bubbles (PP too high)
```python
# Reduce PP, increase TP
tp_degree=4 # Increase
pp_degree=2 # Decrease
```
**Check 2**: Micro-batch size too small
```python
num_micro_batches=8 # Increase
```
### Issue: NVLink Not Detected
```bash
# Verify NVLink
nvidia-smi nvlink -s
# If no NVLink, avoid TP > 1
# Use PP or DP instead
```
## Resources
- Megatron-LM: https://github.com/NVIDIA/Megatron-LM
- Accelerate Megatron docs: https://huggingface.co/docs/accelerate/usage_guides/megatron_lm
- Paper: "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism"
- NVIDIA Apex: https://github.com/NVIDIA/apex

View file

@ -0,0 +1,525 @@
# Accelerate Performance Tuning
## Profiling
### Basic Profiling
```python
from accelerate import Accelerator
import time
accelerator = Accelerator()
# Warmup
for _ in range(10):
batch = next(iter(dataloader))
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
# Profile training loop
start = time.time()
total_batches = 100
for i, batch in enumerate(dataloader):
if i >= total_batches:
break
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
accelerator.wait_for_everyone() # Sync all processes
elapsed = time.time() - start
# Metrics
batches_per_sec = total_batches / elapsed
samples_per_sec = (total_batches * batch_size * accelerator.num_processes) / elapsed
print(f"Throughput: {samples_per_sec:.2f} samples/sec")
print(f"Batches/sec: {batches_per_sec:.2f}")
```
### PyTorch Profiler Integration
```python
from torch.profiler import profile, ProfilerActivity
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
for i, batch in enumerate(dataloader):
if i >= 10: # Profile first 10 batches
break
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
# Print profiling results
print(prof.key_averages().table(
sort_by="cuda_time_total", row_limit=20
))
# Export to Chrome tracing
prof.export_chrome_trace("trace.json")
# View at chrome://tracing
```
## Memory Optimization
### 1. Gradient Accumulation
**Problem**: Large batch size causes OOM
**Solution**: Accumulate gradients across micro-batches
```python
accelerator = Accelerator(gradient_accumulation_steps=8)
# Effective batch = batch_size × accumulation_steps × num_gpus
# Example: 4 × 8 × 8 = 256
for batch in dataloader:
with accelerator.accumulate(model): # Handles accumulation logic
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
```
**Memory savings**: 8× less activation memory (with 8 accumulation steps)
### 2. Gradient Checkpointing
**Enable in model**:
```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"gpt2",
use_cache=False # Required for gradient checkpointing
)
# Enable checkpointing
model.gradient_checkpointing_enable()
# Prepare with Accelerate
model = accelerator.prepare(model)
```
**Memory savings**: 30-50% with 10-15% slowdown
### 3. Mixed Precision
**BF16 (A100/H100)**:
```python
accelerator = Accelerator(mixed_precision='bf16')
# Automatic mixed precision
for batch in dataloader:
outputs = model(**batch) # Forward in BF16
loss = outputs.loss
accelerator.backward(loss) # Backward in FP32
optimizer.step()
```
**FP16 (V100, older GPUs)**:
```python
from accelerate.utils import GradScalerKwargs
scaler_kwargs = GradScalerKwargs(
init_scale=2.**16,
growth_interval=2000
)
accelerator = Accelerator(
mixed_precision='fp16',
kwargs_handlers=[scaler_kwargs]
)
```
**Memory savings**: 50% compared to FP32
### 4. CPU Offloading (DeepSpeed)
```python
from accelerate.utils import DeepSpeedPlugin
ds_plugin = DeepSpeedPlugin(
zero_stage=3,
offload_optimizer_device="cpu", # Offload optimizer to CPU
offload_param_device="cpu", # Offload parameters to CPU
)
accelerator = Accelerator(
deepspeed_plugin=ds_plugin,
mixed_precision='bf16'
)
```
**Memory savings**: 10-20× for optimizer state, 5-10× for parameters
**Trade-off**: 20-30% slower due to CPU-GPU transfers
### 5. Flash Attention
```python
# Install flash-attn
# pip install flash-attn
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"gpt2",
attn_implementation="flash_attention_2" # Enable Flash Attention 2
)
model = accelerator.prepare(model)
```
**Memory savings**: 50% for attention, 2× faster
**Requirements**: A100/H100, sequence length must be multiple of 128
## Communication Optimization
### 1. Gradient Bucketing (DDP)
```python
from accelerate.utils import DistributedDataParallelKwargs
ddp_kwargs = DistributedDataParallelKwargs(
bucket_cap_mb=25, # Bucket size for gradient reduction
gradient_as_bucket_view=True, # Reduce memory copies
static_graph=False # Set True if model doesn't change
)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
```
**Recommended bucket sizes**:
- Small models (<1B): 25 MB
- Medium models (1-10B): 50-100 MB
- Large models (>10B): 100-200 MB
### 2. Find Unused Parameters
```python
# Only enable if model has unused parameters (slower!)
ddp_kwargs = DistributedDataParallelKwargs(
find_unused_parameters=True
)
```
**Use case**: Models with conditional branches (e.g., mixture of experts)
**Cost**: 10-20% slower
### 3. NCCL Tuning
```bash
# Set environment variables before launch
export NCCL_DEBUG=INFO # Debug info
export NCCL_IB_DISABLE=0 # Enable InfiniBand
export NCCL_SOCKET_IFNAME=eth0 # Network interface
export NCCL_P2P_LEVEL=NVL # Use NVLink
accelerate launch train.py
```
**NCCL_P2P_LEVEL options**:
- `NVL`: NVLink (fastest, within node)
- `PIX`: PCIe (fast, within node)
- `PHB`: PCIe host bridge (slow, cross-node)
## Data Loading Optimization
### 1. DataLoader Workers
```python
from torch.utils.data import DataLoader
train_loader = DataLoader(
dataset,
batch_size=32,
num_workers=4, # Parallel data loading
pin_memory=True, # Pin memory for faster GPU transfer
prefetch_factor=2, # Prefetch batches per worker
persistent_workers=True # Keep workers alive between epochs
)
train_loader = accelerator.prepare(train_loader)
```
**Recommendations**:
- `num_workers`: 2-4 per GPU (8 GPUs → 16-32 workers)
- `pin_memory`: Always True for GPU training
- `prefetch_factor`: 2-4 (higher for slow data loading)
### 2. Data Preprocessing
```python
from datasets import load_dataset
# Bad: Preprocess during training (slow)
dataset = load_dataset("openwebtext")
for batch in dataset:
tokens = tokenizer(batch['text']) # Slow!
...
# Good: Preprocess once, save
dataset = load_dataset("openwebtext")
tokenized = dataset.map(
lambda x: tokenizer(x['text']),
batched=True,
num_proc=8, # Parallel preprocessing
remove_columns=['text']
)
tokenized.save_to_disk("preprocessed_data")
# Load preprocessed
dataset = load_from_disk("preprocessed_data")
```
### 3. Faster Tokenization
```python
import os
# Enable Rust-based tokenizers (10× faster)
os.environ["TOKENIZERS_PARALLELISM"] = "true"
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"gpt2",
use_fast=True # Use fast Rust tokenizer
)
```
## Compilation (PyTorch 2.0+)
### Compile Model
```python
import torch
# Compile model for faster execution
model = torch.compile(
model,
mode="reduce-overhead", # Options: default, reduce-overhead, max-autotune
fullgraph=False, # Compile entire graph (stricter)
dynamic=True # Support dynamic shapes
)
model = accelerator.prepare(model)
```
**Speedup**: 10-50% depending on model
**Compilation modes**:
- `default`: Balanced (best for most cases)
- `reduce-overhead`: Min overhead (best for small batches)
- `max-autotune`: Max performance (slow compile, best for production)
### Compilation Best Practices
```python
# Bad: Compile after prepare (won't work)
model = accelerator.prepare(model)
model = torch.compile(model) # Error!
# Good: Compile before prepare
model = torch.compile(model)
model = accelerator.prepare(model)
# Training loop
for batch in dataloader:
# First iteration: slow (compilation)
# Subsequent iterations: fast (compiled)
outputs = model(**batch)
...
```
## Benchmarking Different Strategies
### Script Template
```python
import time
import torch
from accelerate import Accelerator
def benchmark_strategy(strategy_name, accelerator_kwargs):
"""Benchmark a specific training strategy."""
accelerator = Accelerator(**accelerator_kwargs)
# Setup
model = create_model()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
dataloader = create_dataloader()
model, optimizer, dataloader = accelerator.prepare(
model, optimizer, dataloader
)
# Warmup
for i, batch in enumerate(dataloader):
if i >= 10:
break
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
# Benchmark
accelerator.wait_for_everyone()
torch.cuda.synchronize()
start = time.time()
num_batches = 100
for i, batch in enumerate(dataloader):
if i >= num_batches:
break
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
accelerator.wait_for_everyone()
torch.cuda.synchronize()
elapsed = time.time() - start
# Metrics
throughput = (num_batches * batch_size * accelerator.num_processes) / elapsed
memory_used = torch.cuda.max_memory_allocated() / 1e9 # GB
if accelerator.is_main_process:
print(f"\n{strategy_name}:")
print(f" Throughput: {throughput:.2f} samples/sec")
print(f" Memory: {memory_used:.2f} GB")
print(f" Time: {elapsed:.2f} sec")
torch.cuda.reset_peak_memory_stats()
# Benchmark different strategies
strategies = [
("DDP + FP32", {}),
("DDP + BF16", {"mixed_precision": "bf16"}),
("DDP + BF16 + GradAccum", {"mixed_precision": "bf16", "gradient_accumulation_steps": 4}),
("FSDP", {"fsdp_plugin": fsdp_plugin}),
("DeepSpeed ZeRO-2", {"deepspeed_plugin": ds_plugin_stage2}),
("DeepSpeed ZeRO-3", {"deepspeed_plugin": ds_plugin_stage3}),
]
for name, kwargs in strategies:
benchmark_strategy(name, kwargs)
```
## Performance Checklist
**Before training**:
- [ ] Use BF16/FP16 mixed precision
- [ ] Enable gradient checkpointing (if OOM)
- [ ] Set appropriate `num_workers` (2-4 per GPU)
- [ ] Enable `pin_memory=True`
- [ ] Preprocess data once, not during training
- [ ] Compile model with `torch.compile` (PyTorch 2.0+)
**For large models**:
- [ ] Use FSDP or DeepSpeed ZeRO-3
- [ ] Enable CPU offloading (if still OOM)
- [ ] Use Flash Attention
- [ ] Increase gradient accumulation
**For multi-node**:
- [ ] Check network topology (InfiniBand > Ethernet)
- [ ] Tune NCCL settings
- [ ] Use larger bucket sizes for DDP
- [ ] Verify NVLink for tensor parallelism
**Profiling**:
- [ ] Profile first 10-100 batches
- [ ] Check GPU utilization (`nvidia-smi dmon`)
- [ ] Check data loading time (should be <5% of iteration)
- [ ] Identify communication bottlenecks
## Common Performance Issues
### Issue: Low GPU Utilization (<80%)
**Cause 1**: Data loading bottleneck
```python
# Solution: Increase workers and prefetch
num_workers=8
prefetch_factor=4
```
**Cause 2**: Small batch size
```python
# Solution: Increase batch size or use gradient accumulation
batch_size=32 # Increase
gradient_accumulation_steps=4 # Or accumulate
```
### Issue: High Memory Usage
**Solution 1**: Gradient checkpointing
```python
model.gradient_checkpointing_enable()
```
**Solution 2**: Reduce batch size, increase accumulation
```python
batch_size=8 # Reduce from 32
gradient_accumulation_steps=16 # Maintain effective batch
```
**Solution 3**: Use FSDP or DeepSpeed ZeRO-3
```python
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
```
### Issue: Slow Multi-GPU Training
**Cause**: Communication bottleneck
**Check 1**: Gradient bucket size
```python
ddp_kwargs = DistributedDataParallelKwargs(bucket_cap_mb=100)
```
**Check 2**: NCCL settings
```bash
export NCCL_DEBUG=INFO
# Check for "Using NVLS" (good) vs "Using PHB" (bad)
```
**Check 3**: Network bandwidth
```bash
# Test inter-GPU bandwidth
nvidia-smi nvlink -s
```
## Resources
- Accelerate Performance: https://huggingface.co/docs/accelerate/usage_guides/performance
- PyTorch Profiler: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
- NCCL Tuning: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html
- Flash Attention: https://github.com/Dao-AILab/flash-attention

View file

@ -0,0 +1,409 @@
---
name: chroma
description: Open-source embedding database for AI applications. Store embeddings and metadata, perform vector and full-text search, filter by metadata. Simple 4-function API. Scales from notebooks to production clusters. Use for semantic search, RAG applications, or document retrieval. Best for local development and open-source projects.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [chromadb, sentence-transformers]
metadata:
hermes:
tags: [RAG, Chroma, Vector Database, Embeddings, Semantic Search, Open Source, Self-Hosted, Document Retrieval, Metadata Filtering]
---
# Chroma - Open-Source Embedding Database
The AI-native database for building LLM applications with memory.
## When to use Chroma
**Use Chroma when:**
- Building RAG (retrieval-augmented generation) applications
- Need local/self-hosted vector database
- Want open-source solution (Apache 2.0)
- Prototyping in notebooks
- Semantic search over documents
- Storing embeddings with metadata
**Metrics**:
- **24,300+ GitHub stars**
- **1,900+ forks**
- **v1.3.3** (stable, weekly releases)
- **Apache 2.0 license**
**Use alternatives instead**:
- **Pinecone**: Managed cloud, auto-scaling
- **FAISS**: Pure similarity search, no metadata
- **Weaviate**: Production ML-native database
- **Qdrant**: High performance, Rust-based
## Quick start
### Installation
```bash
# Python
pip install chromadb
# JavaScript/TypeScript
npm install chromadb @chroma-core/default-embed
```
### Basic usage (Python)
```python
import chromadb
# Create client
client = chromadb.Client()
# Create collection
collection = client.create_collection(name="my_collection")
# Add documents
collection.add(
documents=["This is document 1", "This is document 2"],
metadatas=[{"source": "doc1"}, {"source": "doc2"}],
ids=["id1", "id2"]
)
# Query
results = collection.query(
query_texts=["document about topic"],
n_results=2
)
print(results)
```
## Core operations
### 1. Create collection
```python
# Simple collection
collection = client.create_collection("my_docs")
# With custom embedding function
from chromadb.utils import embedding_functions
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
api_key="your-key",
model_name="text-embedding-3-small"
)
collection = client.create_collection(
name="my_docs",
embedding_function=openai_ef
)
# Get existing collection
collection = client.get_collection("my_docs")
# Delete collection
client.delete_collection("my_docs")
```
### 2. Add documents
```python
# Add with auto-generated IDs
collection.add(
documents=["Doc 1", "Doc 2", "Doc 3"],
metadatas=[
{"source": "web", "category": "tutorial"},
{"source": "pdf", "page": 5},
{"source": "api", "timestamp": "2025-01-01"}
],
ids=["id1", "id2", "id3"]
)
# Add with custom embeddings
collection.add(
embeddings=[[0.1, 0.2, ...], [0.3, 0.4, ...]],
documents=["Doc 1", "Doc 2"],
ids=["id1", "id2"]
)
```
### 3. Query (similarity search)
```python
# Basic query
results = collection.query(
query_texts=["machine learning tutorial"],
n_results=5
)
# Query with filters
results = collection.query(
query_texts=["Python programming"],
n_results=3,
where={"source": "web"}
)
# Query with metadata filters
results = collection.query(
query_texts=["advanced topics"],
where={
"$and": [
{"category": "tutorial"},
{"difficulty": {"$gte": 3}}
]
}
)
# Access results
print(results["documents"]) # List of matching documents
print(results["metadatas"]) # Metadata for each doc
print(results["distances"]) # Similarity scores
print(results["ids"]) # Document IDs
```
### 4. Get documents
```python
# Get by IDs
docs = collection.get(
ids=["id1", "id2"]
)
# Get with filters
docs = collection.get(
where={"category": "tutorial"},
limit=10
)
# Get all documents
docs = collection.get()
```
### 5. Update documents
```python
# Update document content
collection.update(
ids=["id1"],
documents=["Updated content"],
metadatas=[{"source": "updated"}]
)
```
### 6. Delete documents
```python
# Delete by IDs
collection.delete(ids=["id1", "id2"])
# Delete with filter
collection.delete(
where={"source": "outdated"}
)
```
## Persistent storage
```python
# Persist to disk
client = chromadb.PersistentClient(path="./chroma_db")
collection = client.create_collection("my_docs")
collection.add(documents=["Doc 1"], ids=["id1"])
# Data persisted automatically
# Reload later with same path
client = chromadb.PersistentClient(path="./chroma_db")
collection = client.get_collection("my_docs")
```
## Embedding functions
### Default (Sentence Transformers)
```python
# Uses sentence-transformers by default
collection = client.create_collection("my_docs")
# Default model: all-MiniLM-L6-v2
```
### OpenAI
```python
from chromadb.utils import embedding_functions
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
api_key="your-key",
model_name="text-embedding-3-small"
)
collection = client.create_collection(
name="openai_docs",
embedding_function=openai_ef
)
```
### HuggingFace
```python
huggingface_ef = embedding_functions.HuggingFaceEmbeddingFunction(
api_key="your-key",
model_name="sentence-transformers/all-mpnet-base-v2"
)
collection = client.create_collection(
name="hf_docs",
embedding_function=huggingface_ef
)
```
### Custom embedding function
```python
from chromadb import Documents, EmbeddingFunction, Embeddings
class MyEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
# Your embedding logic
return embeddings
my_ef = MyEmbeddingFunction()
collection = client.create_collection(
name="custom_docs",
embedding_function=my_ef
)
```
## Metadata filtering
```python
# Exact match
results = collection.query(
query_texts=["query"],
where={"category": "tutorial"}
)
# Comparison operators
results = collection.query(
query_texts=["query"],
where={"page": {"$gt": 10}} # $gt, $gte, $lt, $lte, $ne
)
# Logical operators
results = collection.query(
query_texts=["query"],
where={
"$and": [
{"category": "tutorial"},
{"difficulty": {"$lte": 3}}
]
} # Also: $or
)
# Contains
results = collection.query(
query_texts=["query"],
where={"tags": {"$in": ["python", "ml"]}}
)
```
## LangChain integration
```python
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
# Split documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000)
docs = text_splitter.split_documents(documents)
# Create Chroma vector store
vectorstore = Chroma.from_documents(
documents=docs,
embedding=OpenAIEmbeddings(),
persist_directory="./chroma_db"
)
# Query
results = vectorstore.similarity_search("machine learning", k=3)
# As retriever
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
```
## LlamaIndex integration
```python
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core import VectorStoreIndex, StorageContext
import chromadb
# Initialize Chroma
db = chromadb.PersistentClient(path="./chroma_db")
collection = db.get_or_create_collection("my_collection")
# Create vector store
vector_store = ChromaVectorStore(chroma_collection=collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
# Create index
index = VectorStoreIndex.from_documents(
documents,
storage_context=storage_context
)
# Query
query_engine = index.as_query_engine()
response = query_engine.query("What is machine learning?")
```
## Server mode
```python
# Run Chroma server
# Terminal: chroma run --path ./chroma_db --port 8000
# Connect to server
import chromadb
from chromadb.config import Settings
client = chromadb.HttpClient(
host="localhost",
port=8000,
settings=Settings(anonymized_telemetry=False)
)
# Use as normal
collection = client.get_or_create_collection("my_docs")
```
## Best practices
1. **Use persistent client** - Don't lose data on restart
2. **Add metadata** - Enables filtering and tracking
3. **Batch operations** - Add multiple docs at once
4. **Choose right embedding model** - Balance speed/quality
5. **Use filters** - Narrow search space
6. **Unique IDs** - Avoid collisions
7. **Regular backups** - Copy chroma_db directory
8. **Monitor collection size** - Scale up if needed
9. **Test embedding functions** - Ensure quality
10. **Use server mode for production** - Better for multi-user
## Performance
| Operation | Latency | Notes |
|-----------|---------|-------|
| Add 100 docs | ~1-3s | With embedding |
| Query (top 10) | ~50-200ms | Depends on collection size |
| Metadata filter | ~10-50ms | Fast with proper indexing |
## Resources
- **GitHub**: https://github.com/chroma-core/chroma ⭐ 24,300+
- **Docs**: https://docs.trychroma.com
- **Discord**: https://discord.gg/MMeYNTmh3x
- **Version**: 1.3.3+
- **License**: Apache 2.0

View file

@ -0,0 +1,38 @@
# Chroma Integration Guide
Integration with LangChain, LlamaIndex, and frameworks.
## LangChain
```python
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings
vectorstore = Chroma.from_documents(
documents=docs,
embedding=OpenAIEmbeddings(),
persist_directory="./chroma_db"
)
# Query
results = vectorstore.similarity_search("query", k=3)
# As retriever
retriever = vectorstore.as_retriever()
```
## LlamaIndex
```python
from llama_index.vector_stores.chroma import ChromaVectorStore
import chromadb
db = chromadb.PersistentClient(path="./chroma_db")
collection = db.get_or_create_collection("docs")
vector_store = ChromaVectorStore(chroma_collection=collection)
```
## Resources
- **Docs**: https://docs.trychroma.com

View file

@ -0,0 +1,224 @@
---
name: faiss
description: Facebook's library for efficient similarity search and clustering of dense vectors. Supports billions of vectors, GPU acceleration, and various index types (Flat, IVF, HNSW). Use for fast k-NN search, large-scale vector retrieval, or when you need pure similarity search without metadata. Best for high-performance applications.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [faiss-cpu, faiss-gpu, numpy]
metadata:
hermes:
tags: [RAG, FAISS, Similarity Search, Vector Search, Facebook AI, GPU Acceleration, Billion-Scale, K-NN, HNSW, High Performance, Large Scale]
---
# FAISS - Efficient Similarity Search
Facebook AI's library for billion-scale vector similarity search.
## When to use FAISS
**Use FAISS when:**
- Need fast similarity search on large vector datasets (millions/billions)
- GPU acceleration required
- Pure vector similarity (no metadata filtering needed)
- High throughput, low latency critical
- Offline/batch processing of embeddings
**Metrics**:
- **31,700+ GitHub stars**
- Meta/Facebook AI Research
- **Handles billions of vectors**
- **C++** with Python bindings
**Use alternatives instead**:
- **Chroma/Pinecone**: Need metadata filtering
- **Weaviate**: Need full database features
- **Annoy**: Simpler, fewer features
## Quick start
### Installation
```bash
# CPU only
pip install faiss-cpu
# GPU support
pip install faiss-gpu
```
### Basic usage
```python
import faiss
import numpy as np
# Create sample data (1000 vectors, 128 dimensions)
d = 128
nb = 1000
vectors = np.random.random((nb, d)).astype('float32')
# Create index
index = faiss.IndexFlatL2(d) # L2 distance
index.add(vectors) # Add vectors
# Search
k = 5 # Find 5 nearest neighbors
query = np.random.random((1, d)).astype('float32')
distances, indices = index.search(query, k)
print(f"Nearest neighbors: {indices}")
print(f"Distances: {distances}")
```
## Index types
### 1. Flat (exact search)
```python
# L2 (Euclidean) distance
index = faiss.IndexFlatL2(d)
# Inner product (cosine similarity if normalized)
index = faiss.IndexFlatIP(d)
# Slowest, most accurate
```
### 2. IVF (inverted file) - Fast approximate
```python
# Create quantizer
quantizer = faiss.IndexFlatL2(d)
# IVF index with 100 clusters
nlist = 100
index = faiss.IndexIVFFlat(quantizer, d, nlist)
# Train on data
index.train(vectors)
# Add vectors
index.add(vectors)
# Search (nprobe = clusters to search)
index.nprobe = 10
distances, indices = index.search(query, k)
```
### 3. HNSW (Hierarchical NSW) - Best quality/speed
```python
# HNSW index
M = 32 # Number of connections per layer
index = faiss.IndexHNSWFlat(d, M)
# No training needed
index.add(vectors)
# Search
distances, indices = index.search(query, k)
```
### 4. Product Quantization - Memory efficient
```python
# PQ reduces memory by 16-32×
m = 8 # Number of subquantizers
nbits = 8
index = faiss.IndexPQ(d, m, nbits)
# Train and add
index.train(vectors)
index.add(vectors)
```
## Save and load
```python
# Save index
faiss.write_index(index, "large.index")
# Load index
index = faiss.read_index("large.index")
# Continue using
distances, indices = index.search(query, k)
```
## GPU acceleration
```python
# Single GPU
res = faiss.StandardGpuResources()
index_cpu = faiss.IndexFlatL2(d)
index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu) # GPU 0
# Multi-GPU
index_gpu = faiss.index_cpu_to_all_gpus(index_cpu)
# 10-100× faster than CPU
```
## LangChain integration
```python
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
# Create FAISS vector store
vectorstore = FAISS.from_documents(docs, OpenAIEmbeddings())
# Save
vectorstore.save_local("faiss_index")
# Load
vectorstore = FAISS.load_local(
"faiss_index",
OpenAIEmbeddings(),
allow_dangerous_deserialization=True
)
# Search
results = vectorstore.similarity_search("query", k=5)
```
## LlamaIndex integration
```python
from llama_index.vector_stores.faiss import FaissVectorStore
import faiss
# Create FAISS index
d = 1536
faiss_index = faiss.IndexFlatL2(d)
vector_store = FaissVectorStore(faiss_index=faiss_index)
```
## Best practices
1. **Choose right index type** - Flat for <10K, IVF for 10K-1M, HNSW for quality
2. **Normalize for cosine** - Use IndexFlatIP with normalized vectors
3. **Use GPU for large datasets** - 10-100× faster
4. **Save trained indices** - Training is expensive
5. **Tune nprobe/ef_search** - Balance speed/accuracy
6. **Monitor memory** - PQ for large datasets
7. **Batch queries** - Better GPU utilization
## Performance
| Index Type | Build Time | Search Time | Memory | Accuracy |
|------------|------------|-------------|--------|----------|
| Flat | Fast | Slow | High | 100% |
| IVF | Medium | Fast | Medium | 95-99% |
| HNSW | Slow | Fastest | High | 99% |
| PQ | Medium | Fast | Low | 90-95% |
## Resources
- **GitHub**: https://github.com/facebookresearch/faiss ⭐ 31,700+
- **Wiki**: https://github.com/facebookresearch/faiss/wiki
- **License**: MIT

View file

@ -0,0 +1,280 @@
# FAISS Index Types Guide
Complete guide to choosing and using FAISS index types.
## Index selection guide
| Dataset Size | Index Type | Training | Accuracy | Speed |
|--------------|------------|----------|----------|-------|
| < 10K | Flat | No | 100% | Slow |
| 10K-1M | IVF | Yes | 95-99% | Fast |
| 1M-10M | HNSW | No | 99% | Fastest |
| > 10M | IVF+PQ | Yes | 90-95% | Fast, low memory |
## Flat indices (exact search)
### IndexFlatL2 - L2 (Euclidean) distance
```python
import faiss
import numpy as np
d = 128 # Dimension
index = faiss.IndexFlatL2(d)
# Add vectors
vectors = np.random.random((1000, d)).astype('float32')
index.add(vectors)
# Search
k = 5
query = np.random.random((1, d)).astype('float32')
distances, indices = index.search(query, k)
```
**Use when:**
- Dataset < 10,000 vectors
- Need 100% accuracy
- Serving as baseline
### IndexFlatIP - Inner product (cosine similarity)
```python
# For cosine similarity, normalize vectors first
import faiss
d = 128
index = faiss.IndexFlatIP(d)
# Normalize vectors (required for cosine similarity)
faiss.normalize_L2(vectors)
index.add(vectors)
# Search
faiss.normalize_L2(query)
distances, indices = index.search(query, k)
```
**Use when:**
- Need cosine similarity
- Recommendation systems
- Text embeddings
## IVF indices (inverted file)
### IndexIVFFlat - Cluster-based search
```python
# Create quantizer
quantizer = faiss.IndexFlatL2(d)
# Create IVF index with 100 clusters
nlist = 100 # Number of clusters
index = faiss.IndexIVFFlat(quantizer, d, nlist)
# Train on data (required!)
index.train(vectors)
# Add vectors
index.add(vectors)
# Search (nprobe = clusters to search)
index.nprobe = 10 # Search 10 closest clusters
distances, indices = index.search(query, k)
```
**Parameters:**
- `nlist`: Number of clusters (√N to 4√N recommended)
- `nprobe`: Clusters to search (1-nlist, higher = more accurate)
**Use when:**
- Dataset 10K-1M vectors
- Need fast approximate search
- Can afford training time
### Tuning nprobe
```python
# Test different nprobe values
for nprobe in [1, 5, 10, 20, 50]:
index.nprobe = nprobe
distances, indices = index.search(query, k)
# Measure recall/speed trade-off
```
**Guidelines:**
- `nprobe=1`: Fastest, ~50% recall
- `nprobe=10`: Good balance, ~95% recall
- `nprobe=nlist`: Exact search (same as Flat)
## HNSW indices (graph-based)
### IndexHNSWFlat - Hierarchical NSW
```python
# HNSW index
M = 32 # Number of connections per layer (16-64)
index = faiss.IndexHNSWFlat(d, M)
# Optional: Set ef_construction (build time parameter)
index.hnsw.efConstruction = 40 # Higher = better quality, slower build
# Add vectors (no training needed!)
index.add(vectors)
# Search
index.hnsw.efSearch = 16 # Search time parameter
distances, indices = index.search(query, k)
```
**Parameters:**
- `M`: Connections per layer (16-64, default 32)
- `efConstruction`: Build quality (40-200, higher = better)
- `efSearch`: Search quality (16-512, higher = more accurate)
**Use when:**
- Need best quality approximate search
- Can afford higher memory (more connections)
- Dataset 1M-10M vectors
## PQ indices (product quantization)
### IndexPQ - Memory-efficient
```python
# PQ reduces memory by 16-32×
m = 8 # Number of subquantizers (divides d)
nbits = 8 # Bits per subquantizer
index = faiss.IndexPQ(d, m, nbits)
# Train (required!)
index.train(vectors)
# Add vectors
index.add(vectors)
# Search
distances, indices = index.search(query, k)
```
**Parameters:**
- `m`: Subquantizers (d must be divisible by m)
- `nbits`: Bits per code (8 or 16)
**Memory savings:**
- Original: d × 4 bytes (float32)
- PQ: m bytes
- Compression ratio: 4d/m
**Use when:**
- Limited memory
- Large datasets (> 10M vectors)
- Can accept ~90-95% accuracy
### IndexIVFPQ - IVF + PQ combined
```python
# Best for very large datasets
nlist = 4096
m = 8
nbits = 8
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, nbits)
# Train
index.train(vectors)
index.add(vectors)
# Search
index.nprobe = 32
distances, indices = index.search(query, k)
```
**Use when:**
- Dataset > 10M vectors
- Need fast search + low memory
- Can accept 90-95% accuracy
## GPU indices
### Single GPU
```python
import faiss
# Create CPU index
index_cpu = faiss.IndexFlatL2(d)
# Move to GPU
res = faiss.StandardGpuResources() # GPU resources
index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu) # GPU 0
# Use normally
index_gpu.add(vectors)
distances, indices = index_gpu.search(query, k)
```
### Multi-GPU
```python
# Use all available GPUs
index_gpu = faiss.index_cpu_to_all_gpus(index_cpu)
# Or specific GPUs
gpus = [0, 1, 2, 3] # Use GPUs 0-3
index_gpu = faiss.index_cpu_to_gpus_list(index_cpu, gpus)
```
**Speedup:**
- Single GPU: 10-50× faster than CPU
- Multi-GPU: Near-linear scaling
## Index factory
```python
# Easy index creation with string descriptors
index = faiss.index_factory(d, "IVF100,Flat")
index = faiss.index_factory(d, "HNSW32")
index = faiss.index_factory(d, "IVF4096,PQ8")
# Train and use
index.train(vectors)
index.add(vectors)
```
**Common descriptors:**
- `"Flat"`: Exact search
- `"IVF100,Flat"`: IVF with 100 clusters
- `"HNSW32"`: HNSW with M=32
- `"IVF4096,PQ8"`: IVF + PQ compression
## Performance comparison
### Search speed (1M vectors, k=10)
| Index | Build Time | Search Time | Memory | Recall |
|-------|------------|-------------|--------|--------|
| Flat | 0s | 50ms | 512 MB | 100% |
| IVF100 | 5s | 2ms | 512 MB | 95% |
| HNSW32 | 60s | 1ms | 1GB | 99% |
| IVF4096+PQ8 | 30s | 3ms | 32 MB | 90% |
*CPU (16 cores), 128-dim vectors*
## Best practices
1. **Start with Flat** - Baseline for comparison
2. **Use IVF for medium datasets** - Good balance
3. **Use HNSW for best quality** - If memory allows
4. **Add PQ for memory savings** - Large datasets
5. **GPU for > 100K vectors** - 10-50× speedup
6. **Tune nprobe/efSearch** - Trade-off speed/accuracy
7. **Train on representative data** - Better clustering
8. **Save trained indices** - Avoid retraining
## Resources
- **Wiki**: https://github.com/facebookresearch/faiss/wiki
- **Paper**: https://arxiv.org/abs/1702.08734

View file

@ -0,0 +1,370 @@
---
name: optimizing-attention-flash
description: Optimizes transformer attention with Flash Attention for 2-4x speedup and 10-20x memory reduction. Use when training/running transformers with long sequences (>512 tokens), encountering GPU memory issues with attention, or need faster inference. Supports PyTorch native SDPA, flash-attn library, H100 FP8, and sliding window attention.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [flash-attn, torch, transformers]
metadata:
hermes:
tags: [Optimization, Flash Attention, Attention Optimization, Memory Efficiency, Speed Optimization, Long Context, PyTorch, SDPA, H100, FP8, Transformers]
---
# Flash Attention - Fast Memory-Efficient Attention
## Quick start
Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation.
**PyTorch native (easiest, PyTorch 2.2+)**:
```python
import torch
import torch.nn.functional as F
q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
# Automatically uses Flash Attention if available
out = F.scaled_dot_product_attention(q, k, v)
```
**flash-attn library (more features)**:
```bash
pip install flash-attn --no-build-isolation
```
```python
from flash_attn import flash_attn_func
# q, k, v: [batch, seqlen, nheads, headdim]
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
```
## Common workflows
### Workflow 1: Enable in existing PyTorch model
Copy this checklist:
```
Flash Attention Integration:
- [ ] Step 1: Check PyTorch version (≥2.2)
- [ ] Step 2: Enable Flash Attention backend
- [ ] Step 3: Verify speedup with profiling
- [ ] Step 4: Test accuracy matches baseline
```
**Step 1: Check PyTorch version**
```bash
python -c "import torch; print(torch.__version__)"
# Should be ≥2.2.0
```
If <2.2, upgrade:
```bash
pip install --upgrade torch
```
**Step 2: Enable Flash Attention backend**
Replace standard attention:
```python
# Before (standard attention)
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
out = attn_weights @ v
# After (Flash Attention)
import torch.nn.functional as F
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
```
Force Flash Attention backend:
```python
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
out = F.scaled_dot_product_attention(q, k, v)
```
**Step 3: Verify speedup with profiling**
```python
import torch.utils.benchmark as benchmark
def test_attention(use_flash):
q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
if use_flash:
with torch.backends.cuda.sdp_kernel(enable_flash=True):
return F.scaled_dot_product_attention(q, k, v)
else:
attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
return attn @ v
# Benchmark
t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
print(f"Standard: {t_standard.timeit(100).mean:.3f}s")
```
Expected: 2-4x speedup for sequences >512 tokens.
**Step 4: Test accuracy matches baseline**
```python
# Compare outputs
q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
# Flash Attention
out_flash = F.scaled_dot_product_attention(q, k, v)
# Standard attention
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
out_standard = attn_weights @ v
# Check difference
diff = (out_flash - out_standard).abs().max()
print(f"Max difference: {diff:.6f}")
# Should be <1e-3 for float16
```
### Workflow 2: Use flash-attn library for advanced features
For multi-query attention, sliding window, or H100 FP8.
Copy this checklist:
```
flash-attn Library Setup:
- [ ] Step 1: Install flash-attn library
- [ ] Step 2: Modify attention code
- [ ] Step 3: Enable advanced features
- [ ] Step 4: Benchmark performance
```
**Step 1: Install flash-attn library**
```bash
# NVIDIA GPUs (CUDA 12.0+)
pip install flash-attn --no-build-isolation
# Verify installation
python -c "from flash_attn import flash_attn_func; print('Success')"
```
**Step 2: Modify attention code**
```python
from flash_attn import flash_attn_func
# Input: [batch_size, seq_len, num_heads, head_dim]
# Transpose from [batch, heads, seq, dim] if needed
q = q.transpose(1, 2) # [batch, seq, heads, dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = flash_attn_func(
q, k, v,
dropout_p=0.1,
causal=True, # For autoregressive models
window_size=(-1, -1), # No sliding window
softmax_scale=None # Auto-scale
)
out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]
```
**Step 3: Enable advanced features**
Multi-query attention (shared K/V across heads):
```python
from flash_attn import flash_attn_func
# q: [batch, seq, num_q_heads, dim]
# k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads
out = flash_attn_func(q, k, v) # Automatically handles MQA
```
Sliding window attention (local attention):
```python
# Only attend to window of 256 tokens before/after
out = flash_attn_func(
q, k, v,
window_size=(256, 256), # (left, right) window
causal=True
)
```
**Step 4: Benchmark performance**
```python
import torch
from flash_attn import flash_attn_func
import time
q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
# Warmup
for _ in range(10):
_ = flash_attn_func(q, k, v)
# Benchmark
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
out = flash_attn_func(q, k, v)
torch.cuda.synchronize()
end = time.time()
print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")
print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
```
### Workflow 3: H100 FP8 optimization (FlashAttention-3)
For maximum performance on H100 GPUs.
```
FP8 Setup:
- [ ] Step 1: Verify H100 GPU available
- [ ] Step 2: Install flash-attn with FP8 support
- [ ] Step 3: Convert inputs to FP8
- [ ] Step 4: Run with FP8 attention
```
**Step 1: Verify H100 GPU**
```bash
nvidia-smi --query-gpu=name --format=csv
# Should show "H100" or "H800"
```
**Step 2: Install flash-attn with FP8 support**
```bash
pip install flash-attn --no-build-isolation
# FP8 support included for H100
```
**Step 3: Convert inputs to FP8**
```python
import torch
q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
# Convert to float8_e4m3 (FP8)
q_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v_fp8 = v.to(torch.float8_e4m3fn)
```
**Step 4: Run with FP8 attention**
```python
from flash_attn import flash_attn_func
# FlashAttention-3 automatically uses FP8 kernels on H100
out = flash_attn_func(q_fp8, k_fp8, v_fp8)
# Result: ~1.2 PFLOPS, 1.5-2x faster than FP16
```
## When to use vs alternatives
**Use Flash Attention when:**
- Training transformers with sequences >512 tokens
- Running inference with long context (>2K tokens)
- GPU memory constrained (OOM with standard attention)
- Need 2-4x speedup without accuracy loss
- Using PyTorch 2.2+ or can install flash-attn
**Use alternatives instead:**
- **Standard attention**: Sequences <256 tokens (overhead not worth it)
- **xFormers**: Need more attention variants (not just speed)
- **Memory-efficient attention**: CPU inference (Flash Attention needs GPU)
## Common issues
**Issue: ImportError: cannot import flash_attn**
Install with no-build-isolation flag:
```bash
pip install flash-attn --no-build-isolation
```
Or install CUDA toolkit first:
```bash
conda install cuda -c nvidia
pip install flash-attn --no-build-isolation
```
**Issue: Slower than expected (no speedup)**
Flash Attention benefits increase with sequence length:
- <512 tokens: Minimal speedup (10-20%)
- 512-2K tokens: 2-3x speedup
- >2K tokens: 3-4x speedup
Check sequence length is sufficient.
**Issue: RuntimeError: CUDA error**
Verify GPU supports Flash Attention:
```python
import torch
print(torch.cuda.get_device_capability())
# Should be ≥(7, 5) for Turing+
```
Flash Attention requires:
- Ampere (A100, A10): ✅ Full support
- Turing (T4): ✅ Supported
- Volta (V100): ❌ Not supported
**Issue: Accuracy degradation**
Check dtype is float16 or bfloat16 (not float32):
```python
q = q.to(torch.float16) # Or torch.bfloat16
```
Flash Attention uses float16/bfloat16 for speed. Float32 not supported.
## Advanced topics
**Integration with HuggingFace Transformers**: See [references/transformers-integration.md](references/transformers-integration.md) for enabling Flash Attention in BERT, GPT, Llama models.
**Performance benchmarks**: See [references/benchmarks.md](references/benchmarks.md) for detailed speed and memory comparisons across GPUs and sequence lengths.
**Algorithm details**: See [references/algorithm.md](references/algorithm.md) for tiling strategy, recomputation, and IO complexity analysis.
**Advanced features**: See [references/advanced-features.md](references/advanced-features.md) for rotary embeddings, ALiBi, paged KV cache, and custom attention masks.
## Hardware requirements
- **GPU**: NVIDIA Ampere+ (A100, A10, A30) or AMD MI200+
- **VRAM**: Same as standard attention (Flash Attention doesn't increase memory)
- **CUDA**: 12.0+ (11.8 minimum)
- **PyTorch**: 2.2+ for native support
**Not supported**: V100 (Volta), CPU inference
## Resources
- Paper: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (NeurIPS 2022)
- Paper: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (ICLR 2024)
- Blog: https://tridao.me/blog/2024/flash3/
- GitHub: https://github.com/Dao-AILab/flash-attention
- PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

View file

@ -0,0 +1,215 @@
# Performance Benchmarks
## Contents
- Speed comparisons across GPUs
- Memory usage analysis
- Scaling with sequence length
- Training vs inference performance
- Flash Attention versions comparison
## Speed comparisons across GPUs
### A100 80GB (Ampere)
**Forward pass time** (milliseconds, batch=8, heads=32, dim=64):
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 | Speedup (FA2) |
|------------|----------|--------------|--------------|---------------|
| 512 | 1.2 | 0.9 | N/A | 1.3x |
| 1024 | 3.8 | 1.4 | N/A | 2.7x |
| 2048 | 14.2 | 4.8 | N/A | 3.0x |
| 4096 | 55.1 | 17.3 | N/A | 3.2x |
| 8192 | 218.5 | 66.2 | N/A | 3.3x |
### H100 80GB (Hopper)
**Forward pass time** (milliseconds, same config):
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 (FP16) | Flash Attn 3 (FP8) | Best Speedup |
|------------|----------|--------------|---------------------|--------------------|--------------|
| 512 | 0.8 | 0.6 | 0.4 | 0.3 | 2.7x |
| 1024 | 2.6 | 1.0 | 0.6 | 0.4 | 6.5x |
| 2048 | 9.8 | 3.4 | 2.0 | 1.3 | 7.5x |
| 4096 | 38.2 | 12.5 | 7.2 | 4.8 | 8.0x |
| 8192 | 151.4 | 47.8 | 27.1 | 18.2 | 8.3x |
**Key insight**: Flash Attention 3 on H100 with FP8 achieves ~1.2 PFLOPS (75% of theoretical max).
### A10G 24GB (Ampere)
**Forward pass time** (milliseconds, batch=4):
| Seq Length | Standard | Flash Attn 2 | Speedup |
|------------|----------|--------------|---------|
| 512 | 2.1 | 1.6 | 1.3x |
| 1024 | 6.8 | 2.8 | 2.4x |
| 2048 | 25.9 | 9.4 | 2.8x |
| 4096 | 102.1 | 35.2 | 2.9x |
## Memory usage analysis
### GPU memory consumption (batch=8, heads=32, dim=64)
**Standard attention memory**:
| Seq Length | Attention Matrix | KV Cache | Total | Notes |
|------------|------------------|----------|-------|-------|
| 512 | 8 MB | 32 MB | 40 MB | Manageable |
| 2048 | 128 MB | 128 MB | 256 MB | Growing |
| 8192 | 2048 MB (2 GB) | 512 MB | 2.5 GB | Large |
| 32768 | 32768 MB (32 GB) | 2048 MB | 34 GB | OOM on 24GB GPUs |
**Flash Attention 2 memory**:
| Seq Length | Attention (on-chip) | KV Cache | Total | Reduction |
|------------|---------------------|----------|-------|-----------|
| 512 | 0 MB (recomputed) | 32 MB | 32 MB | 20% |
| 2048 | 0 MB | 128 MB | 128 MB | 50% |
| 8192 | 0 MB | 512 MB | 512 MB | 80% |
| 32768 | 0 MB | 2048 MB | 2 GB | 94% |
**Key insight**: Flash Attention doesn't materialize attention matrix, saving O(N²) memory.
### Memory scaling comparison
**Llama 2 7B model memory** (float16, batch=1):
| Context Length | Standard Attention | Flash Attention 2 | Can Fit 24GB GPU? |
|----------------|-------------------|-------------------|-------------------|
| 2K | 3.2 GB | 2.1 GB | Both: Yes |
| 4K | 5.8 GB | 2.8 GB | Both: Yes |
| 8K | 12.1 GB | 4.2 GB | Both: Yes |
| 16K | 26.3 GB (OOM) | 7.8 GB | Only Flash: Yes |
| 32K | OOM | 14.2 GB | Only Flash: Yes |
### Training memory (Llama 2 7B, batch=4)
| Context | Standard (GB) | Flash Attn (GB) | Reduction |
|---------|---------------|-----------------|-----------|
| 2K | 18.2 | 12.4 | 32% |
| 4K | 34.8 | 16.8 | 52% |
| 8K | OOM (>40GB) | 26.2 | Fits! |
## Scaling with sequence length
### Computational complexity
**Standard attention**:
- Time: O(N² × d)
- Memory: O(N² + N × d)
**Flash Attention**:
- Time: O(N² × d) (same, but with better constants)
- Memory: O(N × d) (linear!)
### Empirical scaling (A100, batch=1, heads=32, dim=64)
**Time per token (milliseconds)**:
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|----------|-----|-----|-----|-----|-----|------|
| Standard | 0.15 | 0.37 | 1.11 | 3.44 | 13.4 | 52.8 |
| Flash Attn 2 | 0.11 | 0.14 | 0.24 | 0.43 | 0.83 | 1.64 |
| Speedup | 1.4x | 2.6x | 4.6x | 8.0x | 16.1x | 32.2x |
**Observation**: Speedup increases quadratically with sequence length!
### Memory per token (MB)
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|----------|-----|-----|-----|-----|-----|------|
| Standard | 0.08 | 0.13 | 0.25 | 0.64 | 2.05 | 8.13 |
| Flash Attn 2 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 |
**Observation**: Flash Attention memory per token is constant!
## Training vs inference performance
### Training (forward + backward, Llama 2 7B, A100)
| Batch × Seq | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|-------------|------------------------|--------------------------|---------|
| 4 × 2K | 1.2 | 3.1 | 2.6x |
| 8 × 2K | 2.1 | 5.8 | 2.8x |
| 4 × 4K | 0.4 | 1.3 | 3.3x |
| 8 × 4K | OOM | 2.4 | Enabled |
| 2 × 8K | 0.1 | 0.4 | 4.0x |
### Inference (generation, Llama 2 7B, A100)
| Context Length | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|----------------|----------------------|-------------------------|---------|
| 512 | 48 | 52 | 1.1x |
| 2K | 42 | 62 | 1.5x |
| 4K | 31 | 58 | 1.9x |
| 8K | 18 | 51 | 2.8x |
| 16K | OOM | 42 | Enabled |
**Note**: Inference speedup less dramatic than training because generation is memory-bound (KV cache accesses).
## Flash Attention versions comparison
### Flash Attention 1 vs 2 vs 3 (H100, seq=4096, batch=8)
| Metric | FA1 | FA2 | FA3 (FP16) | FA3 (FP8) |
|--------|-----|-----|------------|-----------|
| Forward time (ms) | 28.4 | 12.5 | 7.2 | 4.8 |
| Memory (GB) | 4.8 | 4.2 | 4.2 | 2.8 |
| TFLOPS | 180 | 420 | 740 | 1150 |
| GPU util % | 35% | 55% | 75% | 82% |
**Key improvements**:
- FA2: 2.3x faster than FA1 (better parallelism)
- FA3 (FP16): 1.7x faster than FA2 (H100 async optimizations)
- FA3 (FP8): 2.6x faster than FA2 (low precision)
### Features by version
| Feature | FA1 | FA2 | FA3 |
|---------|-----|-----|-----|
| Basic attention | ✅ | ✅ | ✅ |
| Causal masking | ✅ | ✅ | ✅ |
| Multi-query attention | ❌ | ✅ | ✅ |
| Sliding window | ❌ | ✅ | ✅ |
| Paged KV cache | ❌ | ✅ | ✅ |
| FP8 support | ❌ | ❌ | ✅ (H100 only) |
| Work partitioning | Basic | Advanced | Optimal |
## Real-world model benchmarks
### Llama 2 models (A100 80GB, batch=4, seq=2048)
| Model | Params | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|-------|--------|------------------------|--------------------------|---------|
| Llama 2 7B | 7B | 1.2 | 3.1 | 2.6x |
| Llama 2 13B | 13B | 0.6 | 1.7 | 2.8x |
| Llama 2 70B | 70B | 0.12 | 0.34 | 2.8x |
### GPT-style models (seq=1024)
| Model | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|-------|----------------------|-------------------------|---------|
| GPT-2 (124M) | 520 | 680 | 1.3x |
| GPT-J (6B) | 42 | 98 | 2.3x |
| GPT-NeoX (20B) | 8 | 22 | 2.75x |
## Recommendations by use case
**Training large models (>7B parameters)**:
- Use Flash Attention 2 on A100
- Use Flash Attention 3 FP8 on H100 for maximum speed
- Expected: 2.5-3x speedup
**Long context inference (>4K tokens)**:
- Flash Attention essential (enables contexts standard attention can't handle)
- Expected: 2-4x speedup, 5-10x memory reduction
**Short sequences (<512 tokens)**:
- Flash Attention provides 1.2-1.5x speedup
- Minimal memory benefit
- Still worth enabling (no downside)
**Multi-user serving**:
- Flash Attention reduces per-request memory
- Allows higher concurrent batch sizes
- Can serve 2-3x more users on same hardware

View file

@ -0,0 +1,293 @@
# HuggingFace Transformers Integration
## Contents
- Enabling Flash Attention in Transformers
- Supported model architectures
- Configuration examples
- Performance comparisons
- Troubleshooting model-specific issues
## Enabling Flash Attention in Transformers
HuggingFace Transformers (v4.36+) supports Flash Attention 2 natively.
**Simple enable for any supported model**:
```python
from transformers import AutoModel
model = AutoModel.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
device_map="auto"
)
```
**Install requirements**:
```bash
pip install transformers>=4.36
pip install flash-attn --no-build-isolation
```
## Supported model architectures
As of Transformers 4.40:
**Fully supported**:
- Llama / Llama 2 / Llama 3
- Mistral / Mixtral
- Falcon
- GPT-NeoX
- Phi / Phi-2 / Phi-3
- Qwen / Qwen2
- Gemma
- Starcoder2
- GPT-J
- OPT
- BLOOM
**Partially supported** (encoder-decoder):
- BART
- T5 / Flan-T5
- Whisper
**Check support**:
```python
from transformers import AutoConfig
config = AutoConfig.from_pretrained("model-name")
print(config._attn_implementation_internal)
# 'flash_attention_2' if supported
```
## Configuration examples
### Llama 2 with Flash Attention
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_id = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(
model_id,
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Generate
inputs = tokenizer("Once upon a time", return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_length=100)
print(tokenizer.decode(outputs[0]))
```
### Mistral with Flash Attention for long context
```python
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16, # Better for long context
device_map="auto",
max_position_embeddings=32768 # Extended context
)
# Process long document (32K tokens)
long_text = "..." * 10000
inputs = tokenizer(long_text, return_tensors="pt", truncation=False).to("cuda")
outputs = model.generate(**inputs, max_new_tokens=512)
```
### Fine-tuning with Flash Attention
```python
from transformers import Trainer, TrainingArguments
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16
)
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
num_train_epochs=3,
fp16=True, # Must match model dtype
optim="adamw_torch_fused" # Fast optimizer
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset
)
trainer.train()
```
### Multi-GPU training
```python
from transformers import AutoModelForCausalLM
import torch
# Model parallelism with Flash Attention
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-13b-hf",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
device_map="auto", # Automatic multi-GPU placement
max_memory={0: "20GB", 1: "20GB"} # Limit per GPU
)
```
## Performance comparisons
### Memory usage (Llama 2 7B, batch=1)
| Sequence Length | Standard Attention | Flash Attention 2 | Reduction |
|-----------------|-------------------|-------------------|-----------|
| 512 | 1.2 GB | 0.9 GB | 25% |
| 2048 | 3.8 GB | 1.4 GB | 63% |
| 8192 | 14.2 GB | 3.2 GB | 77% |
| 32768 | OOM (>24GB) | 10.8 GB | Fits! |
### Speed (tokens/sec, A100 80GB)
| Model | Standard | Flash Attn 2 | Speedup |
|-------|----------|--------------|---------|
| Llama 2 7B (seq=2048) | 42 | 118 | 2.8x |
| Llama 2 13B (seq=4096) | 18 | 52 | 2.9x |
| Llama 2 70B (seq=2048) | 4 | 11 | 2.75x |
### Training throughput (samples/sec)
| Model | Batch Size | Standard | Flash Attn 2 | Speedup |
|-------|------------|----------|--------------|---------|
| Llama 2 7B | 4 | 1.2 | 3.1 | 2.6x |
| Llama 2 7B | 8 | 2.1 | 5.8 | 2.8x |
| Llama 2 13B | 2 | 0.6 | 1.7 | 2.8x |
## Troubleshooting model-specific issues
### Issue: Model doesn't support Flash Attention
Check support list above. If not supported, use PyTorch SDPA as fallback:
```python
model = AutoModelForCausalLM.from_pretrained(
"model-name",
attn_implementation="sdpa", # PyTorch native (still faster)
torch_dtype=torch.float16
)
```
### Issue: CUDA out of memory during loading
Reduce memory footprint:
```python
model = AutoModelForCausalLM.from_pretrained(
"model-name",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
device_map="auto",
max_memory={0: "18GB"}, # Reserve memory for KV cache
low_cpu_mem_usage=True
)
```
### Issue: Slower inference than expected
Ensure dtype matches:
```python
# Model and inputs must both be float16/bfloat16
model = model.to(torch.float16)
inputs = tokenizer(..., return_tensors="pt").to("cuda")
inputs = {k: v.to(torch.float16) if v.dtype == torch.float32 else v
for k, v in inputs.items()}
```
### Issue: Different outputs vs standard attention
Flash Attention is numerically equivalent but uses different computation order. Small differences (<1e-3) are normal:
```python
# Compare outputs
model_standard = AutoModelForCausalLM.from_pretrained("model-name", torch_dtype=torch.float16)
model_flash = AutoModelForCausalLM.from_pretrained(
"model-name",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16
)
inputs = tokenizer("Test", return_tensors="pt").to("cuda")
with torch.no_grad():
out_standard = model_standard(**inputs).logits
out_flash = model_flash(**inputs).logits
diff = (out_standard - out_flash).abs().max()
print(f"Max diff: {diff:.6f}") # Should be ~1e-3 to 1e-4
```
### Issue: ImportError during model loading
Install flash-attn:
```bash
pip install flash-attn --no-build-isolation
```
Or disable Flash Attention:
```python
model = AutoModelForCausalLM.from_pretrained(
"model-name",
attn_implementation="eager", # Standard PyTorch
torch_dtype=torch.float16
)
```
## Best practices
1. **Always use float16/bfloat16** with Flash Attention (not float32)
2. **Set device_map="auto"** for automatic memory management
3. **Use bfloat16 for long context** (better numerical stability)
4. **Enable gradient checkpointing** for training large models
5. **Monitor memory** with `torch.cuda.max_memory_allocated()`
**Example with all best practices**:
```python
from transformers import AutoModelForCausalLM, TrainingArguments
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16, # Better for training
device_map="auto",
low_cpu_mem_usage=True
)
# Enable gradient checkpointing for memory
model.gradient_checkpointing_enable()
# Training with optimizations
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=8,
gradient_accumulation_steps=2,
bf16=True, # Match model dtype
optim="adamw_torch_fused",
gradient_checkpointing=True
)
```

View file

@ -0,0 +1,302 @@
---
name: hermes-atropos-environments
description: Build, test, and debug Hermes Agent RL environments for Atropos training. Covers the HermesAgentBaseEnv interface, reward functions, agent loop integration, evaluation with tools, wandb logging, and the three CLI modes (serve/process/evaluate). Use when creating, reviewing, or fixing RL environments in the hermes-agent repo.
version: 1.1.0
author: Hermes Agent
license: MIT
metadata:
hermes:
tags: [atropos, rl, environments, training, reinforcement-learning, reward-functions]
related_skills: [axolotl, grpo-rl-training, trl-fine-tuning, lm-evaluation-harness]
---
# Hermes Agent Atropos Environments
Guide for building RL environments in the hermes-agent repo that integrate with the Atropos training framework.
## Architecture Overview
```
Atropos BaseEnv (atroposlib/envs/base.py)
└── HermesAgentBaseEnv (environments/hermes_base_env.py)
├── Handles agent loop orchestration
├── Handles tool resolution per group
├── Handles ToolContext for reward verification
└── YOUR ENVIRONMENT (environments/your_env.py)
Only implements: setup, get_next_item, format_prompt,
compute_reward, evaluate, wandb_log
```
Hermes environments are special because they run a **multi-turn agent loop with tool calling** — not just single-turn completions. The base env handles the loop; you implement the task and scoring.
## File Locations
| File | Purpose |
|------|---------|
| `environments/hermes_base_env.py` | Base class with agent loop + tool resolution |
| `environments/agent_loop.py` | `HermesAgentLoop` + `AgentResult` dataclass |
| `environments/tool_context.py` | `ToolContext` for reward verification |
| `environments/tool_call_parsers.py` | Phase 2 tool call parsers (hermes, mistral, etc.) |
| `environments/your_env.py` | Your environment implementation |
## Inference Setup — Ask the User First
**IMPORTANT:** Before running any test, evaluation, or data generation command, always ask the user how they want to handle inference. Do NOT assume OpenRouter or any specific endpoint. Present these options:
1. **OpenRouter** — Ask which model they want to use (e.g., `anthropic/claude-sonnet-4.5`, `google/gemini-2.5-pro`, `meta-llama/llama-3.3-70b-instruct`, etc.). Requires `OPENROUTER_API_KEY` in environment.
2. **Self-hosted VLLM endpoint** — Ask for their base URL (e.g., `http://localhost:8000/v1`) and model name. Set `--openai.server_type vllm`.
3. **Other OpenAI-compatible API** — Ask for the base URL, model name, and any required API key. Set `--openai.server_type openai` and `--openai.health_check false`.
4. **Local Atropos training server** — For `serve` mode with a live training loop. Default `http://localhost:8000/v1`.
Once the user tells you their setup, use those values in all CLI commands for that session. Example prompts:
> "Before I run this, how would you like to handle inference?
> 1. OpenRouter (I'll need your preferred model, e.g. claude-sonnet-4.5)
> 2. A self-hosted VLLM endpoint (give me the URL and model name)
> 3. Another OpenAI-compatible API (give me the URL, model, and any auth details)
> 4. Local Atropos training server (serve mode)"
### Key flags by provider:
| Provider | `--openai.server_type` | `--openai.health_check` | `--openai.api_key` |
|----------|----------------------|------------------------|-------------------|
| OpenRouter | `openai` | `false` | `$OPENROUTER_API_KEY` |
| VLLM (self-hosted) | `vllm` | (default) | (not needed) |
| Other OpenAI-compatible | `openai` | `false` | As needed |
| Local Atropos | (default) | (default) | (not needed) |
## Required Methods
### 1. `setup()` — Load dataset and initialize state
```python
async def setup(self) -> None:
"""Called once at startup. Load datasets, initialize state."""
# Try HuggingFace first, fallback to built-in samples
try:
from datasets import load_dataset
ds = load_dataset("your/dataset", split="test")
self._items = [...]
except Exception:
self._items = BUILTIN_SAMPLES
# Always split into train/eval
random.shuffle(self._items)
eval_size = max(20, int(len(self._items) * 0.1))
self._eval_items = self._items[:eval_size]
self._items = self._items[eval_size:]
```
### 2. `get_next_item()` — Return next training item
```python
async def get_next_item(self) -> dict:
"""Return next item, cycling through dataset."""
item = self._items[self._index % len(self._items)]
self._index += 1
return item
```
### 3. `format_prompt(item)` — Convert item to user message
```python
def format_prompt(self, item: dict) -> str:
"""Convert a dataset item into the user-facing prompt."""
return f"Research this question: {item['question']}"
```
### 4. `compute_reward(item, result, ctx)` — Score the rollout
**CRITICAL**: `result` is an `AgentResult`, NOT a dict. It has these attributes:
- `result.messages` — List of message dicts (OpenAI format)
- `result.turns_used` — Number of LLM calls made
- `result.finished_naturally` — True if model stopped voluntarily
- `result.tool_errors` — List of ToolError objects
**AgentResult does NOT have**: `final_response`, `tool_calls`, `tools_used`.
You must extract these from `result.messages`:
```python
async def compute_reward(self, item, result: AgentResult, ctx: ToolContext) -> float:
# Extract final response (last assistant message with content)
final_response = ""
tools_used = []
for msg in reversed(result.messages):
if msg.get("role") == "assistant" and msg.get("content") and not final_response:
final_response = msg["content"]
if msg.get("role") == "assistant" and msg.get("tool_calls"):
for tc in msg["tool_calls"]:
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
name = fn.get("name", "")
if name:
tools_used.append(name)
# Score using LLM judge, heuristic, or ToolContext verification
correctness = await self._llm_judge(item, final_response)
return correctness
```
`ctx` (ToolContext) gives you terminal/file access to the agent's sandbox for verification:
```python
# Run tests in the agent's sandbox
result = ctx.terminal("pytest /workspace/test.py")
return 1.0 if result["exit_code"] == 0 else 0.0
```
### 5. `evaluate()` — Periodic evaluation with full agent loop
**MUST use the full agent loop with tools**, not single-turn chat_completion.
The whole point of hermes-agent environments is agentic evaluation:
```python
async def evaluate(self, *args, **kwargs) -> None:
import time, uuid
from environments.agent_loop import HermesAgentLoop
from environments.tool_context import ToolContext
start_time = time.time()
tools, valid_names = self._resolve_tools_for_group()
samples = []
for item in self._eval_items[:self.config.eval_size]:
task_id = str(uuid.uuid4())
messages = []
if self.config.system_prompt:
messages.append({"role": "system", "content": self.config.system_prompt})
messages.append({"role": "user", "content": self.format_prompt(item)})
agent = HermesAgentLoop(
server=self.server,
tool_schemas=tools,
valid_tool_names=valid_names,
max_turns=self.config.max_agent_turns,
task_id=task_id,
temperature=0.0, # Deterministic for eval
max_tokens=self.config.max_token_length,
extra_body=self.config.extra_body,
)
result = await agent.run(messages)
ctx = ToolContext(task_id)
try:
reward = await self.compute_reward(item, result, ctx)
finally:
ctx.cleanup()
samples.append({"prompt": ..., "response": ..., "reward": reward})
eval_metrics = {"eval/mean_reward": ...}
await self.evaluate_log(metrics=eval_metrics, samples=samples,
start_time=start_time, end_time=time.time())
```
### 6. `wandb_log()` — Custom metrics logging
Always call `super().wandb_log()` at the end:
```python
async def wandb_log(self, wandb_metrics=None):
if wandb_metrics is None:
wandb_metrics = {}
if self._reward_buffer:
n = len(self._reward_buffer)
wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n
self._reward_buffer.clear()
await super().wandb_log(wandb_metrics) # MUST call super
```
**Pitfall**: `compute_reward` appends to metric buffers. During eval, this pollutes training metrics. Roll back buffer entries added during eval.
## Config Class
Always create a custom config subclass with Pydantic Field descriptors. Key inherited fields you can tune: `enabled_toolsets`, `max_agent_turns`, `agent_temperature`, `system_prompt`, `terminal_backend`, `group_size`, `steps_per_eval`, `total_steps`.
## config_init() — Default Configuration
Classmethod returning `(YourEnvConfig, [APIServerConfig(...)])`. Set server_type to "openai" for OpenRouter/external APIs. Load API key from environment variable.
## Three CLI Modes
```bash
# SERVE — Full training loop (connects to Atropos API server)
python environments/my_env.py serve --openai.base_url http://localhost:8000/v1
# PROCESS — Offline data generation (saves JSONL)
python environments/my_env.py process --env.total_steps 10 --env.group_size 1 \
--env.use_wandb false --env.data_path_to_save_groups output.jsonl \
--openai.base_url "<USER_BASE_URL>" \
--openai.model_name "<USER_MODEL>" \
--openai.server_type <USER_SERVER_TYPE> --openai.health_check false
# EVALUATE — Standalone eval (runs setup + evaluate only)
python environments/my_env.py evaluate --env.eval_size 20 \
--env.data_dir_to_save_evals /tmp/eval_results \
--openai.base_url "<USER_BASE_URL>" \
--openai.model_name "<USER_MODEL>" \
--openai.server_type <USER_SERVER_TYPE> --openai.health_check false
```
Config priority: CLI args > YAML file > config_init() defaults.
## Common Pitfalls
1. **AgentResult has .messages, not .final_response** — Extract the final response by iterating reversed(result.messages) looking for the last assistant message with content.
2. **evaluate() must use HermesAgentLoop, not chat_completion** — Single-turn chat_completion has no tools. The whole point of hermes-agent benchmarks is agentic evaluation with tool use.
3. **Don't call _llm_judge twice** — If compute_reward already calls it, extract the score from the buffer instead of calling judge separately in evaluate().
4. **Eval pollutes training buffers** — compute_reward appends to metric buffers. During eval, roll back buffer entries to keep training metrics clean.
5. **Always set health_check=false for OpenRouter** — OpenRouter has no /health endpoint.
6. **Set data_dir_to_save_evals in evaluate mode** — Without it, results aren't saved.
7. **default_toolsets class variable vs enabled_toolsets config** — The class variable is a hint; the config field is what actually controls tool resolution.
8. **Tool call parsing in messages** — Tool calls are dicts with `{"function": {"name": ..., "arguments": ...}}`. Always check `isinstance(tc, dict)`.
9. **ToolContext.cleanup()** — Always call in a finally block to release sandbox resources.
10. **server_type must be "openai" for external APIs** — Without it, Atropos assumes a local VLLM server.
11. **Always ask the user for their inference setup** — Never hardcode or assume a specific provider/model. See the "Inference Setup" section above.
## Reward Function Patterns
### LLM Judge (for open-ended tasks)
Use `self.server.chat_completion()` with a scoring prompt. Parse JSON response for score float. Always include a heuristic fallback (keyword overlap) for when the judge call fails.
### Binary Verification (for code/terminal tasks)
Use `ctx.terminal("pytest test.py -q")` to run tests in the agent's sandbox. Return 1.0 for pass, 0.0 for fail.
### Multi-Signal (combine multiple indicators)
Weight correctness (0.6) + tool usage (0.2) + efficiency (0.2) + optional bonuses. Clamp to [0, 1].
## Testing Your Environment
1. **Import test**: `python -c "from environments.my_env import MyEnv; print('OK')"`
2. **Ask the user for inference setup** (see "Inference Setup" section above)
3. **Process mode** (1 item): Verify JSONL output has valid tokens, masks, scores
4. **Evaluate mode**: Verify full agent loop runs with tools, metrics logged correctly
5. **Check reward range**: Scores should be in [0, 1], not all identical
## Minimum Implementation Checklist
```python
class MyEnv(HermesAgentBaseEnv):
name = "my-env"
env_config_cls = MyEnvConfig
@classmethod
def config_init(cls): ... # Default server + env config
async def setup(self): ... # Load dataset + train/eval split
async def get_next_item(self): ... # Cycle through training items
def format_prompt(self, item): ... # Item → user message string
async def compute_reward(self, item, result, ctx): ... # Score rollout
async def evaluate(self, *args, **kwargs): ... # Full agent loop eval
async def wandb_log(self, metrics=None): ... # Custom metrics + super()
if __name__ == "__main__":
MyEnv.cli()
```

View file

@ -0,0 +1,59 @@
# AgentResult Fields Reference
`AgentResult` is defined in `environments/agent_loop.py` as a dataclass.
## Fields
| Field | Type | Description |
|-------|------|-------------|
| `messages` | `List[Dict[str, Any]]` | Full conversation history in OpenAI message format |
| `managed_state` | `Optional[Dict]` | ManagedServer.get_state() if Phase 2, else None |
| `turns_used` | `int` | Number of LLM calls made during the loop |
| `finished_naturally` | `bool` | True if model stopped calling tools on its own |
| `reasoning_per_turn` | `List[Optional[str]]` | Extracted reasoning content per turn |
| `tool_errors` | `List[ToolError]` | Tool errors encountered during the loop |
## ToolError Fields
| Field | Type | Description |
|-------|------|-------------|
| `turn` | `int` | Which turn the error occurred |
| `tool_name` | `str` | Name of the tool that failed |
| `arguments` | `str` | Arguments passed to the tool |
| `error` | `str` | Error message |
| `tool_result` | `str` | The result returned to the model |
## Extracting Data from Messages
Messages follow OpenAI format. Common patterns:
```python
# Get final assistant response
for msg in reversed(result.messages):
if msg.get("role") == "assistant" and msg.get("content"):
final_response = msg["content"]
break
# Get all tool names used
tools = []
for msg in result.messages:
if msg.get("role") == "assistant" and msg.get("tool_calls"):
for tc in msg["tool_calls"]:
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
tools.append(fn.get("name", ""))
# Get tool results
for msg in result.messages:
if msg.get("role") == "tool":
tool_output = msg.get("content", "")
call_id = msg.get("tool_call_id", "")
```
## Fields that DO NOT EXIST
These are common mistakes — AgentResult does NOT have:
- `final_response` — extract from messages
- `tool_calls` — extract from messages
- `tools_used` — extract from messages
- `output` — extract from messages
- `response` — extract from messages

View file

@ -0,0 +1,65 @@
# Atropos BaseEnv Reference
Source: `atroposlib/envs/base.py` (~2124 lines)
## Abstract Methods (MUST implement)
| Method | Signature | Description |
|--------|-----------|-------------|
| `get_next_item()` | `async def get_next_item(self) -> Item` | Return next item for trajectory. Return None to pause. |
| `evaluate()` | `async def evaluate(self, *args, **kwargs)` | Called every steps_per_eval steps. |
| `setup()` | `async def setup(self)` | Called once at start. Load datasets, init models. |
| `collect_trajectory()` | `async def collect_trajectory(self, item) -> Tuple[Optional[ScoredDataItem], List[Item]]` | Single rollout. Or override collect_trajectories instead. |
## Overridable Methods
| Method | Default Behavior | Override When |
|--------|-----------------|---------------|
| `collect_trajectories()` | Runs collect_trajectory group_size times in parallel | Batch generation, MCTS, coupled rollouts |
| `wandb_log()` | Logs completion lengths, rollout table, perf stats | Add custom metrics (always call super) |
| `config_init()` | Returns (env_config_cls(), ServerBaseline()) | Custom defaults + server configs |
| `postprocess_histories()` | Passthrough | Final processing before sending to trainer |
| `save_checkpoint()` | Saves JSON to checkpoint_dir | Custom serialization |
| `cleanup()` | No-op | Release resources after each rollout |
## ScoredDataGroup Structure
```python
ScoredDataGroup = TypedDict with:
tokens: List[List[int]] # Token IDs per rollout
masks: List[List[int]] # -100=prompt, token_id=completion
scores: List[float] # Score per rollout
advantages: Optional[...] # Per-token advantages
ref_logprobs: Optional[...] # Reference model logprobs
messages: Optional[...] # OpenAI-format messages
inference_logprobs: Optional[...] # Inference logprobs
```
## BaseEnvConfig Key Fields
| Field | Default | Description |
|-------|---------|-------------|
| `group_size` | 4 | Responses grouped for scoring |
| `steps_per_eval` | 100 | Steps between evaluations |
| `max_token_length` | 2048 | Max token length for generations |
| `total_steps` | 1000 | Total training steps |
| `use_wandb` | True | Enable wandb logging |
| `tokenizer_name` | DeepHermes-3 | Tokenizer for token encoding |
| `ensure_scores_are_not_same` | True | Skip groups with identical scores |
| `worker_timeout` | 600 | Task timeout seconds |
## Data Flow
```
env_manager() → add_train_workers() → handle_env()
→ collect_trajectories() → postprocess_histories()
→ handle_send_to_api() → training server
```
## Atropos Environment Statistics (82 environments analyzed)
- 95% implement setup, collect_trajectories, evaluate, get_next_item
- 76% override wandb_log
- 54% have custom config class
- Most use collect_trajectories (plural), not collect_trajectory (singular)
- Common reward patterns: LLM-judge (~40), regex-extract (~35), code-exec (~12)

View file

@ -0,0 +1,199 @@
# Usage Patterns — Testing Environments and Evaluating Models
## Pattern 1: Test Your Environment Works (process mode)
Use `process` mode to verify your environment runs end-to-end before
committing. This generates trajectories without needing an Atropos
training server.
**Before running:** Ask the user for their inference setup (see SKILL.md "Inference Setup" section). Replace `<BASE_URL>`, `<MODEL>`, and `<SERVER_TYPE>` below with their chosen values.
### Step 1: Run 1 trajectory
```bash
cd ~/.hermes/hermes-agent
source venv/bin/activate
python environments/your_env.py process \
--env.total_steps 1 \
--env.group_size 1 \
--env.use_wandb false \
--env.data_path_to_save_groups /tmp/test_output.jsonl \
--openai.base_url "<BASE_URL>" \
--openai.model_name "<MODEL>" \
--openai.server_type <SERVER_TYPE> \
--openai.health_check false
```
### Step 2: Verify the output
```python
import json
for line in open("/tmp/test_output.jsonl"):
data = json.loads(line)
print(f"Scores: {data.get('scores', [])}")
print(f"Token sequences: {len(data.get('tokens', []))}")
# Check messages include tool calls
for msg_list in data.get("messages", []):
roles = [m.get("role") for m in msg_list]
print(f"Roles: {roles}")
for m in reversed(msg_list):
if m.get("role") == "assistant" and m.get("content"):
print(f"Response: {m['content'][:200]}...")
break
```
### What to check:
- **Scores are not all 0.0** — if so, compute_reward is broken
- **Scores are in [0, 1]** — not negative, not >1
- **Messages include "tool" role entries** — agent used tools
- **Token sequences are non-empty**
- **An HTML visualization is generated** next to the .jsonl
### Common failures:
- `'AgentResult' object has no attribute 'X'` — accessing a field that doesn't exist. See agentresult-fields.md.
- Score always 0.0 — reward function erroring silently
- Score always 1.0 — verification too lenient or not running
## Pattern 2: Evaluate a Model (evaluate mode)
Use `evaluate` mode to benchmark a model on your environment's eval
split. This runs the full agent loop with tools for each eval item.
### Step 1: Run evaluation
```bash
python environments/your_env.py evaluate \
--env.eval_size 20 \
--env.use_wandb false \
--env.data_dir_to_save_evals /tmp/eval_results \
--openai.base_url "<BASE_URL>" \
--openai.model_name "<MODEL>" \
--openai.server_type <SERVER_TYPE> \
--openai.health_check false
```
### Step 2: Read results
Stdout shows a lighteval-compatible table:
```
Evaluation Results: your-env_eval
|Metric | Value|
|mean correctness| 0.850 |
|mean reward | 0.920 |
|mean tool calls | 4.300 |
|n items | 20 |
Evaluation completed in 367 seconds
```
JSON results saved to the eval directory:
```python
import json
data = json.load(open("/tmp/eval_results/metrics.json"))
for metric, value in data["results"]["all"].items():
print(f"{metric}: {value}")
```
### Step 3: Compare models
Run evaluate with different models and compare the metrics.json files.
### What to check:
- **"data_dir_to_save_evals is not set"** — you forgot the flag, results won't be saved
- **Tool usage rate = 0** — evaluate() is using chat_completion instead of HermesAgentLoop
- **All scores identical** — judge failing, falling back to heuristic
- **Very slow** — each item runs a full agent loop (~30-90s). Use `--env.eval_size 5` for quick checks.
## Pattern 3: Generate Training Data (process mode, larger scale)
Generate trajectory data for offline training or analysis:
```bash
python environments/your_env.py process \
--env.total_steps 50 \
--env.group_size 4 \
--env.use_wandb false \
--env.data_path_to_save_groups data/trajectories.jsonl \
--openai.base_url "<BASE_URL>" \
--openai.model_name "<MODEL>" \
--openai.server_type <SERVER_TYPE> \
--openai.health_check false
```
### Analyze the distribution:
```python
import json
scores = []
for line in open("data/trajectories.jsonl"):
data = json.loads(line)
scores.extend(data.get("scores", []))
print(f"Total: {len(scores)}, Mean: {sum(scores)/len(scores):.3f}")
for bucket in [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]:
count = sum(1 for s in scores if abs(s - bucket) < 0.1)
print(f" {bucket:.1f}: {'█' * count} ({count})")
```
### What to check:
- **Score distribution has variance** — RL needs score variance. All-same scores are useless.
## Pattern 4: Full RL Training (serve mode)
For actual RL training with Atropos:
```bash
# Terminal 1: Start Atropos API server
run-api
# Terminal 2: Start your environment
python environments/your_env.py serve \
--config environments/your_env/default.yaml
```
For Phase 2 with VLLM:
```bash
# Terminal 1: VLLM server
python -m vllm.entrypoints.openai.api_server --model your-model --port 8000
# Terminal 2: Atropos API
run-api
# Terminal 3: Environment
python environments/your_env.py serve \
--openai.base_url http://localhost:8000/v1 \
--openai.model_name your-model \
--openai.server_type vllm
```
## Pattern 5: Quick Smoke Test
Verify imports and config before spending money on API calls:
```python
from environments.your_env import YourEnv
print(f"Name: {YourEnv.name}")
cfg, servers = YourEnv.config_init()
print(f"Toolsets: {cfg.enabled_toolsets}")
print(f"Server: {servers[0].model_name}")
print("All imports OK")
```
## Timing Expectations
| Mode | Items | Time per item | Total |
|------|-------|--------------|-------|
| process (1 item) | 1 | 30-90s | ~1 min |
| evaluate (5 items) | 5 | 30-90s | ~5 min |
| evaluate (20 items) | 20 | 30-90s | ~15-30 min |
| process (50 items) | 50 | 30-90s | ~30-75 min |
Times are for cloud APIs with Claude Sonnet-class models. Local models may be faster or slower depending on hardware.

View file

@ -0,0 +1,519 @@
---
name: huggingface-tokenizers
description: Fast tokenizers optimized for research and production. Rust-based implementation tokenizes 1GB in <20 seconds. Supports BPE, WordPiece, and Unigram algorithms. Train custom vocabularies, track alignments, handle padding/truncation. Integrates seamlessly with transformers. Use when you need high-performance tokenization or custom tokenizer training.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [tokenizers, transformers, datasets]
metadata:
hermes:
tags: [Tokenization, HuggingFace, BPE, WordPiece, Unigram, Fast Tokenization, Rust, Custom Tokenizer, Alignment Tracking, Production]
---
# HuggingFace Tokenizers - Fast Tokenization for NLP
Fast, production-ready tokenizers with Rust performance and Python ease-of-use.
## When to use HuggingFace Tokenizers
**Use HuggingFace Tokenizers when:**
- Need extremely fast tokenization (<20s per GB of text)
- Training custom tokenizers from scratch
- Want alignment tracking (token → original text position)
- Building production NLP pipelines
- Need to tokenize large corpora efficiently
**Performance**:
- **Speed**: <20 seconds to tokenize 1GB on CPU
- **Implementation**: Rust core with Python/Node.js bindings
- **Efficiency**: 10-100× faster than pure Python implementations
**Use alternatives instead**:
- **SentencePiece**: Language-independent, used by T5/ALBERT
- **tiktoken**: OpenAI's BPE tokenizer for GPT models
- **transformers AutoTokenizer**: Loading pretrained only (uses this library internally)
## Quick start
### Installation
```bash
# Install tokenizers
pip install tokenizers
# With transformers integration
pip install tokenizers transformers
```
### Load pretrained tokenizer
```python
from tokenizers import Tokenizer
# Load from HuggingFace Hub
tokenizer = Tokenizer.from_pretrained("bert-base-uncased")
# Encode text
output = tokenizer.encode("Hello, how are you?")
print(output.tokens) # ['hello', ',', 'how', 'are', 'you', '?']
print(output.ids) # [7592, 1010, 2129, 2024, 2017, 1029]
# Decode back
text = tokenizer.decode(output.ids)
print(text) # "hello, how are you?"
```
### Train custom BPE tokenizer
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
# Initialize tokenizer with BPE model
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
# Configure trainer
trainer = BpeTrainer(
vocab_size=30000,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
min_frequency=2
)
# Train on files
files = ["train.txt", "validation.txt"]
tokenizer.train(files, trainer)
# Save
tokenizer.save("my-tokenizer.json")
```
**Training time**: ~1-2 minutes for 100MB corpus, ~10-20 minutes for 1GB
### Batch encoding with padding
```python
# Enable padding
tokenizer.enable_padding(pad_id=3, pad_token="[PAD]")
# Encode batch
texts = ["Hello world", "This is a longer sentence"]
encodings = tokenizer.encode_batch(texts)
for encoding in encodings:
print(encoding.ids)
# [101, 7592, 2088, 102, 3, 3, 3]
# [101, 2023, 2003, 1037, 2936, 6251, 102]
```
## Tokenization algorithms
### BPE (Byte-Pair Encoding)
**How it works**:
1. Start with character-level vocabulary
2. Find most frequent character pair
3. Merge into new token, add to vocabulary
4. Repeat until vocabulary size reached
**Used by**: GPT-2, GPT-3, RoBERTa, BART, DeBERTa
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel
tokenizer = Tokenizer(BPE(unk_token="<|endoftext|>"))
tokenizer.pre_tokenizer = ByteLevel()
trainer = BpeTrainer(
vocab_size=50257,
special_tokens=["<|endoftext|>"],
min_frequency=2
)
tokenizer.train(files=["data.txt"], trainer=trainer)
```
**Advantages**:
- Handles OOV words well (breaks into subwords)
- Flexible vocabulary size
- Good for morphologically rich languages
**Trade-offs**:
- Tokenization depends on merge order
- May split common words unexpectedly
### WordPiece
**How it works**:
1. Start with character vocabulary
2. Score merge pairs: `frequency(pair) / (frequency(first) × frequency(second))`
3. Merge highest scoring pair
4. Repeat until vocabulary size reached
**Used by**: BERT, DistilBERT, MobileBERT
```python
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.trainers import WordPieceTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.normalizers import BertNormalizer
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
tokenizer.normalizer = BertNormalizer(lowercase=True)
tokenizer.pre_tokenizer = Whitespace()
trainer = WordPieceTrainer(
vocab_size=30522,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
continuing_subword_prefix="##"
)
tokenizer.train(files=["corpus.txt"], trainer=trainer)
```
**Advantages**:
- Prioritizes meaningful merges (high score = semantically related)
- Used successfully in BERT (state-of-the-art results)
**Trade-offs**:
- Unknown words become `[UNK]` if no subword match
- Saves vocabulary, not merge rules (larger files)
### Unigram
**How it works**:
1. Start with large vocabulary (all substrings)
2. Compute loss for corpus with current vocabulary
3. Remove tokens with minimal impact on loss
4. Repeat until vocabulary size reached
**Used by**: ALBERT, T5, mBART, XLNet (via SentencePiece)
```python
from tokenizers import Tokenizer
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
tokenizer = Tokenizer(Unigram())
trainer = UnigramTrainer(
vocab_size=8000,
special_tokens=["<unk>", "<s>", "</s>"],
unk_token="<unk>"
)
tokenizer.train(files=["data.txt"], trainer=trainer)
```
**Advantages**:
- Probabilistic (finds most likely tokenization)
- Works well for languages without word boundaries
- Handles diverse linguistic contexts
**Trade-offs**:
- Computationally expensive to train
- More hyperparameters to tune
## Tokenization pipeline
Complete pipeline: **Normalization → Pre-tokenization → Model → Post-processing**
### Normalization
Clean and standardize text:
```python
from tokenizers.normalizers import NFD, StripAccents, Lowercase, Sequence
tokenizer.normalizer = Sequence([
NFD(), # Unicode normalization (decompose)
Lowercase(), # Convert to lowercase
StripAccents() # Remove accents
])
# Input: "Héllo WORLD"
# After normalization: "hello world"
```
**Common normalizers**:
- `NFD`, `NFC`, `NFKD`, `NFKC` - Unicode normalization forms
- `Lowercase()` - Convert to lowercase
- `StripAccents()` - Remove accents (é → e)
- `Strip()` - Remove whitespace
- `Replace(pattern, content)` - Regex replacement
### Pre-tokenization
Split text into word-like units:
```python
from tokenizers.pre_tokenizers import Whitespace, Punctuation, Sequence, ByteLevel
# Split on whitespace and punctuation
tokenizer.pre_tokenizer = Sequence([
Whitespace(),
Punctuation()
])
# Input: "Hello, world!"
# After pre-tokenization: ["Hello", ",", "world", "!"]
```
**Common pre-tokenizers**:
- `Whitespace()` - Split on spaces, tabs, newlines
- `ByteLevel()` - GPT-2 style byte-level splitting
- `Punctuation()` - Isolate punctuation
- `Digits(individual_digits=True)` - Split digits individually
- `Metaspace()` - Replace spaces with ▁ (SentencePiece style)
### Post-processing
Add special tokens for model input:
```python
from tokenizers.processors import TemplateProcessing
# BERT-style: [CLS] sentence [SEP]
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
pair="[CLS] $A [SEP] $B [SEP]",
special_tokens=[
("[CLS]", 1),
("[SEP]", 2),
],
)
```
**Common patterns**:
```python
# GPT-2: sentence <|endoftext|>
TemplateProcessing(
single="$A <|endoftext|>",
special_tokens=[("<|endoftext|>", 50256)]
)
# RoBERTa: <s> sentence </s>
TemplateProcessing(
single="<s> $A </s>",
pair="<s> $A </s> </s> $B </s>",
special_tokens=[("<s>", 0), ("</s>", 2)]
)
```
## Alignment tracking
Track token positions in original text:
```python
output = tokenizer.encode("Hello, world!")
# Get token offsets
for token, offset in zip(output.tokens, output.offsets):
start, end = offset
print(f"{token:10} → [{start:2}, {end:2}): {text[start:end]!r}")
# Output:
# hello → [ 0, 5): 'Hello'
# , → [ 5, 6): ','
# world → [ 7, 12): 'world'
# ! → [12, 13): '!'
```
**Use cases**:
- Named entity recognition (map predictions back to text)
- Question answering (extract answer spans)
- Token classification (align labels to original positions)
## Integration with transformers
### Load with AutoTokenizer
```python
from transformers import AutoTokenizer
# AutoTokenizer automatically uses fast tokenizers
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Check if using fast tokenizer
print(tokenizer.is_fast) # True
# Access underlying tokenizers.Tokenizer
fast_tokenizer = tokenizer.backend_tokenizer
print(type(fast_tokenizer)) # <class 'tokenizers.Tokenizer'>
```
### Convert custom tokenizer to transformers
```python
from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerFast
# Train custom tokenizer
tokenizer = Tokenizer(BPE())
# ... train tokenizer ...
tokenizer.save("my-tokenizer.json")
# Wrap for transformers
transformers_tokenizer = PreTrainedTokenizerFast(
tokenizer_file="my-tokenizer.json",
unk_token="[UNK]",
pad_token="[PAD]",
cls_token="[CLS]",
sep_token="[SEP]",
mask_token="[MASK]"
)
# Use like any transformers tokenizer
outputs = transformers_tokenizer(
"Hello world",
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
```
## Common patterns
### Train from iterator (large datasets)
```python
from datasets import load_dataset
# Load dataset
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
# Create batch iterator
def batch_iterator(batch_size=1000):
for i in range(0, len(dataset), batch_size):
yield dataset[i:i + batch_size]["text"]
# Train tokenizer
tokenizer.train_from_iterator(
batch_iterator(),
trainer=trainer,
length=len(dataset) # For progress bar
)
```
**Performance**: Processes 1GB in ~10-20 minutes
### Enable truncation and padding
```python
# Enable truncation
tokenizer.enable_truncation(max_length=512)
# Enable padding
tokenizer.enable_padding(
pad_id=tokenizer.token_to_id("[PAD]"),
pad_token="[PAD]",
length=512 # Fixed length, or None for batch max
)
# Encode with both
output = tokenizer.encode("This is a long sentence that will be truncated...")
print(len(output.ids)) # 512
```
### Multi-processing
```python
from tokenizers import Tokenizer
from multiprocessing import Pool
# Load tokenizer
tokenizer = Tokenizer.from_file("tokenizer.json")
def encode_batch(texts):
return tokenizer.encode_batch(texts)
# Process large corpus in parallel
with Pool(8) as pool:
# Split corpus into chunks
chunk_size = 1000
chunks = [corpus[i:i+chunk_size] for i in range(0, len(corpus), chunk_size)]
# Encode in parallel
results = pool.map(encode_batch, chunks)
```
**Speedup**: 5-8× with 8 cores
## Performance benchmarks
### Training speed
| Corpus Size | BPE (30k vocab) | WordPiece (30k) | Unigram (8k) |
|-------------|-----------------|-----------------|--------------|
| 10 MB | 15 sec | 18 sec | 25 sec |
| 100 MB | 1.5 min | 2 min | 4 min |
| 1 GB | 15 min | 20 min | 40 min |
**Hardware**: 16-core CPU, tested on English Wikipedia
### Tokenization speed
| Implementation | 1 GB corpus | Throughput |
|----------------|-------------|---------------|
| Pure Python | ~20 minutes | ~50 MB/min |
| HF Tokenizers | ~15 seconds | ~4 GB/min |
| **Speedup** | **80×** | **80×** |
**Test**: English text, average sentence length 20 words
### Memory usage
| Task | Memory |
|-------------------------|---------|
| Load tokenizer | ~10 MB |
| Train BPE (30k vocab) | ~200 MB |
| Encode 1M sentences | ~500 MB |
## Supported models
Pre-trained tokenizers available via `from_pretrained()`:
**BERT family**:
- `bert-base-uncased`, `bert-large-cased`
- `distilbert-base-uncased`
- `roberta-base`, `roberta-large`
**GPT family**:
- `gpt2`, `gpt2-medium`, `gpt2-large`
- `distilgpt2`
**T5 family**:
- `t5-small`, `t5-base`, `t5-large`
- `google/flan-t5-xxl`
**Other**:
- `facebook/bart-base`, `facebook/mbart-large-cc25`
- `albert-base-v2`, `albert-xlarge-v2`
- `xlm-roberta-base`, `xlm-roberta-large`
Browse all: https://huggingface.co/models?library=tokenizers
## References
- **[Training Guide](references/training.md)** - Train custom tokenizers, configure trainers, handle large datasets
- **[Algorithms Deep Dive](references/algorithms.md)** - BPE, WordPiece, Unigram explained in detail
- **[Pipeline Components](references/pipeline.md)** - Normalizers, pre-tokenizers, post-processors, decoders
- **[Transformers Integration](references/integration.md)** - AutoTokenizer, PreTrainedTokenizerFast, special tokens
## Resources
- **Docs**: https://huggingface.co/docs/tokenizers
- **GitHub**: https://github.com/huggingface/tokenizers ⭐ 9,000+
- **Version**: 0.20.0+
- **Course**: https://huggingface.co/learn/nlp-course/chapter6/1
- **Paper**: BPE (Sennrich et al., 2016), WordPiece (Schuster & Nakajima, 2012)

View file

@ -0,0 +1,653 @@
# Tokenization Algorithms Deep Dive
Comprehensive explanation of BPE, WordPiece, and Unigram algorithms.
## Byte-Pair Encoding (BPE)
### Algorithm overview
BPE iteratively merges the most frequent pair of tokens in a corpus.
**Training process**:
1. Initialize vocabulary with all characters
2. Count frequency of all adjacent token pairs
3. Merge most frequent pair into new token
4. Add new token to vocabulary
5. Update corpus with new token
6. Repeat until vocabulary size reached
### Step-by-step example
**Corpus**:
```
low: 5
lower: 2
newest: 6
widest: 3
```
**Iteration 1**:
```
Count pairs:
'e' + 's': 9 (newest: 6, widest: 3) ← most frequent
'l' + 'o': 7
'o' + 'w': 7
...
Merge: 'e' + 's' → 'es'
Updated corpus:
low: 5
lower: 2
newest: 6 → newes|t: 6
widest: 3 → wides|t: 3
Vocabulary: [a-z] + ['es']
```
**Iteration 2**:
```
Count pairs:
'es' + 't': 9 ← most frequent
'l' + 'o': 7
...
Merge: 'es' + 't' → 'est'
Updated corpus:
low: 5
lower: 2
newest: 6 → new|est: 6
widest: 3 → wid|est: 3
Vocabulary: [a-z] + ['es', 'est']
```
**Continue until desired vocabulary size...**
### Tokenization with trained BPE
Given vocabulary: `['l', 'o', 'w', 'e', 'r', 'n', 's', 't', 'i', 'd', 'es', 'est', 'lo', 'low', 'ne', 'new', 'newest', 'wi', 'wid', 'widest']`
Tokenize "lowest":
```
Step 1: Split into characters
['l', 'o', 'w', 'e', 's', 't']
Step 2: Apply merges in order learned during training
- Merge 'l' + 'o' → 'lo' (if this merge was learned)
- Merge 'lo' + 'w' → 'low' (if learned)
- Merge 'e' + 's' → 'es' (learned)
- Merge 'es' + 't' → 'est' (learned)
Final: ['low', 'est']
```
### Implementation
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
# Initialize
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
# Configure trainer
trainer = BpeTrainer(
vocab_size=1000,
min_frequency=2,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
)
# Train
corpus = [
"This is a sample corpus for BPE training.",
"BPE learns subword units from the training data.",
# ... more sentences
]
tokenizer.train_from_iterator(corpus, trainer=trainer)
# Use
output = tokenizer.encode("This is tokenization")
print(output.tokens) # ['This', 'is', 'token', 'ization']
```
### Byte-level BPE (GPT-2 variant)
**Problem**: Standard BPE has limited character coverage (256+ Unicode chars)
**Solution**: Operate on byte level (256 bytes)
```python
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
tokenizer = Tokenizer(BPE())
# Byte-level pre-tokenization
tokenizer.pre_tokenizer = ByteLevel()
tokenizer.decoder = ByteLevelDecoder()
# This handles ALL possible characters, including emojis
text = "Hello 🌍 世界"
tokens = tokenizer.encode(text).tokens
```
**Advantages**:
- Handles any Unicode character (256 byte coverage)
- No unknown tokens (worst case: bytes)
- Used by GPT-2, GPT-3, BART
**Trade-offs**:
- Slightly worse compression (bytes vs characters)
- More tokens for non-ASCII text
### BPE variants
**SentencePiece BPE**:
- Language-independent (no pre-tokenization)
- Treats input as raw byte stream
- Used by T5, ALBERT, XLNet
**Robust BPE**:
- Dropout during training (randomly skip merges)
- More robust tokenization at inference
- Reduces overfitting to training data
## WordPiece
### Algorithm overview
WordPiece is similar to BPE but uses a different merge selection criterion.
**Training process**:
1. Initialize vocabulary with all characters
2. Count frequency of all token pairs
3. Score each pair: `score = freq(pair) / (freq(first) × freq(second))`
4. Merge pair with highest score
5. Repeat until vocabulary size reached
### Why different scoring?
**BPE**: Merges most frequent pairs
- "aa" appears 100 times → high priority
- Even if 'a' appears 1000 times alone
**WordPiece**: Merges pairs that are semantically related
- "aa" appears 100 times, 'a' appears 1000 times → low score (100 / (1000 × 1000))
- "th" appears 50 times, 't' appears 60 times, 'h' appears 55 times → high score (50 / (60 × 55))
- Prioritizes pairs that appear together more than expected
### Step-by-step example
**Corpus**:
```
low: 5
lower: 2
newest: 6
widest: 3
```
**Iteration 1**:
```
Count frequencies:
'e': 11 (lower: 2, newest: 6, widest: 3)
's': 9
't': 9
...
Count pairs:
'e' + 's': 9 (newest: 6, widest: 3)
'es' + 't': 9 (newest: 6, widest: 3)
...
Compute scores:
score('e' + 's') = 9 / (11 × 9) = 0.091
score('es' + 't') = 9 / (9 × 9) = 0.111 ← highest score
score('l' + 'o') = 7 / (7 × 9) = 0.111 ← tied
Choose: 'es' + 't' → 'est' (or 'lo' if tied)
```
**Key difference**: WordPiece prioritizes rare combinations over frequent ones.
### Tokenization with WordPiece
Given vocabulary: `['##e', '##s', '##t', 'l', 'o', 'w', 'new', 'est', 'low']`
Tokenize "lowest":
```
Step 1: Find longest matching prefix
'lowest' → 'low' (matches)
Step 2: Find longest match for remainder
'est' → 'est' (matches)
Final: ['low', 'est']
```
**If no match**:
```
Tokenize "unknownword":
'unknownword' → no match
'unknown' → no match
'unkn' → no match
'un' → no match
'u' → no match
→ [UNK]
```
### Implementation
```python
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.trainers import WordPieceTrainer
from tokenizers.normalizers import BertNormalizer
from tokenizers.pre_tokenizers import BertPreTokenizer
# Initialize BERT-style tokenizer
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
# Normalization (lowercase, accent stripping)
tokenizer.normalizer = BertNormalizer(lowercase=True)
# Pre-tokenization (whitespace + punctuation)
tokenizer.pre_tokenizer = BertPreTokenizer()
# Configure trainer
trainer = WordPieceTrainer(
vocab_size=30522, # BERT vocab size
min_frequency=2,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
continuing_subword_prefix="##" # BERT uses ##
)
# Train
tokenizer.train_from_iterator(corpus, trainer=trainer)
# Use
output = tokenizer.encode("Tokenization works great!")
print(output.tokens) # ['token', '##ization', 'works', 'great', '!']
```
### Subword prefix
**BERT uses `##` prefix**:
```
"unbelievable" → ['un', '##believ', '##able']
```
**Why?**
- Indicates token is a continuation
- Allows reconstruction: remove ##, concatenate
- Helps model distinguish word boundaries
### WordPiece advantages
**Semantic merges**:
- Prioritizes meaningful combinations
- "qu" has high score (always together)
- "qx" has low score (rare combination)
**Better for morphology**:
- Captures affixes: un-, -ing, -ed
- Preserves word stems
**Trade-offs**:
- Slower training than BPE
- More memory (stores vocabulary, not merges)
- Original implementation not open-source (HF reimplementation)
## Unigram
### Algorithm overview
Unigram works backward: start with large vocabulary, remove tokens.
**Training process**:
1. Initialize with large vocabulary (all substrings)
2. Estimate probability of each token (frequency-based)
3. For each token, compute loss increase if removed
4. Remove 10-20% of tokens with lowest loss impact
5. Re-estimate probabilities
6. Repeat until desired vocabulary size
### Probabilistic tokenization
**Unigram assumption**: Each token is independent.
Given vocabulary with probabilities:
```
P('low') = 0.02
P('l') = 0.01
P('o') = 0.015
P('w') = 0.01
P('est') = 0.03
P('e') = 0.02
P('s') = 0.015
P('t') = 0.015
```
Tokenize "lowest":
```
Option 1: ['low', 'est']
P = P('low') × P('est') = 0.02 × 0.03 = 0.0006
Option 2: ['l', 'o', 'w', 'est']
P = 0.01 × 0.015 × 0.01 × 0.03 = 0.000000045
Option 3: ['low', 'e', 's', 't']
P = 0.02 × 0.02 × 0.015 × 0.015 = 0.0000009
Choose option 1 (highest probability)
```
### Viterbi algorithm
Finding best tokenization is expensive (exponential possibilities).
**Viterbi algorithm** (dynamic programming):
```python
def tokenize_viterbi(word, vocab, probs):
n = len(word)
# dp[i] = (best_prob, best_tokens) for word[:i]
dp = [{} for _ in range(n + 1)]
dp[0] = (0.0, []) # log probability
for i in range(1, n + 1):
best_prob = float('-inf')
best_tokens = []
# Try all possible last tokens
for j in range(i):
token = word[j:i]
if token in vocab:
prob = dp[j][0] + log(probs[token])
if prob > best_prob:
best_prob = prob
best_tokens = dp[j][1] + [token]
dp[i] = (best_prob, best_tokens)
return dp[n][1]
```
**Time complexity**: O(n² × vocab_size) vs O(2^n) brute force
### Implementation
```python
from tokenizers import Tokenizer
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
# Initialize
tokenizer = Tokenizer(Unigram())
# Configure trainer
trainer = UnigramTrainer(
vocab_size=8000,
special_tokens=["<unk>", "<s>", "</s>"],
unk_token="<unk>",
max_piece_length=16, # Max token length
n_sub_iterations=2, # EM iterations
shrinking_factor=0.75 # Remove 25% each iteration
)
# Train
tokenizer.train_from_iterator(corpus, trainer=trainer)
# Use
output = tokenizer.encode("Tokenization with Unigram")
print(output.tokens) # ['▁Token', 'ization', '▁with', '▁Un', 'igram']
```
### Unigram advantages
**Probabilistic**:
- Multiple valid tokenizations
- Can sample different tokenizations (data augmentation)
**Subword regularization**:
```python
# Sample different tokenizations
for _ in range(3):
tokens = tokenizer.encode("tokenization", is_pretokenized=False).tokens
print(tokens)
# Output (different each time):
# ['token', 'ization']
# ['tok', 'en', 'ization']
# ['token', 'iz', 'ation']
```
**Language-independent**:
- No word boundaries needed
- Works for CJK languages (Chinese, Japanese, Korean)
- Treats input as character stream
**Trade-offs**:
- Slower training (EM algorithm)
- More hyperparameters
- Larger model (stores probabilities)
## Algorithm comparison
### Training speed
| Algorithm | Small (10MB) | Medium (100MB) | Large (1GB) |
|------------|--------------|----------------|-------------|
| BPE | 10-15 sec | 1-2 min | 10-20 min |
| WordPiece | 15-20 sec | 2-3 min | 15-30 min |
| Unigram | 20-30 sec | 3-5 min | 30-60 min |
**Tested on**: 16-core CPU, 30k vocab
### Tokenization quality
Tested on English Wikipedia (perplexity measurement):
| Algorithm | Vocab Size | Tokens/Word | Unknown Rate |
|------------|------------|-------------|--------------|
| BPE | 30k | 1.3 | 0.5% |
| WordPiece | 30k | 1.2 | 1.2% |
| Unigram | 8k | 1.5 | 0.3% |
**Key observations**:
- WordPiece: Slightly better compression
- BPE: Lower unknown rate
- Unigram: Smallest vocab, good coverage
### Compression ratio
Characters per token (higher = better compression):
| Language | BPE (30k) | WordPiece (30k) | Unigram (8k) |
|----------|-----------|-----------------|--------------|
| English | 4.2 | 4.5 | 3.8 |
| Chinese | 2.1 | 2.3 | 2.5 |
| Arabic | 3.5 | 3.8 | 3.2 |
**Best for each**:
- English: WordPiece
- Chinese: Unigram (language-independent)
- Arabic: WordPiece
### Use case recommendations
**BPE** - Best for:
- English language models
- Code (handles symbols well)
- Fast training needed
- **Models**: GPT-2, GPT-3, RoBERTa, BART
**WordPiece** - Best for:
- Masked language modeling (BERT-style)
- Morphologically rich languages
- Semantic understanding tasks
- **Models**: BERT, DistilBERT, ELECTRA
**Unigram** - Best for:
- Multilingual models
- Languages without word boundaries (CJK)
- Data augmentation via subword regularization
- **Models**: T5, ALBERT, XLNet (via SentencePiece)
## Advanced topics
### Handling rare words
**BPE approach**:
```
"antidisestablishmentarianism"
→ ['anti', 'dis', 'establish', 'ment', 'arian', 'ism']
```
**WordPiece approach**:
```
"antidisestablishmentarianism"
→ ['anti', '##dis', '##establish', '##ment', '##arian', '##ism']
```
**Unigram approach**:
```
"antidisestablishmentarianism"
→ ['▁anti', 'dis', 'establish', 'ment', 'arian', 'ism']
```
### Handling numbers
**Challenge**: Infinite number combinations
**BPE solution**: Byte-level (handles any digit sequence)
```python
tokenizer = Tokenizer(BPE())
tokenizer.pre_tokenizer = ByteLevel()
# Handles any number
"123456789" → byte-level tokens
```
**WordPiece solution**: Digit pre-tokenization
```python
from tokenizers.pre_tokenizers import Digits
# Split digits individually or as groups
tokenizer.pre_tokenizer = Digits(individual_digits=True)
"123" → ['1', '2', '3']
```
**Unigram solution**: Learns common number patterns
```python
# Learns patterns during training
"2023" → ['202', '3'] or ['20', '23']
```
### Handling case sensitivity
**Lowercase (BERT)**:
```python
from tokenizers.normalizers import Lowercase
tokenizer.normalizer = Lowercase()
"Hello WORLD" → "hello world" → ['hello', 'world']
```
**Preserve case (GPT-2)**:
```python
# No case normalization
tokenizer.normalizer = None
"Hello WORLD" → ['Hello', 'WORLD']
```
**Cased tokens (RoBERTa)**:
```python
# Learns separate tokens for different cases
Vocabulary: ['Hello', 'hello', 'HELLO', 'world', 'WORLD']
```
### Handling emojis and special characters
**Byte-level (GPT-2)**:
```python
tokenizer.pre_tokenizer = ByteLevel()
"Hello 🌍 👋" → byte-level representation (always works)
```
**Unicode normalization**:
```python
from tokenizers.normalizers import NFKC
tokenizer.normalizer = NFKC()
"é" (composed) ↔ "é" (decomposed) → normalized to one form
```
## Troubleshooting
### Issue: Poor subword splitting
**Symptom**:
```
"running" → ['r', 'u', 'n', 'n', 'i', 'n', 'g'] (too granular)
```
**Solutions**:
1. Increase vocabulary size
2. Train longer (more merge iterations)
3. Lower `min_frequency` threshold
### Issue: Too many unknown tokens
**Symptom**:
```
5% of tokens are [UNK]
```
**Solutions**:
1. Increase vocabulary size
2. Use byte-level BPE (no UNK possible)
3. Verify training corpus is representative
### Issue: Inconsistent tokenization
**Symptom**:
```
"running" → ['run', 'ning']
"runner" → ['r', 'u', 'n', 'n', 'e', 'r']
```
**Solutions**:
1. Check normalization consistency
2. Ensure pre-tokenization is deterministic
3. Use Unigram for probabilistic variance
## Best practices
1. **Match algorithm to model architecture**:
- BERT-style → WordPiece
- GPT-style → BPE
- T5-style → Unigram
2. **Use byte-level for multilingual**:
- Handles any Unicode
- No unknown tokens
3. **Test on representative data**:
- Measure compression ratio
- Check unknown token rate
- Inspect sample tokenizations
4. **Version control tokenizers**:
- Save with model
- Document special tokens
- Track vocabulary changes

View file

@ -0,0 +1,637 @@
# Transformers Integration
Complete guide to using HuggingFace Tokenizers with the Transformers library.
## AutoTokenizer
The easiest way to load tokenizers.
### Loading pretrained tokenizers
```python
from transformers import AutoTokenizer
# Load from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Check if using fast tokenizer (Rust-based)
print(tokenizer.is_fast) # True
# Access underlying tokenizers.Tokenizer
if tokenizer.is_fast:
fast_tokenizer = tokenizer.backend_tokenizer
print(type(fast_tokenizer)) # <class 'tokenizers.Tokenizer'>
```
### Fast vs slow tokenizers
| Feature | Fast (Rust) | Slow (Python) |
|--------------------------|----------------|---------------|
| Speed | 5-10× faster | Baseline |
| Alignment tracking | ✅ Full support | ❌ Limited |
| Batch processing | ✅ Optimized | ⚠️ Slower |
| Offset mapping | ✅ Yes | ❌ No |
| Installation | `tokenizers` | Built-in |
**Always use fast tokenizers when available.**
### Check available tokenizers
```python
from transformers import TOKENIZER_MAPPING
# List all fast tokenizers
for config_class, (slow, fast) in TOKENIZER_MAPPING.items():
if fast is not None:
print(f"{config_class.__name__}: {fast.__name__}")
```
## PreTrainedTokenizerFast
Wrap custom tokenizers for transformers.
### Convert custom tokenizer
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from transformers import PreTrainedTokenizerFast
# Train custom tokenizer
tokenizer = Tokenizer(BPE())
trainer = BpeTrainer(
vocab_size=30000,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
)
tokenizer.train(files=["corpus.txt"], trainer=trainer)
# Save tokenizer
tokenizer.save("my-tokenizer.json")
# Wrap for transformers
transformers_tokenizer = PreTrainedTokenizerFast(
tokenizer_file="my-tokenizer.json",
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]"
)
# Save in transformers format
transformers_tokenizer.save_pretrained("my-tokenizer")
```
**Result**: Directory with `tokenizer.json` + `tokenizer_config.json` + `special_tokens_map.json`
### Use like any transformers tokenizer
```python
# Load
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("my-tokenizer")
# Encode with all transformers features
outputs = tokenizer(
"Hello world",
padding="max_length",
truncation=True,
max_length=128,
return_tensors="pt"
)
print(outputs.keys())
# dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
```
## Special tokens
### Default special tokens
| Model Family | CLS/BOS | SEP/EOS | PAD | UNK | MASK |
|--------------|---------|---------------|---------|---------|---------|
| BERT | [CLS] | [SEP] | [PAD] | [UNK] | [MASK] |
| GPT-2 | - | <\|endoftext\|> | <\|endoftext\|> | <\|endoftext\|> | - |
| RoBERTa | <s> | </s> | <pad> | <unk> | <mask> |
| T5 | - | </s> | <pad> | <unk> | - |
### Adding special tokens
```python
# Add new special tokens
special_tokens_dict = {
"additional_special_tokens": ["<|image|>", "<|video|>", "<|audio|>"]
}
num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)
print(f"Added {num_added_tokens} tokens")
# Resize model embeddings
model.resize_token_embeddings(len(tokenizer))
# Use new tokens
text = "This is an image: <|image|>"
tokens = tokenizer.encode(text)
```
### Adding regular tokens
```python
# Add domain-specific tokens
new_tokens = ["COVID-19", "mRNA", "vaccine"]
num_added = tokenizer.add_tokens(new_tokens)
# These are NOT special tokens (can be split if needed)
tokenizer.add_tokens(new_tokens, special_tokens=False)
# These ARE special tokens (never split)
tokenizer.add_tokens(new_tokens, special_tokens=True)
```
## Encoding and decoding
### Basic encoding
```python
# Single sentence
text = "Hello, how are you?"
encoded = tokenizer(text)
print(encoded)
# {'input_ids': [101, 7592, 1010, 2129, 2024, 2017, 1029, 102],
# 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0],
# 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}
```
### Batch encoding
```python
# Multiple sentences
texts = ["Hello world", "How are you?", "I am fine"]
encoded = tokenizer(texts, padding=True, truncation=True, max_length=10)
print(encoded['input_ids'])
# [[101, 7592, 2088, 102, 0, 0, 0, 0, 0, 0],
# [101, 2129, 2024, 2017, 1029, 102, 0, 0, 0, 0],
# [101, 1045, 2572, 2986, 102, 0, 0, 0, 0, 0]]
```
### Return tensors
```python
# Return PyTorch tensors
outputs = tokenizer("Hello world", return_tensors="pt")
print(outputs['input_ids'].shape) # torch.Size([1, 5])
# Return TensorFlow tensors
outputs = tokenizer("Hello world", return_tensors="tf")
# Return NumPy arrays
outputs = tokenizer("Hello world", return_tensors="np")
# Return lists (default)
outputs = tokenizer("Hello world", return_tensors=None)
```
### Decoding
```python
# Decode token IDs
ids = [101, 7592, 2088, 102]
text = tokenizer.decode(ids)
print(text) # "[CLS] hello world [SEP]"
# Skip special tokens
text = tokenizer.decode(ids, skip_special_tokens=True)
print(text) # "hello world"
# Batch decode
batch_ids = [[101, 7592, 102], [101, 2088, 102]]
texts = tokenizer.batch_decode(batch_ids, skip_special_tokens=True)
print(texts) # ["hello", "world"]
```
## Padding and truncation
### Padding strategies
```python
# Pad to max length in batch
tokenizer(texts, padding="longest")
# Pad to model max length
tokenizer(texts, padding="max_length", max_length=128)
# No padding
tokenizer(texts, padding=False)
# Pad to multiple of value (for efficient computation)
tokenizer(texts, padding="max_length", max_length=128, pad_to_multiple_of=8)
# Result: length will be 128 (already multiple of 8)
```
### Truncation strategies
```python
# Truncate to max length
tokenizer(text, truncation=True, max_length=10)
# Only truncate first sequence (for pairs)
tokenizer(text1, text2, truncation="only_first", max_length=20)
# Only truncate second sequence
tokenizer(text1, text2, truncation="only_second", max_length=20)
# Truncate longest first (default for pairs)
tokenizer(text1, text2, truncation="longest_first", max_length=20)
# No truncation (error if too long)
tokenizer(text, truncation=False)
```
### Stride for long documents
```python
# For documents longer than max_length
text = "Very long document " * 1000
# Encode with overlap
encodings = tokenizer(
text,
max_length=512,
stride=128, # Overlap between chunks
truncation=True,
return_overflowing_tokens=True,
return_offsets_mapping=True
)
# Get all chunks
num_chunks = len(encodings['input_ids'])
print(f"Split into {num_chunks} chunks")
# Each chunk overlaps by stride tokens
for i, chunk in enumerate(encodings['input_ids']):
print(f"Chunk {i}: {len(chunk)} tokens")
```
**Use case**: Long document QA, sliding window inference
## Alignment and offsets
### Offset mapping
```python
# Get character offsets for each token
encoded = tokenizer("Hello, world!", return_offsets_mapping=True)
for token, (start, end) in zip(
encoded.tokens(),
encoded['offset_mapping'][0]
):
print(f"{token:10s} → [{start:2d}, {end:2d})")
# Output:
# [CLS] → [ 0, 0)
# Hello → [ 0, 5)
# , → [ 5, 6)
# world → [ 7, 12)
# ! → [12, 13)
# [SEP] → [ 0, 0)
```
### Word IDs
```python
# Get word index for each token
encoded = tokenizer("Hello world", return_offsets_mapping=True)
word_ids = encoded.word_ids()
print(word_ids)
# [None, 0, 1, None]
# None = special token, 0 = first word, 1 = second word
```
**Use case**: Token classification (NER, POS tagging)
### Character to token mapping
```python
text = "Machine learning is awesome"
encoded = tokenizer(text, return_offsets_mapping=True)
# Find token for character position
char_pos = 8 # "l" in "learning"
token_idx = encoded.char_to_token(char_pos)
print(f"Character {char_pos} is in token {token_idx}: {encoded.tokens()[token_idx]}")
# Character 8 is in token 2: learning
```
**Use case**: Question answering (map answer character span to tokens)
### Sequence pairs
```python
# Encode sentence pair
encoded = tokenizer("Question here", "Answer here", return_offsets_mapping=True)
# Get sequence IDs (which sequence each token belongs to)
sequence_ids = encoded.sequence_ids()
print(sequence_ids)
# [None, 0, 0, 0, None, 1, 1, 1, None]
# None = special token, 0 = question, 1 = answer
```
## Model integration
### Use with transformers models
```python
from transformers import AutoModel, AutoTokenizer
import torch
# Load model and tokenizer
model = AutoModel.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Tokenize
text = "Hello world"
inputs = tokenizer(text, return_tensors="pt")
# Forward pass
with torch.no_grad():
outputs = model(**inputs)
# Get embeddings
last_hidden_state = outputs.last_hidden_state
print(last_hidden_state.shape) # [1, seq_len, hidden_size]
```
### Custom model with custom tokenizer
```python
from transformers import BertConfig, BertModel
# Train custom tokenizer
from tokenizers import Tokenizer, models, trainers
tokenizer = Tokenizer(models.BPE())
trainer = trainers.BpeTrainer(vocab_size=30000)
tokenizer.train(files=["data.txt"], trainer=trainer)
# Wrap for transformers
from transformers import PreTrainedTokenizerFast
fast_tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
unk_token="[UNK]",
pad_token="[PAD]"
)
# Create model with custom vocab size
config = BertConfig(vocab_size=30000)
model = BertModel(config)
# Use together
inputs = fast_tokenizer("Hello world", return_tensors="pt")
outputs = model(**inputs)
```
### Save and load together
```python
# Save both
model.save_pretrained("my-model")
tokenizer.save_pretrained("my-model")
# Directory structure:
# my-model/
# ├── config.json
# ├── pytorch_model.bin
# ├── tokenizer.json
# ├── tokenizer_config.json
# └── special_tokens_map.json
# Load both
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained("my-model")
tokenizer = AutoTokenizer.from_pretrained("my-model")
```
## Advanced features
### Multimodal tokenization
```python
from transformers import AutoTokenizer
# LLaVA-style (image + text)
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-1.5-7b-hf")
# Add image placeholder token
tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
# Use in prompt
text = "Describe this image: <image>"
inputs = tokenizer(text, return_tensors="pt")
```
### Template formatting
```python
# Chat template
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi! How can I help?"},
{"role": "user", "content": "What's the weather?"}
]
# Apply chat template (if tokenizer has one)
if hasattr(tokenizer, "apply_chat_template"):
text = tokenizer.apply_chat_template(messages, tokenize=False)
inputs = tokenizer(text, return_tensors="pt")
```
### Custom template
```python
from transformers import PreTrainedTokenizerFast
tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json")
# Define chat template
tokenizer.chat_template = """
{%- for message in messages %}
{%- if message['role'] == 'system' %}
System: {{ message['content'] }}\\n
{%- elif message['role'] == 'user' %}
User: {{ message['content'] }}\\n
{%- elif message['role'] == 'assistant' %}
Assistant: {{ message['content'] }}\\n
{%- endif %}
{%- endfor %}
Assistant:
"""
# Use template
text = tokenizer.apply_chat_template(messages, tokenize=False)
```
## Performance optimization
### Batch processing
```python
# Process large datasets efficiently
from datasets import load_dataset
dataset = load_dataset("imdb", split="train[:1000]")
# Tokenize in batches
def tokenize_function(examples):
return tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=512
)
# Map over dataset (batched)
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
batch_size=1000,
num_proc=4 # Parallel processing
)
```
### Caching
```python
# Enable caching for repeated tokenization
tokenizer = AutoTokenizer.from_pretrained(
"bert-base-uncased",
use_fast=True,
cache_dir="./cache" # Cache tokenizer files
)
# Tokenize with caching
from functools import lru_cache
@lru_cache(maxsize=10000)
def cached_tokenize(text):
return tuple(tokenizer.encode(text))
# Reuses cached results for repeated inputs
```
### Memory efficiency
```python
# For very large datasets, use streaming
from datasets import load_dataset
dataset = load_dataset("pile", split="train", streaming=True)
def process_batch(batch):
# Tokenize
tokens = tokenizer(batch["text"], truncation=True, max_length=512)
# Process tokens...
return tokens
# Process in chunks (memory efficient)
for batch in dataset.batch(batch_size=1000):
processed = process_batch(batch)
```
## Troubleshooting
### Issue: Tokenizer not fast
**Symptom**:
```python
tokenizer.is_fast # False
```
**Solution**: Install tokenizers library
```bash
pip install tokenizers
```
### Issue: Special tokens not working
**Symptom**: Special tokens are split into subwords
**Solution**: Add as special tokens, not regular tokens
```python
# Wrong
tokenizer.add_tokens(["<|image|>"])
# Correct
tokenizer.add_special_tokens({"additional_special_tokens": ["<|image|>"]})
```
### Issue: Offset mapping not available
**Symptom**:
```python
tokenizer("text", return_offsets_mapping=True)
# Error: return_offsets_mapping not supported
```
**Solution**: Use fast tokenizer
```python
from transformers import AutoTokenizer
# Load fast version
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)
```
### Issue: Padding inconsistent
**Symptom**: Some sequences padded, others not
**Solution**: Specify padding strategy
```python
# Explicit padding
tokenizer(
texts,
padding="max_length", # or "longest"
max_length=128
)
```
## Best practices
1. **Always use fast tokenizers**:
- 5-10× faster
- Full alignment tracking
- Better batch processing
2. **Save tokenizer with model**:
- Ensures reproducibility
- Prevents version mismatches
3. **Use batch processing for datasets**:
- Tokenize with `.map(batched=True)`
- Set `num_proc` for parallelism
4. **Enable caching for repeated inputs**:
- Use `lru_cache` for inference
- Cache tokenizer files with `cache_dir`
5. **Handle special tokens properly**:
- Use `add_special_tokens()` for never-split tokens
- Resize embeddings after adding tokens
6. **Test alignment for downstream tasks**:
- Verify `offset_mapping` is correct
- Test `char_to_token()` on samples
7. **Version control tokenizer config**:
- Save `tokenizer_config.json`
- Document custom templates
- Track vocabulary changes

View file

@ -0,0 +1,723 @@
# Tokenization Pipeline Components
Complete guide to normalizers, pre-tokenizers, models, post-processors, and decoders.
## Pipeline overview
**Full tokenization pipeline**:
```
Raw Text
Normalization (cleaning, lowercasing)
Pre-tokenization (split into words)
Model (apply BPE/WordPiece/Unigram)
Post-processing (add special tokens)
Token IDs
```
**Decoding reverses the process**:
```
Token IDs
Decoder (handle special encodings)
Raw Text
```
## Normalizers
Clean and standardize input text.
### Common normalizers
**Lowercase**:
```python
from tokenizers.normalizers import Lowercase
tokenizer.normalizer = Lowercase()
# Input: "Hello WORLD"
# Output: "hello world"
```
**Unicode normalization**:
```python
from tokenizers.normalizers import NFD, NFC, NFKD, NFKC
# NFD: Canonical decomposition
tokenizer.normalizer = NFD()
# "é" → "e" + "́" (separate characters)
# NFC: Canonical composition (default)
tokenizer.normalizer = NFC()
# "e" + "́" → "é" (composed)
# NFKD: Compatibility decomposition
tokenizer.normalizer = NFKD()
# "fi" → "f" + "i"
# NFKC: Compatibility composition
tokenizer.normalizer = NFKC()
# Most aggressive normalization
```
**Strip accents**:
```python
from tokenizers.normalizers import StripAccents
tokenizer.normalizer = StripAccents()
# Input: "café"
# Output: "cafe"
```
**Whitespace handling**:
```python
from tokenizers.normalizers import Strip, StripAccents
# Remove leading/trailing whitespace
tokenizer.normalizer = Strip()
# Input: " hello "
# Output: "hello"
```
**Replace patterns**:
```python
from tokenizers.normalizers import Replace
# Replace newlines with spaces
tokenizer.normalizer = Replace("\\n", " ")
# Input: "hello\\nworld"
# Output: "hello world"
```
### Combining normalizers
```python
from tokenizers.normalizers import Sequence, NFD, Lowercase, StripAccents
# BERT-style normalization
tokenizer.normalizer = Sequence([
NFD(), # Unicode decomposition
Lowercase(), # Convert to lowercase
StripAccents() # Remove accents
])
# Input: "Café au Lait"
# After NFD: "Café au Lait" (e + ́)
# After Lowercase: "café au lait"
# After StripAccents: "cafe au lait"
```
### Use case examples
**Case-insensitive model (BERT)**:
```python
from tokenizers.normalizers import BertNormalizer
# All-in-one BERT normalization
tokenizer.normalizer = BertNormalizer(
clean_text=True, # Remove control characters
handle_chinese_chars=True, # Add spaces around Chinese
strip_accents=True, # Remove accents
lowercase=True # Lowercase
)
```
**Case-sensitive model (GPT-2)**:
```python
# Minimal normalization
tokenizer.normalizer = NFC() # Only normalize Unicode
```
**Multilingual (mBERT)**:
```python
# Preserve scripts, normalize form
tokenizer.normalizer = NFKC()
```
## Pre-tokenizers
Split text into word-like units before tokenization.
### Whitespace splitting
```python
from tokenizers.pre_tokenizers import Whitespace
tokenizer.pre_tokenizer = Whitespace()
# Input: "Hello world! How are you?"
# Output: [("Hello", (0, 5)), ("world!", (6, 12)), ("How", (13, 16)), ("are", (17, 20)), ("you?", (21, 25))]
```
### Punctuation isolation
```python
from tokenizers.pre_tokenizers import Punctuation
tokenizer.pre_tokenizer = Punctuation()
# Input: "Hello, world!"
# Output: [("Hello", ...), (",", ...), ("world", ...), ("!", ...)]
```
### Byte-level (GPT-2)
```python
from tokenizers.pre_tokenizers import ByteLevel
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True)
# Input: "Hello world"
# Output: Byte-level tokens with Ġ prefix for spaces
# [("ĠHello", ...), ("Ġworld", ...)]
```
**Key feature**: Handles ALL Unicode characters (256 byte combinations)
### Metaspace (SentencePiece)
```python
from tokenizers.pre_tokenizers import Metaspace
tokenizer.pre_tokenizer = Metaspace(replacement="▁", add_prefix_space=True)
# Input: "Hello world"
# Output: [("▁Hello", ...), ("▁world", ...)]
```
**Used by**: T5, ALBERT (via SentencePiece)
### Digits splitting
```python
from tokenizers.pre_tokenizers import Digits
# Split digits individually
tokenizer.pre_tokenizer = Digits(individual_digits=True)
# Input: "Room 123"
# Output: [("Room", ...), ("1", ...), ("2", ...), ("3", ...)]
# Keep digits together
tokenizer.pre_tokenizer = Digits(individual_digits=False)
# Input: "Room 123"
# Output: [("Room", ...), ("123", ...)]
```
### BERT pre-tokenizer
```python
from tokenizers.pre_tokenizers import BertPreTokenizer
tokenizer.pre_tokenizer = BertPreTokenizer()
# Splits on whitespace and punctuation, preserves CJK
# Input: "Hello, 世界!"
# Output: [("Hello", ...), (",", ...), ("世", ...), ("界", ...), ("!", ...)]
```
### Combining pre-tokenizers
```python
from tokenizers.pre_tokenizers import Sequence, Whitespace, Punctuation
tokenizer.pre_tokenizer = Sequence([
Whitespace(), # Split on whitespace first
Punctuation() # Then isolate punctuation
])
# Input: "Hello, world!"
# After Whitespace: [("Hello,", ...), ("world!", ...)]
# After Punctuation: [("Hello", ...), (",", ...), ("world", ...), ("!", ...)]
```
### Pre-tokenizer comparison
| Pre-tokenizer | Use Case | Example |
|-------------------|---------------------------------|--------------------------------------------|
| Whitespace | Simple English | "Hello world" → ["Hello", "world"] |
| Punctuation | Isolate symbols | "world!" → ["world", "!"] |
| ByteLevel | Multilingual, emojis | "🌍" → byte tokens |
| Metaspace | SentencePiece-style | "Hello" → ["▁Hello"] |
| BertPreTokenizer | BERT-style (CJK aware) | "世界" → ["世", "界"] |
| Digits | Handle numbers | "123" → ["1", "2", "3"] or ["123"] |
## Models
Core tokenization algorithms.
### BPE Model
```python
from tokenizers.models import BPE
model = BPE(
vocab=None, # Or provide pre-built vocab
merges=None, # Or provide merge rules
unk_token="[UNK]", # Unknown token
continuing_subword_prefix="",
end_of_word_suffix="",
fuse_unk=False # Keep unknown tokens separate
)
tokenizer = Tokenizer(model)
```
**Parameters**:
- `vocab`: Dict of token → id
- `merges`: List of merge rules `["a b", "ab c"]`
- `unk_token`: Token for unknown words
- `continuing_subword_prefix`: Prefix for subwords (empty for GPT-2)
- `end_of_word_suffix`: Suffix for last subword (empty for GPT-2)
### WordPiece Model
```python
from tokenizers.models import WordPiece
model = WordPiece(
vocab=None,
unk_token="[UNK]",
max_input_chars_per_word=100, # Max word length
continuing_subword_prefix="##" # BERT-style prefix
)
tokenizer = Tokenizer(model)
```
**Key difference**: Uses `##` prefix for continuing subwords.
### Unigram Model
```python
from tokenizers.models import Unigram
model = Unigram(
vocab=None, # List of (token, score) tuples
unk_id=0, # ID for unknown token
byte_fallback=False # Fall back to bytes if no match
)
tokenizer = Tokenizer(model)
```
**Probabilistic**: Selects tokenization with highest probability.
### WordLevel Model
```python
from tokenizers.models import WordLevel
# Simple word-to-ID mapping (no subwords)
model = WordLevel(
vocab=None,
unk_token="[UNK]"
)
tokenizer = Tokenizer(model)
```
**Warning**: Requires huge vocabulary (one token per word).
## Post-processors
Add special tokens and format output.
### Template processing
**BERT-style** (`[CLS] sentence [SEP]`):
```python
from tokenizers.processors import TemplateProcessing
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
pair="[CLS] $A [SEP] $B [SEP]",
special_tokens=[
("[CLS]", 101),
("[SEP]", 102),
],
)
# Single sentence
output = tokenizer.encode("Hello world")
# [101, ..., 102] ([CLS] hello world [SEP])
# Sentence pair
output = tokenizer.encode("Hello", "world")
# [101, ..., 102, ..., 102] ([CLS] hello [SEP] world [SEP])
```
**GPT-2 style** (`sentence <|endoftext|>`):
```python
tokenizer.post_processor = TemplateProcessing(
single="$A <|endoftext|>",
special_tokens=[
("<|endoftext|>", 50256),
],
)
```
**RoBERTa style** (`<s> sentence </s>`):
```python
tokenizer.post_processor = TemplateProcessing(
single="<s> $A </s>",
pair="<s> $A </s> </s> $B </s>",
special_tokens=[
("<s>", 0),
("</s>", 2),
],
)
```
**T5 style** (no special tokens):
```python
# T5 doesn't add special tokens via post-processor
tokenizer.post_processor = None
```
### RobertaProcessing
```python
from tokenizers.processors import RobertaProcessing
tokenizer.post_processor = RobertaProcessing(
sep=("</s>", 2),
cls=("<s>", 0),
add_prefix_space=True, # Add space before first token
trim_offsets=True # Trim leading space from offsets
)
```
### ByteLevelProcessing
```python
from tokenizers.processors import ByteLevel as ByteLevelProcessing
tokenizer.post_processor = ByteLevelProcessing(
trim_offsets=True # Remove Ġ from offsets
)
```
## Decoders
Convert token IDs back to text.
### ByteLevel decoder
```python
from tokenizers.decoders import ByteLevel
tokenizer.decoder = ByteLevel()
# Handles byte-level tokens
# ["ĠHello", "Ġworld"] → "Hello world"
```
### WordPiece decoder
```python
from tokenizers.decoders import WordPiece
tokenizer.decoder = WordPiece(prefix="##")
# Removes ## prefix and concatenates
# ["token", "##ization"] → "tokenization"
```
### Metaspace decoder
```python
from tokenizers.decoders import Metaspace
tokenizer.decoder = Metaspace(replacement="▁", add_prefix_space=True)
# Converts ▁ back to spaces
# ["▁Hello", "▁world"] → "Hello world"
```
### BPEDecoder
```python
from tokenizers.decoders import BPEDecoder
tokenizer.decoder = BPEDecoder(suffix="</w>")
# Removes suffix and concatenates
# ["token", "ization</w>"] → "tokenization"
```
### Sequence decoder
```python
from tokenizers.decoders import Sequence, ByteLevel, Strip
tokenizer.decoder = Sequence([
ByteLevel(), # Decode byte-level first
Strip(' ', 1, 1) # Strip leading/trailing spaces
])
```
## Complete pipeline examples
### BERT tokenizer
```python
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.normalizers import BertNormalizer
from tokenizers.pre_tokenizers import BertPreTokenizer
from tokenizers.processors import TemplateProcessing
from tokenizers.decoders import WordPiece as WordPieceDecoder
# Model
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
# Normalization
tokenizer.normalizer = BertNormalizer(lowercase=True)
# Pre-tokenization
tokenizer.pre_tokenizer = BertPreTokenizer()
# Post-processing
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
pair="[CLS] $A [SEP] $B [SEP]",
special_tokens=[("[CLS]", 101), ("[SEP]", 102)],
)
# Decoder
tokenizer.decoder = WordPieceDecoder(prefix="##")
# Enable padding
tokenizer.enable_padding(pad_id=0, pad_token="[PAD]")
# Enable truncation
tokenizer.enable_truncation(max_length=512)
```
### GPT-2 tokenizer
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.normalizers import NFC
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
from tokenizers.processors import TemplateProcessing
# Model
tokenizer = Tokenizer(BPE())
# Normalization (minimal)
tokenizer.normalizer = NFC()
# Byte-level pre-tokenization
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
# Post-processing
tokenizer.post_processor = TemplateProcessing(
single="$A <|endoftext|>",
special_tokens=[("<|endoftext|>", 50256)],
)
# Byte-level decoder
tokenizer.decoder = ByteLevelDecoder()
```
### T5 tokenizer (SentencePiece-style)
```python
from tokenizers import Tokenizer
from tokenizers.models import Unigram
from tokenizers.normalizers import NFKC
from tokenizers.pre_tokenizers import Metaspace
from tokenizers.decoders import Metaspace as MetaspaceDecoder
# Model
tokenizer = Tokenizer(Unigram())
# Normalization
tokenizer.normalizer = NFKC()
# Metaspace pre-tokenization
tokenizer.pre_tokenizer = Metaspace(replacement="▁", add_prefix_space=True)
# No post-processing (T5 doesn't add CLS/SEP)
tokenizer.post_processor = None
# Metaspace decoder
tokenizer.decoder = MetaspaceDecoder(replacement="▁", add_prefix_space=True)
```
## Alignment tracking
Track token positions in original text.
### Basic alignment
```python
text = "Hello, world!"
output = tokenizer.encode(text)
for token, (start, end) in zip(output.tokens, output.offsets):
print(f"{token:10s} → [{start:2d}, {end:2d}): {text[start:end]!r}")
# Output:
# [CLS] → [ 0, 0): ''
# hello → [ 0, 5): 'Hello'
# , → [ 5, 6): ','
# world → [ 7, 12): 'world'
# ! → [12, 13): '!'
# [SEP] → [ 0, 0): ''
```
### Word-level alignment
```python
# Get word_ids (which word each token belongs to)
encoding = tokenizer.encode("Hello world")
word_ids = encoding.word_ids
print(word_ids)
# [None, 0, 0, 1, None]
# None = special token, 0 = first word, 1 = second word
```
**Use case**: Token classification (NER)
```python
# Align predictions to words
predictions = ["O", "B-PER", "I-PER", "O", "O"]
word_predictions = {}
for token_idx, word_idx in enumerate(encoding.word_ids):
if word_idx is not None and word_idx not in word_predictions:
word_predictions[word_idx] = predictions[token_idx]
print(word_predictions)
# {0: "B-PER", 1: "O"} # First word is PERSON, second is OTHER
```
### Span alignment
```python
# Find token span for character span
text = "Machine learning is awesome"
char_start, char_end = 8, 16 # "learning"
encoding = tokenizer.encode(text)
# Find token span
token_start = encoding.char_to_token(char_start)
token_end = encoding.char_to_token(char_end - 1) + 1
print(f"Tokens {token_start}:{token_end} = {encoding.tokens[token_start:token_end]}")
# Tokens 2:3 = ['learning']
```
**Use case**: Question answering (extract answer span)
## Custom components
### Custom normalizer
```python
from tokenizers import NormalizedString, Normalizer
class CustomNormalizer:
def normalize(self, normalized: NormalizedString):
# Custom normalization logic
normalized.lowercase()
normalized.replace(" ", " ") # Replace double spaces
# Use custom normalizer
tokenizer.normalizer = CustomNormalizer()
```
### Custom pre-tokenizer
```python
from tokenizers import PreTokenizedString
class CustomPreTokenizer:
def pre_tokenize(self, pretok: PreTokenizedString):
# Custom pre-tokenization logic
pretok.split(lambda i, char: char.isspace())
tokenizer.pre_tokenizer = CustomPreTokenizer()
```
## Troubleshooting
### Issue: Misaligned offsets
**Symptom**: Offsets don't match original text
```python
text = " hello" # Leading spaces
offsets = [(0, 5)] # Expects " hel"
```
**Solution**: Check normalization strips spaces
```python
# Preserve offsets
tokenizer.normalizer = Sequence([
Strip(), # This changes offsets!
])
# Use trim_offsets in post-processor instead
tokenizer.post_processor = ByteLevelProcessing(trim_offsets=True)
```
### Issue: Special tokens not added
**Symptom**: No [CLS] or [SEP] in output
**Solution**: Check post-processor is set
```python
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
special_tokens=[("[CLS]", 101), ("[SEP]", 102)],
)
```
### Issue: Incorrect decoding
**Symptom**: Decoded text has ## or ▁
**Solution**: Set correct decoder
```python
# For WordPiece
tokenizer.decoder = WordPieceDecoder(prefix="##")
# For SentencePiece
tokenizer.decoder = MetaspaceDecoder(replacement="▁")
```
## Best practices
1. **Match pipeline to model architecture**:
- BERT → BertNormalizer + BertPreTokenizer + WordPiece
- GPT-2 → NFC + ByteLevel + BPE
- T5 → NFKC + Metaspace + Unigram
2. **Test pipeline on sample inputs**:
- Check normalization doesn't over-normalize
- Verify pre-tokenization splits correctly
- Ensure decoding reconstructs text
3. **Preserve alignment for downstream tasks**:
- Use `trim_offsets` instead of stripping in normalizer
- Test `char_to_token()` on sample spans
4. **Document your pipeline**:
- Save complete tokenizer config
- Document special tokens
- Note any custom components

View file

@ -0,0 +1,565 @@
# Training Custom Tokenizers
Complete guide to training tokenizers from scratch.
## Training workflow
### Step 1: Choose tokenization algorithm
**Decision tree**:
- **GPT-style model** → BPE
- **BERT-style model** → WordPiece
- **Multilingual/No word boundaries** → Unigram
### Step 2: Prepare training data
```python
# Option 1: From files
files = ["train.txt", "validation.txt"]
# Option 2: From Python list
texts = [
"This is the first sentence.",
"This is the second sentence.",
# ... more texts
]
# Option 3: From dataset iterator
from datasets import load_dataset
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
def batch_iterator(batch_size=1000):
for i in range(0, len(dataset), batch_size):
yield dataset[i:i + batch_size]["text"]
```
### Step 3: Initialize tokenizer
**BPE example**:
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
tokenizer = Tokenizer(BPE())
tokenizer.pre_tokenizer = ByteLevel()
tokenizer.decoder = ByteLevelDecoder()
trainer = BpeTrainer(
vocab_size=50000,
min_frequency=2,
special_tokens=["<|endoftext|>", "<|padding|>"],
show_progress=True
)
```
**WordPiece example**:
```python
from tokenizers.models import WordPiece
from tokenizers.trainers import WordPieceTrainer
from tokenizers.normalizers import BertNormalizer
from tokenizers.pre_tokenizers import BertPreTokenizer
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
tokenizer.normalizer = BertNormalizer(lowercase=True)
tokenizer.pre_tokenizer = BertPreTokenizer()
trainer = WordPieceTrainer(
vocab_size=30522,
min_frequency=2,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
continuing_subword_prefix="##",
show_progress=True
)
```
**Unigram example**:
```python
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
tokenizer = Tokenizer(Unigram())
trainer = UnigramTrainer(
vocab_size=8000,
special_tokens=["<unk>", "<s>", "</s>", "<pad>"],
unk_token="<unk>",
show_progress=True
)
```
### Step 4: Train
```python
# From files
tokenizer.train(files=files, trainer=trainer)
# From iterator (recommended for large datasets)
tokenizer.train_from_iterator(
batch_iterator(),
trainer=trainer,
length=len(dataset) # Optional, for progress bar
)
```
**Training time** (30k vocab on 16-core CPU):
- 10 MB: 15-30 seconds
- 100 MB: 1-3 minutes
- 1 GB: 15-30 minutes
- 10 GB: 2-4 hours
### Step 5: Add post-processing
```python
from tokenizers.processors import TemplateProcessing
# BERT-style
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
pair="[CLS] $A [SEP] $B [SEP]",
special_tokens=[
("[CLS]", tokenizer.token_to_id("[CLS]")),
("[SEP]", tokenizer.token_to_id("[SEP]")),
],
)
# GPT-2 style
tokenizer.post_processor = TemplateProcessing(
single="$A <|endoftext|>",
special_tokens=[
("<|endoftext|>", tokenizer.token_to_id("<|endoftext|>")),
],
)
```
### Step 6: Save
```python
# Save to JSON
tokenizer.save("my-tokenizer.json")
# Save to directory (for transformers)
tokenizer.save("my-tokenizer-dir/tokenizer.json")
# Convert to transformers format
from transformers import PreTrainedTokenizerFast
transformers_tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
unk_token="[UNK]",
pad_token="[PAD]",
cls_token="[CLS]",
sep_token="[SEP]",
mask_token="[MASK]"
)
transformers_tokenizer.save_pretrained("my-tokenizer-dir")
```
## Trainer configuration
### BpeTrainer parameters
```python
from tokenizers.trainers import BpeTrainer
trainer = BpeTrainer(
vocab_size=30000, # Target vocabulary size
min_frequency=2, # Minimum frequency for merges
special_tokens=["[UNK]"], # Special tokens (added first)
limit_alphabet=1000, # Limit initial alphabet size
initial_alphabet=[], # Pre-defined initial characters
show_progress=True, # Show progress bar
continuing_subword_prefix="", # Prefix for continuing subwords
end_of_word_suffix="" # Suffix for end of words
)
```
**Parameter tuning**:
- **vocab_size**: Start with 30k for English, 50k for multilingual
- **min_frequency**: 2-5 for large corpora, 1 for small
- **limit_alphabet**: Reduce for non-English (CJK languages)
### WordPieceTrainer parameters
```python
from tokenizers.trainers import WordPieceTrainer
trainer = WordPieceTrainer(
vocab_size=30522, # BERT uses 30,522
min_frequency=2,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
limit_alphabet=1000,
continuing_subword_prefix="##", # BERT-style prefix
show_progress=True
)
```
### UnigramTrainer parameters
```python
from tokenizers.trainers import UnigramTrainer
trainer = UnigramTrainer(
vocab_size=8000, # Typically smaller than BPE/WordPiece
special_tokens=["<unk>", "<s>", "</s>"],
unk_token="<unk>",
max_piece_length=16, # Maximum token length
n_sub_iterations=2, # EM algorithm iterations
shrinking_factor=0.75, # Vocabulary reduction rate
show_progress=True
)
```
## Training from large datasets
### Memory-efficient training
```python
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
# Load dataset
dataset = load_dataset("wikipedia", "20220301.en", split="train", streaming=True)
# Create iterator (yields batches)
def batch_iterator(batch_size=1000):
batch = []
for sample in dataset:
batch.append(sample["text"])
if len(batch) >= batch_size:
yield batch
batch = []
if batch:
yield batch
# Initialize tokenizer
tokenizer = Tokenizer(BPE())
trainer = BpeTrainer(vocab_size=50000, special_tokens=["<|endoftext|>"])
# Train (memory efficient - streams data)
tokenizer.train_from_iterator(
batch_iterator(),
trainer=trainer
)
```
**Memory usage**: ~200 MB (vs 10+ GB loading full dataset)
### Multi-file training
```python
import glob
# Find all training files
files = glob.glob("data/train/*.txt")
print(f"Training on {len(files)} files")
# Train on all files
tokenizer.train(files=files, trainer=trainer)
```
### Parallel training (multi-processing)
```python
from multiprocessing import Pool, cpu_count
import os
def train_shard(shard_files):
"""Train tokenizer on a shard of files."""
tokenizer = Tokenizer(BPE())
trainer = BpeTrainer(vocab_size=50000)
tokenizer.train(files=shard_files, trainer=trainer)
return tokenizer.get_vocab()
# Split files into shards
num_shards = cpu_count()
file_shards = [files[i::num_shards] for i in range(num_shards)]
# Train shards in parallel
with Pool(num_shards) as pool:
vocab_shards = pool.map(train_shard, file_shards)
# Merge vocabularies (custom logic needed)
# This is a simplified example - real implementation would merge intelligently
final_vocab = {}
for vocab in vocab_shards:
final_vocab.update(vocab)
```
## Domain-specific tokenizers
### Code tokenizer
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.normalizers import Sequence, NFC
# Code-optimized configuration
tokenizer = Tokenizer(BPE())
# Minimal normalization (preserve case, whitespace)
tokenizer.normalizer = NFC() # Only normalize Unicode
# Byte-level pre-tokenization (handles all characters)
tokenizer.pre_tokenizer = ByteLevel()
# Train on code corpus
trainer = BpeTrainer(
vocab_size=50000,
special_tokens=["<|endoftext|>", "<|pad|>"],
min_frequency=2
)
tokenizer.train(files=["code_corpus.txt"], trainer=trainer)
```
### Medical/scientific tokenizer
```python
# Preserve case and special characters
from tokenizers.normalizers import NFKC
from tokenizers.pre_tokenizers import Whitespace, Punctuation, Sequence
tokenizer = Tokenizer(BPE())
# Minimal normalization
tokenizer.normalizer = NFKC()
# Preserve medical terms
tokenizer.pre_tokenizer = Sequence([
Whitespace(),
Punctuation(behavior="isolated") # Keep punctuation separate
])
trainer = BpeTrainer(
vocab_size=50000,
special_tokens=["[UNK]", "[CLS]", "[SEP]"],
min_frequency=3 # Higher threshold for rare medical terms
)
tokenizer.train(files=["pubmed_corpus.txt"], trainer=trainer)
```
### Multilingual tokenizer
```python
# Handle multiple scripts
from tokenizers.normalizers import NFKC, Lowercase, Sequence
tokenizer = Tokenizer(BPE())
# Normalize but don't lowercase (preserves script differences)
tokenizer.normalizer = NFKC()
# Byte-level handles all Unicode
from tokenizers.pre_tokenizers import ByteLevel
tokenizer.pre_tokenizer = ByteLevel()
trainer = BpeTrainer(
vocab_size=100000, # Larger vocab for multiple languages
special_tokens=["<unk>", "<s>", "</s>"],
limit_alphabet=None # No limit (handles all scripts)
)
# Train on multilingual corpus
tokenizer.train(files=["multilingual_corpus.txt"], trainer=trainer)
```
## Vocabulary size selection
### Guidelines by task
| Task | Recommended Vocab Size | Rationale |
|-----------------------|------------------------|-----------|
| English (monolingual) | 30,000 - 50,000 | Balanced coverage |
| Multilingual | 50,000 - 250,000 | More languages = more tokens |
| Code | 30,000 - 50,000 | Similar to English |
| Domain-specific | 10,000 - 30,000 | Smaller, focused vocabulary |
| Character-level tasks | 1,000 - 5,000 | Only characters + subwords |
### Vocabulary size impact
**Small vocab (10k)**:
- Pros: Faster training, smaller model, less memory
- Cons: More tokens per sentence, worse OOV handling
**Medium vocab (30k-50k)**:
- Pros: Good balance, standard choice
- Cons: None (recommended default)
**Large vocab (100k+)**:
- Pros: Fewer tokens per sentence, better OOV
- Cons: Slower training, larger embedding table
### Empirical testing
```python
# Train multiple tokenizers with different vocab sizes
vocab_sizes = [10000, 30000, 50000, 100000]
for vocab_size in vocab_sizes:
tokenizer = Tokenizer(BPE())
trainer = BpeTrainer(vocab_size=vocab_size)
tokenizer.train(files=["sample.txt"], trainer=trainer)
# Evaluate on test set
test_text = "Test sentence for evaluation..."
tokens = tokenizer.encode(test_text).ids
print(f"Vocab: {vocab_size:6d} | Tokens: {len(tokens):3d} | Avg: {len(test_text)/len(tokens):.2f} chars/token")
# Example output:
# Vocab: 10000 | Tokens: 12 | Avg: 2.33 chars/token
# Vocab: 30000 | Tokens: 8 | Avg: 3.50 chars/token
# Vocab: 50000 | Tokens: 7 | Avg: 4.00 chars/token
# Vocab: 100000 | Tokens: 6 | Avg: 4.67 chars/token
```
## Testing tokenizer quality
### Coverage test
```python
# Test on held-out data
test_corpus = load_dataset("wikitext", "wikitext-103-raw-v1", split="test")
total_tokens = 0
unk_tokens = 0
unk_id = tokenizer.token_to_id("[UNK]")
for text in test_corpus["text"]:
if text.strip():
encoding = tokenizer.encode(text)
total_tokens += len(encoding.ids)
unk_tokens += encoding.ids.count(unk_id)
unk_rate = unk_tokens / total_tokens
print(f"Unknown token rate: {unk_rate:.2%}")
# Good quality: <1% unknown tokens
# Acceptable: 1-5%
# Poor: >5%
```
### Compression test
```python
# Measure tokenization efficiency
import numpy as np
token_lengths = []
for text in test_corpus["text"][:1000]:
if text.strip():
encoding = tokenizer.encode(text)
chars_per_token = len(text) / len(encoding.ids)
token_lengths.append(chars_per_token)
avg_chars_per_token = np.mean(token_lengths)
print(f"Average characters per token: {avg_chars_per_token:.2f}")
# Good: 4-6 chars/token (English)
# Acceptable: 3-4 chars/token
# Poor: <3 chars/token (under-compression)
```
### Semantic test
```python
# Manually inspect tokenization of common words/phrases
test_phrases = [
"tokenization",
"machine learning",
"artificial intelligence",
"preprocessing",
"hello world"
]
for phrase in test_phrases:
tokens = tokenizer.encode(phrase).tokens
print(f"{phrase:25s} → {tokens}")
# Good tokenization:
# tokenization → ['token', 'ization']
# machine learning → ['machine', 'learning']
# artificial intelligence → ['artificial', 'intelligence']
```
## Troubleshooting
### Issue: Training too slow
**Solutions**:
1. Reduce vocabulary size
2. Increase `min_frequency`
3. Use `limit_alphabet` to reduce initial alphabet
4. Train on subset first
```python
# Fast training configuration
trainer = BpeTrainer(
vocab_size=20000, # Smaller vocab
min_frequency=5, # Higher threshold
limit_alphabet=500, # Limit alphabet
show_progress=True
)
```
### Issue: High unknown token rate
**Solutions**:
1. Increase vocabulary size
2. Decrease `min_frequency`
3. Check normalization (might be too aggressive)
```python
# Better coverage configuration
trainer = BpeTrainer(
vocab_size=50000, # Larger vocab
min_frequency=1, # Lower threshold
)
```
### Issue: Poor quality tokenization
**Solutions**:
1. Verify normalization matches your use case
2. Check pre-tokenization splits correctly
3. Ensure training data is representative
4. Try different algorithm (BPE vs WordPiece vs Unigram)
```python
# Debug tokenization pipeline
text = "Sample text to debug"
# Check normalization
normalized = tokenizer.normalizer.normalize_str(text)
print(f"Normalized: {normalized}")
# Check pre-tokenization
pre_tokens = tokenizer.pre_tokenizer.pre_tokenize_str(text)
print(f"Pre-tokens: {pre_tokens}")
# Check final tokenization
tokens = tokenizer.encode(text).tokens
print(f"Tokens: {tokens}")
```
## Best practices
1. **Use representative training data** - Match your target domain
2. **Start with standard configs** - BERT WordPiece or GPT-2 BPE
3. **Test on held-out data** - Measure unknown token rate
4. **Iterate on vocabulary size** - Test 30k, 50k, 100k
5. **Save tokenizer with model** - Ensure reproducibility
6. **Version your tokenizers** - Track changes for reproducibility
7. **Document special tokens** - Critical for model training

View file

@ -0,0 +1,743 @@
---
name: instructor
description: Extract structured data from LLM responses with Pydantic validation, retry failed extractions automatically, parse complex JSON with type safety, and stream partial results with Instructor - battle-tested structured output library
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [instructor, pydantic, openai, anthropic]
metadata:
hermes:
tags: [Prompt Engineering, Instructor, Structured Output, Pydantic, Data Extraction, JSON Parsing, Type Safety, Validation, Streaming, OpenAI, Anthropic]
---
# Instructor: Structured LLM Outputs
## When to Use This Skill
Use Instructor when you need to:
- **Extract structured data** from LLM responses reliably
- **Validate outputs** against Pydantic schemas automatically
- **Retry failed extractions** with automatic error handling
- **Parse complex JSON** with type safety and validation
- **Stream partial results** for real-time processing
- **Support multiple LLM providers** with consistent API
**GitHub Stars**: 15,000+ | **Battle-tested**: 100,000+ developers
## Installation
```bash
# Base installation
pip install instructor
# With specific providers
pip install "instructor[anthropic]" # Anthropic Claude
pip install "instructor[openai]" # OpenAI
pip install "instructor[all]" # All providers
```
## Quick Start
### Basic Example: Extract User Data
```python
import instructor
from pydantic import BaseModel
from anthropic import Anthropic
# Define output structure
class User(BaseModel):
name: str
age: int
email: str
# Create instructor client
client = instructor.from_anthropic(Anthropic())
# Extract structured data
user = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": "John Doe is 30 years old. His email is john@example.com"
}],
response_model=User
)
print(user.name) # "John Doe"
print(user.age) # 30
print(user.email) # "john@example.com"
```
### With OpenAI
```python
from openai import OpenAI
client = instructor.from_openai(OpenAI())
user = client.chat.completions.create(
model="gpt-4o-mini",
response_model=User,
messages=[{"role": "user", "content": "Extract: Alice, 25, alice@email.com"}]
)
```
## Core Concepts
### 1. Response Models (Pydantic)
Response models define the structure and validation rules for LLM outputs.
#### Basic Model
```python
from pydantic import BaseModel, Field
class Article(BaseModel):
title: str = Field(description="Article title")
author: str = Field(description="Author name")
word_count: int = Field(description="Number of words", gt=0)
tags: list[str] = Field(description="List of relevant tags")
article = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": "Analyze this article: [article text]"
}],
response_model=Article
)
```
**Benefits:**
- Type safety with Python type hints
- Automatic validation (word_count > 0)
- Self-documenting with Field descriptions
- IDE autocomplete support
#### Nested Models
```python
class Address(BaseModel):
street: str
city: str
country: str
class Person(BaseModel):
name: str
age: int
address: Address # Nested model
person = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": "John lives at 123 Main St, Boston, USA"
}],
response_model=Person
)
print(person.address.city) # "Boston"
```
#### Optional Fields
```python
from typing import Optional
class Product(BaseModel):
name: str
price: float
discount: Optional[float] = None # Optional
description: str = Field(default="No description") # Default value
# LLM doesn't need to provide discount or description
```
#### Enums for Constraints
```python
from enum import Enum
class Sentiment(str, Enum):
POSITIVE = "positive"
NEGATIVE = "negative"
NEUTRAL = "neutral"
class Review(BaseModel):
text: str
sentiment: Sentiment # Only these 3 values allowed
review = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": "This product is amazing!"
}],
response_model=Review
)
print(review.sentiment) # Sentiment.POSITIVE
```
### 2. Validation
Pydantic validates LLM outputs automatically. If validation fails, Instructor retries.
#### Built-in Validators
```python
from pydantic import Field, EmailStr, HttpUrl
class Contact(BaseModel):
name: str = Field(min_length=2, max_length=100)
age: int = Field(ge=0, le=120) # 0 <= age <= 120
email: EmailStr # Validates email format
website: HttpUrl # Validates URL format
# If LLM provides invalid data, Instructor retries automatically
```
#### Custom Validators
```python
from pydantic import field_validator
class Event(BaseModel):
name: str
date: str
attendees: int
@field_validator('date')
def validate_date(cls, v):
"""Ensure date is in YYYY-MM-DD format."""
import re
if not re.match(r'\d{4}-\d{2}-\d{2}', v):
raise ValueError('Date must be YYYY-MM-DD format')
return v
@field_validator('attendees')
def validate_attendees(cls, v):
"""Ensure positive attendees."""
if v < 1:
raise ValueError('Must have at least 1 attendee')
return v
```
#### Model-Level Validation
```python
from pydantic import model_validator
class DateRange(BaseModel):
start_date: str
end_date: str
@model_validator(mode='after')
def check_dates(self):
"""Ensure end_date is after start_date."""
from datetime import datetime
start = datetime.strptime(self.start_date, '%Y-%m-%d')
end = datetime.strptime(self.end_date, '%Y-%m-%d')
if end < start:
raise ValueError('end_date must be after start_date')
return self
```
### 3. Automatic Retrying
Instructor retries automatically when validation fails, providing error feedback to the LLM.
```python
# Retries up to 3 times if validation fails
user = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": "Extract user from: John, age unknown"
}],
response_model=User,
max_retries=3 # Default is 3
)
# If age can't be extracted, Instructor tells the LLM:
# "Validation error: age - field required"
# LLM tries again with better extraction
```
**How it works:**
1. LLM generates output
2. Pydantic validates
3. If invalid: Error message sent back to LLM
4. LLM tries again with error feedback
5. Repeats up to max_retries
### 4. Streaming
Stream partial results for real-time processing.
#### Streaming Partial Objects
```python
from instructor import Partial
class Story(BaseModel):
title: str
content: str
tags: list[str]
# Stream partial updates as LLM generates
for partial_story in client.messages.create_partial(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": "Write a short sci-fi story"
}],
response_model=Story
):
print(f"Title: {partial_story.title}")
print(f"Content so far: {partial_story.content[:100]}...")
# Update UI in real-time
```
#### Streaming Iterables
```python
class Task(BaseModel):
title: str
priority: str
# Stream list items as they're generated
tasks = client.messages.create_iterable(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": "Generate 10 project tasks"
}],
response_model=Task
)
for task in tasks:
print(f"- {task.title} ({task.priority})")
# Process each task as it arrives
```
## Provider Configuration
### Anthropic Claude
```python
import instructor
from anthropic import Anthropic
client = instructor.from_anthropic(
Anthropic(api_key="your-api-key")
)
# Use with Claude models
response = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[...],
response_model=YourModel
)
```
### OpenAI
```python
from openai import OpenAI
client = instructor.from_openai(
OpenAI(api_key="your-api-key")
)
response = client.chat.completions.create(
model="gpt-4o-mini",
response_model=YourModel,
messages=[...]
)
```
### Local Models (Ollama)
```python
from openai import OpenAI
# Point to local Ollama server
client = instructor.from_openai(
OpenAI(
base_url="http://localhost:11434/v1",
api_key="ollama" # Required but ignored
),
mode=instructor.Mode.JSON
)
response = client.chat.completions.create(
model="llama3.1",
response_model=YourModel,
messages=[...]
)
```
## Common Patterns
### Pattern 1: Data Extraction from Text
```python
class CompanyInfo(BaseModel):
name: str
founded_year: int
industry: str
employees: int
headquarters: str
text = """
Tesla, Inc. was founded in 2003. It operates in the automotive and energy
industry with approximately 140,000 employees. The company is headquartered
in Austin, Texas.
"""
company = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": f"Extract company information from: {text}"
}],
response_model=CompanyInfo
)
```
### Pattern 2: Classification
```python
class Category(str, Enum):
TECHNOLOGY = "technology"
FINANCE = "finance"
HEALTHCARE = "healthcare"
EDUCATION = "education"
OTHER = "other"
class ArticleClassification(BaseModel):
category: Category
confidence: float = Field(ge=0.0, le=1.0)
keywords: list[str]
classification = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": "Classify this article: [article text]"
}],
response_model=ArticleClassification
)
```
### Pattern 3: Multi-Entity Extraction
```python
class Person(BaseModel):
name: str
role: str
class Organization(BaseModel):
name: str
industry: str
class Entities(BaseModel):
people: list[Person]
organizations: list[Organization]
locations: list[str]
text = "Tim Cook, CEO of Apple, announced at the event in Cupertino..."
entities = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": f"Extract all entities from: {text}"
}],
response_model=Entities
)
for person in entities.people:
print(f"{person.name} - {person.role}")
```
### Pattern 4: Structured Analysis
```python
class SentimentAnalysis(BaseModel):
overall_sentiment: Sentiment
positive_aspects: list[str]
negative_aspects: list[str]
suggestions: list[str]
score: float = Field(ge=-1.0, le=1.0)
review = "The product works well but setup was confusing..."
analysis = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": f"Analyze this review: {review}"
}],
response_model=SentimentAnalysis
)
```
### Pattern 5: Batch Processing
```python
def extract_person(text: str) -> Person:
return client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": f"Extract person from: {text}"
}],
response_model=Person
)
texts = [
"John Doe is a 30-year-old engineer",
"Jane Smith, 25, works in marketing",
"Bob Johnson, age 40, software developer"
]
people = [extract_person(text) for text in texts]
```
## Advanced Features
### Union Types
```python
from typing import Union
class TextContent(BaseModel):
type: str = "text"
content: str
class ImageContent(BaseModel):
type: str = "image"
url: HttpUrl
caption: str
class Post(BaseModel):
title: str
content: Union[TextContent, ImageContent] # Either type
# LLM chooses appropriate type based on content
```
### Dynamic Models
```python
from pydantic import create_model
# Create model at runtime
DynamicUser = create_model(
'User',
name=(str, ...),
age=(int, Field(ge=0)),
email=(EmailStr, ...)
)
user = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[...],
response_model=DynamicUser
)
```
### Custom Modes
```python
# For providers without native structured outputs
client = instructor.from_anthropic(
Anthropic(),
mode=instructor.Mode.JSON # JSON mode
)
# Available modes:
# - Mode.ANTHROPIC_TOOLS (recommended for Claude)
# - Mode.JSON (fallback)
# - Mode.TOOLS (OpenAI tools)
```
### Context Management
```python
# Single-use client
with instructor.from_anthropic(Anthropic()) as client:
result = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[...],
response_model=YourModel
)
# Client closed automatically
```
## Error Handling
### Handling Validation Errors
```python
from pydantic import ValidationError
try:
user = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[...],
response_model=User,
max_retries=3
)
except ValidationError as e:
print(f"Failed after retries: {e}")
# Handle gracefully
except Exception as e:
print(f"API error: {e}")
```
### Custom Error Messages
```python
class ValidatedUser(BaseModel):
name: str = Field(description="Full name, 2-100 characters")
age: int = Field(description="Age between 0 and 120", ge=0, le=120)
email: EmailStr = Field(description="Valid email address")
class Config:
# Custom error messages
json_schema_extra = {
"examples": [
{
"name": "John Doe",
"age": 30,
"email": "john@example.com"
}
]
}
```
## Best Practices
### 1. Clear Field Descriptions
```python
# ❌ Bad: Vague
class Product(BaseModel):
name: str
price: float
# ✅ Good: Descriptive
class Product(BaseModel):
name: str = Field(description="Product name from the text")
price: float = Field(description="Price in USD, without currency symbol")
```
### 2. Use Appropriate Validation
```python
# ✅ Good: Constrain values
class Rating(BaseModel):
score: int = Field(ge=1, le=5, description="Rating from 1 to 5 stars")
review: str = Field(min_length=10, description="Review text, at least 10 chars")
```
### 3. Provide Examples in Prompts
```python
messages = [{
"role": "user",
"content": """Extract person info from: "John, 30, engineer"
Example format:
{
"name": "John Doe",
"age": 30,
"occupation": "engineer"
}"""
}]
```
### 4. Use Enums for Fixed Categories
```python
# ✅ Good: Enum ensures valid values
class Status(str, Enum):
PENDING = "pending"
APPROVED = "approved"
REJECTED = "rejected"
class Application(BaseModel):
status: Status # LLM must choose from enum
```
### 5. Handle Missing Data Gracefully
```python
class PartialData(BaseModel):
required_field: str
optional_field: Optional[str] = None
default_field: str = "default_value"
# LLM only needs to provide required_field
```
## Comparison to Alternatives
| Feature | Instructor | Manual JSON | LangChain | DSPy |
|---------|------------|-------------|-----------|------|
| Type Safety | ✅ Yes | ❌ No | ⚠️ Partial | ✅ Yes |
| Auto Validation | ✅ Yes | ❌ No | ❌ No | ⚠️ Limited |
| Auto Retry | ✅ Yes | ❌ No | ❌ No | ✅ Yes |
| Streaming | ✅ Yes | ❌ No | ✅ Yes | ❌ No |
| Multi-Provider | ✅ Yes | ⚠️ Manual | ✅ Yes | ✅ Yes |
| Learning Curve | Low | Low | Medium | High |
**When to choose Instructor:**
- Need structured, validated outputs
- Want type safety and IDE support
- Require automatic retries
- Building data extraction systems
**When to choose alternatives:**
- DSPy: Need prompt optimization
- LangChain: Building complex chains
- Manual: Simple, one-off extractions
## Resources
- **Documentation**: https://python.useinstructor.com
- **GitHub**: https://github.com/jxnl/instructor (15k+ stars)
- **Cookbook**: https://python.useinstructor.com/examples
- **Discord**: Community support available
## See Also
- `references/validation.md` - Advanced validation patterns
- `references/providers.md` - Provider-specific configuration
- `references/examples.md` - Real-world use cases

View file

@ -0,0 +1,107 @@
# Real-World Examples
Practical examples of using Instructor for structured data extraction.
## Data Extraction
```python
class CompanyInfo(BaseModel):
name: str
founded: int
industry: str
employees: int
text = "Apple was founded in 1976 in the technology industry with 164,000 employees."
company = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{"role": "user", "content": f"Extract: {text}"}],
response_model=CompanyInfo
)
```
## Classification
```python
class Sentiment(str, Enum):
POSITIVE = "positive"
NEGATIVE = "negative"
NEUTRAL = "neutral"
class Review(BaseModel):
sentiment: Sentiment
confidence: float = Field(ge=0.0, le=1.0)
review = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{"role": "user", "content": "This product is amazing!"}],
response_model=Review
)
```
## Multi-Entity Extraction
```python
class Person(BaseModel):
name: str
role: str
class Entities(BaseModel):
people: list[Person]
organizations: list[str]
locations: list[str]
entities = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{"role": "user", "content": "Tim Cook, CEO of Apple, spoke in Cupertino..."}],
response_model=Entities
)
```
## Structured Analysis
```python
class Analysis(BaseModel):
summary: str
key_points: list[str]
sentiment: Sentiment
actionable_items: list[str]
analysis = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{"role": "user", "content": "Analyze: [long text]"}],
response_model=Analysis
)
```
## Batch Processing
```python
texts = ["text1", "text2", "text3"]
results = [
client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{"role": "user", "content": text}],
response_model=YourModel
)
for text in texts
]
```
## Streaming
```python
for partial in client.messages.create_partial(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{"role": "user", "content": "Generate report..."}],
response_model=Report
):
print(f"Progress: {partial.title}")
# Update UI in real-time
```

View file

@ -0,0 +1,70 @@
# Provider Configuration
Guide to using Instructor with different LLM providers.
## Anthropic Claude
```python
import instructor
from anthropic import Anthropic
# Basic setup
client = instructor.from_anthropic(Anthropic())
# With API key
client = instructor.from_anthropic(
Anthropic(api_key="your-api-key")
)
# Recommended mode
client = instructor.from_anthropic(
Anthropic(),
mode=instructor.Mode.ANTHROPIC_TOOLS
)
# Usage
result = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{"role": "user", "content": "..."}],
response_model=YourModel
)
```
## OpenAI
```python
from openai import OpenAI
client = instructor.from_openai(OpenAI())
result = client.chat.completions.create(
model="gpt-4o-mini",
response_model=YourModel,
messages=[{"role": "user", "content": "..."}]
)
```
## Local Models (Ollama)
```python
client = instructor.from_openai(
OpenAI(
base_url="http://localhost:11434/v1",
api_key="ollama"
),
mode=instructor.Mode.JSON
)
result = client.chat.completions.create(
model="llama3.1",
response_model=YourModel,
messages=[...]
)
```
## Modes
- `Mode.ANTHROPIC_TOOLS`: Recommended for Claude
- `Mode.TOOLS`: OpenAI function calling
- `Mode.JSON`: Fallback for unsupported providers

View file

@ -0,0 +1,606 @@
# Advanced Validation Patterns
Complete guide to validation in Instructor using Pydantic.
## Table of Contents
- Built-in Validators
- Custom Field Validators
- Model-Level Validation
- Complex Validation Patterns
- Error Handling
## Built-in Validators
### Numeric Constraints
```python
from pydantic import BaseModel, Field
class Product(BaseModel):
price: float = Field(gt=0, description="Price must be positive")
discount: float = Field(ge=0, le=100, description="Discount 0-100%")
quantity: int = Field(ge=1, description="At least 1 item")
rating: float = Field(ge=0.0, le=5.0, description="Rating 0-5 stars")
# If LLM provides invalid values, automatic retry with error feedback
```
**Available constraints:**
- `gt`: Greater than
- `ge`: Greater than or equal
- `lt`: Less than
- `le`: Less than or equal
- `multiple_of`: Must be multiple of this number
### String Constraints
```python
class User(BaseModel):
username: str = Field(
min_length=3,
max_length=20,
pattern=r'^[a-zA-Z0-9_]+$',
description="3-20 alphanumeric characters"
)
bio: str = Field(max_length=500, description="Bio up to 500 chars")
status: str = Field(pattern=r'^(active|inactive|pending)$')
# pattern validates against regex
```
### Email and URL Validation
```python
from pydantic import EmailStr, HttpUrl, AnyUrl
class Contact(BaseModel):
email: EmailStr # Validates email format
website: HttpUrl # Validates HTTP/HTTPS URLs
portfolio: AnyUrl # Any valid URL scheme
contact = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": "Extract: john@example.com, https://example.com"
}],
response_model=Contact
)
```
### Date and DateTime Validation
```python
from datetime import date, datetime
from pydantic import Field, field_validator
class Event(BaseModel):
event_date: date # Validates date format
created_at: datetime # Validates datetime format
year: int = Field(ge=1900, le=2100)
@field_validator('event_date')
def future_date(cls, v):
"""Ensure event is in the future."""
if v < date.today():
raise ValueError('Event must be in the future')
return v
```
### List and Dict Validation
```python
class Document(BaseModel):
tags: list[str] = Field(min_length=1, max_length=10)
keywords: list[str] = Field(min_length=3, description="At least 3 keywords")
metadata: dict[str, str] = Field(description="String key-value pairs")
@field_validator('tags')
def unique_tags(cls, v):
"""Ensure tags are unique."""
if len(v) != len(set(v)):
raise ValueError('Tags must be unique')
return v
```
## Custom Field Validators
### Basic Field Validator
```python
from pydantic import field_validator
class Person(BaseModel):
name: str
age: int
@field_validator('name')
def name_must_not_be_empty(cls, v):
"""Validate name is not empty or just whitespace."""
if not v or not v.strip():
raise ValueError('Name cannot be empty')
return v.strip()
@field_validator('age')
def age_must_be_reasonable(cls, v):
"""Validate age is between 0 and 120."""
if v < 0 or v > 120:
raise ValueError('Age must be between 0 and 120')
return v
```
### Validator with Field Info
```python
from pydantic import ValidationInfo
class Article(BaseModel):
title: str
content: str
@field_validator('content')
def content_length(cls, v, info: ValidationInfo):
"""Validate content is longer than title."""
if 'title' in info.data:
title_len = len(info.data['title'])
if len(v) < title_len * 2:
raise ValueError('Content should be at least 2x title length')
return v
```
### Multiple Fields Validation
```python
class TimeRange(BaseModel):
start_time: str
end_time: str
@field_validator('start_time', 'end_time')
def valid_time_format(cls, v):
"""Validate both times are in HH:MM format."""
import re
if not re.match(r'^\d{2}:\d{2}$', v):
raise ValueError('Time must be in HH:MM format')
return v
```
### Transform and Validate
```python
class URL(BaseModel):
url: str
@field_validator('url')
def normalize_url(cls, v):
"""Add https:// if missing."""
if not v.startswith(('http://', 'https://')):
v = f'https://{v}'
return v
```
## Model-Level Validation
### Cross-Field Validation
```python
from pydantic import model_validator
class DateRange(BaseModel):
start_date: str
end_date: str
@model_validator(mode='after')
def check_dates(self):
"""Ensure end_date is after start_date."""
from datetime import datetime
start = datetime.strptime(self.start_date, '%Y-%m-%d')
end = datetime.strptime(self.end_date, '%Y-%m-%d')
if end < start:
raise ValueError('end_date must be after start_date')
return self
class PriceRange(BaseModel):
min_price: float
max_price: float
@model_validator(mode='after')
def check_price_range(self):
"""Ensure max > min."""
if self.max_price <= self.min_price:
raise ValueError('max_price must be greater than min_price')
return self
```
### Conditional Validation
```python
class Order(BaseModel):
order_type: str # "standard" or "express"
delivery_date: str
delivery_time: Optional[str] = None
@model_validator(mode='after')
def check_delivery_time(self):
"""Express orders need delivery time."""
if self.order_type == "express" and not self.delivery_time:
raise ValueError('Express orders require delivery_time')
return self
```
### Complex Business Logic
```python
class Discount(BaseModel):
code: str
percentage: float = Field(ge=0, le=100)
min_purchase: float = Field(ge=0)
max_discount: float = Field(ge=0)
@model_validator(mode='after')
def validate_discount(self):
"""Ensure discount logic is sound."""
# Max discount can't exceed percentage of min_purchase
theoretical_max = (self.percentage / 100) * self.min_purchase
if self.max_discount > theoretical_max:
self.max_discount = theoretical_max
return self
```
## Complex Validation Patterns
### Nested Model Validation
```python
class Address(BaseModel):
street: str
city: str
country: str
postal_code: str
@field_validator('postal_code')
def validate_postal_code(cls, v, info: ValidationInfo):
"""Validate postal code format based on country."""
if 'country' in info.data:
country = info.data['country']
if country == "USA":
import re
if not re.match(r'^\d{5}(-\d{4})?$', v):
raise ValueError('Invalid US postal code')
elif country == "Canada":
if not re.match(r'^[A-Z]\d[A-Z] \d[A-Z]\d$', v):
raise ValueError('Invalid Canadian postal code')
return v
class Person(BaseModel):
name: str
address: Address
# Nested validation runs automatically
```
### List of Models
```python
class Task(BaseModel):
title: str = Field(min_length=1)
priority: int = Field(ge=1, le=5)
class Project(BaseModel):
name: str
tasks: list[Task] = Field(min_length=1, description="At least 1 task")
@field_validator('tasks')
def at_least_one_high_priority(cls, v):
"""Ensure at least one task has priority >= 4."""
if not any(task.priority >= 4 for task in v):
raise ValueError('Project needs at least one high-priority task')
return v
```
### Union Type Validation
```python
from typing import Union
class TextBlock(BaseModel):
type: str = "text"
content: str = Field(min_length=1)
class ImageBlock(BaseModel):
type: str = "image"
url: HttpUrl
alt_text: str
class Page(BaseModel):
title: str
blocks: list[Union[TextBlock, ImageBlock]]
@field_validator('blocks')
def validate_block_types(cls, v):
"""Ensure first block is TextBlock."""
if v and not isinstance(v[0], TextBlock):
raise ValueError('First block must be text')
return v
```
### Dependent Fields
```python
class Subscription(BaseModel):
plan: str # "free", "pro", "enterprise"
max_users: int
features: list[str]
@model_validator(mode='after')
def validate_plan_limits(self):
"""Enforce plan-specific limits."""
limits = {
"free": {"max_users": 1, "required_features": ["basic"]},
"pro": {"max_users": 10, "required_features": ["basic", "advanced"]},
"enterprise": {"max_users": 999, "required_features": ["basic", "advanced", "premium"]}
}
if self.plan in limits:
limit = limits[self.plan]
if self.max_users > limit["max_users"]:
raise ValueError(f'{self.plan} plan limited to {limit["max_users"]} users')
for feature in limit["required_features"]:
if feature not in self.features:
raise ValueError(f'{self.plan} plan requires {feature} feature')
return self
```
## Error Handling
### Graceful Degradation
```python
class OptionalExtraction(BaseModel):
# Required fields
title: str
# Optional fields with defaults
author: Optional[str] = None
date: Optional[str] = None
tags: list[str] = Field(default_factory=list)
# LLM can succeed even if it can't extract everything
```
### Partial Validation
```python
from pydantic import ValidationError
def extract_with_fallback(text: str):
"""Try full extraction, fall back to partial."""
try:
# Try full extraction
return client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{"role": "user", "content": text}],
response_model=FullModel
)
except ValidationError:
# Fall back to partial model
return client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{"role": "user", "content": text}],
response_model=PartialModel
)
```
### Validation Error Inspection
```python
from pydantic import ValidationError
try:
result = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[...],
response_model=MyModel,
max_retries=3
)
except ValidationError as e:
# Inspect specific errors
for error in e.errors():
field = error['loc'][0]
message = error['msg']
print(f"Field '{field}' failed: {message}")
# Custom handling per field
if field == 'email':
# Handle email validation failure
pass
```
### Custom Error Messages
```python
class DetailedModel(BaseModel):
name: str = Field(
min_length=2,
max_length=100,
description="Name between 2-100 characters"
)
age: int = Field(
ge=0,
le=120,
description="Age between 0 and 120 years"
)
@field_validator('name')
def validate_name(cls, v):
"""Provide helpful error message."""
if not v.strip():
raise ValueError(
'Name cannot be empty. '
'Please provide a valid name from the text.'
)
return v
# When validation fails, LLM sees these helpful messages
```
## Validation Best Practices
### 1. Be Specific
```python
# ❌ Bad: Vague validation
class Item(BaseModel):
name: str
# ✅ Good: Specific constraints
class Item(BaseModel):
name: str = Field(
min_length=1,
max_length=200,
description="Item name, 1-200 characters"
)
```
### 2. Provide Context
```python
# ✅ Good: Explain why validation failed
@field_validator('price')
def validate_price(cls, v):
if v <= 0:
raise ValueError(
'Price must be positive. '
'Extract numeric price from text without currency symbols.'
)
return v
```
### 3. Use Enums for Fixed Sets
```python
# ❌ Bad: String validation
status: str
@field_validator('status')
def validate_status(cls, v):
if v not in ['active', 'inactive', 'pending']:
raise ValueError('Invalid status')
return v
# ✅ Good: Enum
class Status(str, Enum):
ACTIVE = "active"
INACTIVE = "inactive"
PENDING = "pending"
status: Status # Validation automatic
```
### 4. Balance Strictness
```python
# Too strict: May fail unnecessarily
class StrictModel(BaseModel):
date: str = Field(pattern=r'^\d{4}-\d{2}-\d{2}$')
# Fails if LLM uses "2024-1-5" instead of "2024-01-05"
# Better: Normalize in validator
class FlexibleModel(BaseModel):
date: str
@field_validator('date')
def normalize_date(cls, v):
from datetime import datetime
# Parse flexible formats
for fmt in ['%Y-%m-%d', '%Y/%m/%d', '%m/%d/%Y']:
try:
dt = datetime.strptime(v, fmt)
return dt.strftime('%Y-%m-%d') # Normalize
except ValueError:
continue
raise ValueError('Invalid date format')
```
### 5. Test Validation
```python
# Test your validators with edge cases
def test_validation():
# Should succeed
valid = MyModel(field="valid_value")
# Should fail
try:
invalid = MyModel(field="invalid")
assert False, "Should have raised ValidationError"
except ValidationError:
pass # Expected
# Run tests before using in production
```
## Advanced Techniques
### Conditional Required Fields
```python
from typing import Optional
class ConditionalModel(BaseModel):
type: str
detail_a: Optional[str] = None
detail_b: Optional[str] = None
@model_validator(mode='after')
def check_required_details(self):
"""Require different fields based on type."""
if self.type == "type_a" and not self.detail_a:
raise ValueError('type_a requires detail_a')
if self.type == "type_b" and not self.detail_b:
raise ValueError('type_b requires detail_b')
return self
```
### Validation with External Data
```python
class Product(BaseModel):
sku: str
name: str
@field_validator('sku')
def validate_sku(cls, v):
"""Check SKU exists in database."""
# Query database or API
if not database.sku_exists(v):
raise ValueError(f'SKU {v} not found in catalog')
return v
```
### Progressive Validation
```python
# Start with loose validation
class Stage1(BaseModel):
data: str # Any string
# Then strict validation
class Stage2(BaseModel):
data: str = Field(pattern=r'^[A-Z]{3}-\d{6}$')
# Use Stage1 for initial extraction
# Use Stage2 for final validation
```
## Resources
- **Pydantic Docs**: https://docs.pydantic.dev/latest/concepts/validators/
- **Instructor Examples**: https://python.useinstructor.com/examples

View file

@ -0,0 +1,548 @@
---
name: lambda-labs-gpu-cloud
description: Reserved and on-demand GPU cloud instances for ML training and inference. Use when you need dedicated GPU instances with simple SSH access, persistent filesystems, or high-performance multi-node clusters for large-scale training.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [lambda-cloud-client>=1.0.0]
metadata:
hermes:
tags: [Infrastructure, GPU Cloud, Training, Inference, Lambda Labs]
---
# Lambda Labs GPU Cloud
Comprehensive guide to running ML workloads on Lambda Labs GPU cloud with on-demand instances and 1-Click Clusters.
## When to use Lambda Labs
**Use Lambda Labs when:**
- Need dedicated GPU instances with full SSH access
- Running long training jobs (hours to days)
- Want simple pricing with no egress fees
- Need persistent storage across sessions
- Require high-performance multi-node clusters (16-512 GPUs)
- Want pre-installed ML stack (Lambda Stack with PyTorch, CUDA, NCCL)
**Key features:**
- **GPU variety**: B200, H100, GH200, A100, A10, A6000, V100
- **Lambda Stack**: Pre-installed PyTorch, TensorFlow, CUDA, cuDNN, NCCL
- **Persistent filesystems**: Keep data across instance restarts
- **1-Click Clusters**: 16-512 GPU Slurm clusters with InfiniBand
- **Simple pricing**: Pay-per-minute, no egress fees
- **Global regions**: 12+ regions worldwide
**Use alternatives instead:**
- **Modal**: For serverless, auto-scaling workloads
- **SkyPilot**: For multi-cloud orchestration and cost optimization
- **RunPod**: For cheaper spot instances and serverless endpoints
- **Vast.ai**: For GPU marketplace with lowest prices
## Quick start
### Account setup
1. Create account at https://lambda.ai
2. Add payment method
3. Generate API key from dashboard
4. Add SSH key (required before launching instances)
### Launch via console
1. Go to https://cloud.lambda.ai/instances
2. Click "Launch instance"
3. Select GPU type and region
4. Choose SSH key
5. Optionally attach filesystem
6. Launch and wait 3-15 minutes
### Connect via SSH
```bash
# Get instance IP from console
ssh ubuntu@<INSTANCE-IP>
# Or with specific key
ssh -i ~/.ssh/lambda_key ubuntu@<INSTANCE-IP>
```
## GPU instances
### Available GPUs
| GPU | VRAM | Price/GPU/hr | Best For |
|-----|------|--------------|----------|
| B200 SXM6 | 180 GB | $4.99 | Largest models, fastest training |
| H100 SXM | 80 GB | $2.99-3.29 | Large model training |
| H100 PCIe | 80 GB | $2.49 | Cost-effective H100 |
| GH200 | 96 GB | $1.49 | Single-GPU large models |
| A100 80GB | 80 GB | $1.79 | Production training |
| A100 40GB | 40 GB | $1.29 | Standard training |
| A10 | 24 GB | $0.75 | Inference, fine-tuning |
| A6000 | 48 GB | $0.80 | Good VRAM/price ratio |
| V100 | 16 GB | $0.55 | Budget training |
### Instance configurations
```
8x GPU: Best for distributed training (DDP, FSDP)
4x GPU: Large models, multi-GPU training
2x GPU: Medium workloads
1x GPU: Fine-tuning, inference, development
```
### Launch times
- Single-GPU: 3-5 minutes
- Multi-GPU: 10-15 minutes
## Lambda Stack
All instances come with Lambda Stack pre-installed:
```bash
# Included software
- Ubuntu 22.04 LTS
- NVIDIA drivers (latest)
- CUDA 12.x
- cuDNN 8.x
- NCCL (for multi-GPU)
- PyTorch (latest)
- TensorFlow (latest)
- JAX
- JupyterLab
```
### Verify installation
```bash
# Check GPU
nvidia-smi
# Check PyTorch
python -c "import torch; print(torch.cuda.is_available())"
# Check CUDA version
nvcc --version
```
## Python API
### Installation
```bash
pip install lambda-cloud-client
```
### Authentication
```python
import os
import lambda_cloud_client
# Configure with API key
configuration = lambda_cloud_client.Configuration(
host="https://cloud.lambdalabs.com/api/v1",
access_token=os.environ["LAMBDA_API_KEY"]
)
```
### List available instances
```python
with lambda_cloud_client.ApiClient(configuration) as api_client:
api = lambda_cloud_client.DefaultApi(api_client)
# Get available instance types
types = api.instance_types()
for name, info in types.data.items():
print(f"{name}: {info.instance_type.description}")
```
### Launch instance
```python
from lambda_cloud_client.models import LaunchInstanceRequest
request = LaunchInstanceRequest(
region_name="us-west-1",
instance_type_name="gpu_1x_h100_sxm5",
ssh_key_names=["my-ssh-key"],
file_system_names=["my-filesystem"], # Optional
name="training-job"
)
response = api.launch_instance(request)
instance_id = response.data.instance_ids[0]
print(f"Launched: {instance_id}")
```
### List running instances
```python
instances = api.list_instances()
for instance in instances.data:
print(f"{instance.name}: {instance.ip} ({instance.status})")
```
### Terminate instance
```python
from lambda_cloud_client.models import TerminateInstanceRequest
request = TerminateInstanceRequest(
instance_ids=[instance_id]
)
api.terminate_instance(request)
```
### SSH key management
```python
from lambda_cloud_client.models import AddSshKeyRequest
# Add SSH key
request = AddSshKeyRequest(
name="my-key",
public_key="ssh-rsa AAAA..."
)
api.add_ssh_key(request)
# List keys
keys = api.list_ssh_keys()
# Delete key
api.delete_ssh_key(key_id)
```
## CLI with curl
### List instance types
```bash
curl -u $LAMBDA_API_KEY: \
https://cloud.lambdalabs.com/api/v1/instance-types | jq
```
### Launch instance
```bash
curl -u $LAMBDA_API_KEY: \
-X POST https://cloud.lambdalabs.com/api/v1/instance-operations/launch \
-H "Content-Type: application/json" \
-d '{
"region_name": "us-west-1",
"instance_type_name": "gpu_1x_h100_sxm5",
"ssh_key_names": ["my-key"]
}' | jq
```
### Terminate instance
```bash
curl -u $LAMBDA_API_KEY: \
-X POST https://cloud.lambdalabs.com/api/v1/instance-operations/terminate \
-H "Content-Type: application/json" \
-d '{"instance_ids": ["<INSTANCE-ID>"]}' | jq
```
## Persistent storage
### Filesystems
Filesystems persist data across instance restarts:
```bash
# Mount location
/lambda/nfs/<FILESYSTEM_NAME>
# Example: save checkpoints
python train.py --checkpoint-dir /lambda/nfs/my-storage/checkpoints
```
### Create filesystem
1. Go to Storage in Lambda console
2. Click "Create filesystem"
3. Select region (must match instance region)
4. Name and create
### Attach to instance
Filesystems must be attached at instance launch time:
- Via console: Select filesystem when launching
- Via API: Include `file_system_names` in launch request
### Best practices
```bash
# Store on filesystem (persists)
/lambda/nfs/storage/
├── datasets/
├── checkpoints/
├── models/
└── outputs/
# Local SSD (faster, ephemeral)
/home/ubuntu/
└── working/ # Temporary files
```
## SSH configuration
### Add SSH key
```bash
# Generate key locally
ssh-keygen -t ed25519 -f ~/.ssh/lambda_key
# Add public key to Lambda console
# Or via API
```
### Multiple keys
```bash
# On instance, add more keys
echo 'ssh-rsa AAAA...' >> ~/.ssh/authorized_keys
```
### Import from GitHub
```bash
# On instance
ssh-import-id gh:username
```
### SSH tunneling
```bash
# Forward Jupyter
ssh -L 8888:localhost:8888 ubuntu@<IP>
# Forward TensorBoard
ssh -L 6006:localhost:6006 ubuntu@<IP>
# Multiple ports
ssh -L 8888:localhost:8888 -L 6006:localhost:6006 ubuntu@<IP>
```
## JupyterLab
### Launch from console
1. Go to Instances page
2. Click "Launch" in Cloud IDE column
3. JupyterLab opens in browser
### Manual access
```bash
# On instance
jupyter lab --ip=0.0.0.0 --port=8888
# From local machine with tunnel
ssh -L 8888:localhost:8888 ubuntu@<IP>
# Open http://localhost:8888
```
## Training workflows
### Single-GPU training
```bash
# SSH to instance
ssh ubuntu@<IP>
# Clone repo
git clone https://github.com/user/project
cd project
# Install dependencies
pip install -r requirements.txt
# Train
python train.py --epochs 100 --checkpoint-dir /lambda/nfs/storage/checkpoints
```
### Multi-GPU training (single node)
```python
# train_ddp.py
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def main():
dist.init_process_group("nccl")
rank = dist.get_rank()
device = rank % torch.cuda.device_count()
model = MyModel().to(device)
model = DDP(model, device_ids=[device])
# Training loop...
if __name__ == "__main__":
main()
```
```bash
# Launch with torchrun (8 GPUs)
torchrun --nproc_per_node=8 train_ddp.py
```
### Checkpoint to filesystem
```python
import os
checkpoint_dir = "/lambda/nfs/my-storage/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
# Save checkpoint
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, f"{checkpoint_dir}/checkpoint_{epoch}.pt")
```
## 1-Click Clusters
### Overview
High-performance Slurm clusters with:
- 16-512 NVIDIA H100 or B200 GPUs
- NVIDIA Quantum-2 400 Gb/s InfiniBand
- GPUDirect RDMA at 3200 Gb/s
- Pre-installed distributed ML stack
### Included software
- Ubuntu 22.04 LTS + Lambda Stack
- NCCL, Open MPI
- PyTorch with DDP and FSDP
- TensorFlow
- OFED drivers
### Storage
- 24 TB NVMe per compute node (ephemeral)
- Lambda filesystems for persistent data
### Multi-node training
```bash
# On Slurm cluster
srun --nodes=4 --ntasks-per-node=8 --gpus-per-node=8 \
torchrun --nnodes=4 --nproc_per_node=8 \
--rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR:29500 \
train.py
```
## Networking
### Bandwidth
- Inter-instance (same region): up to 200 Gbps
- Internet outbound: 20 Gbps max
### Firewall
- Default: Only port 22 (SSH) open
- Configure additional ports in Lambda console
- ICMP traffic allowed by default
### Private IPs
```bash
# Find private IP
ip addr show | grep 'inet '
```
## Common workflows
### Workflow 1: Fine-tuning LLM
```bash
# 1. Launch 8x H100 instance with filesystem
# 2. SSH and setup
ssh ubuntu@<IP>
pip install transformers accelerate peft
# 3. Download model to filesystem
python -c "
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf')
model.save_pretrained('/lambda/nfs/storage/models/llama-2-7b')
"
# 4. Fine-tune with checkpoints on filesystem
accelerate launch --num_processes 8 train.py \
--model_path /lambda/nfs/storage/models/llama-2-7b \
--output_dir /lambda/nfs/storage/outputs \
--checkpoint_dir /lambda/nfs/storage/checkpoints
```
### Workflow 2: Batch inference
```bash
# 1. Launch A10 instance (cost-effective for inference)
# 2. Run inference
python inference.py \
--model /lambda/nfs/storage/models/fine-tuned \
--input /lambda/nfs/storage/data/inputs.jsonl \
--output /lambda/nfs/storage/data/outputs.jsonl
```
## Cost optimization
### Choose right GPU
| Task | Recommended GPU |
|------|-----------------|
| LLM fine-tuning (7B) | A100 40GB |
| LLM fine-tuning (70B) | 8x H100 |
| Inference | A10, A6000 |
| Development | V100, A10 |
| Maximum performance | B200 |
### Reduce costs
1. **Use filesystems**: Avoid re-downloading data
2. **Checkpoint frequently**: Resume interrupted training
3. **Right-size**: Don't over-provision GPUs
4. **Terminate idle**: No auto-stop, manually terminate
### Monitor usage
- Dashboard shows real-time GPU utilization
- API for programmatic monitoring
## Common issues
| Issue | Solution |
|-------|----------|
| Instance won't launch | Check region availability, try different GPU |
| SSH connection refused | Wait for instance to initialize (3-15 min) |
| Data lost after terminate | Use persistent filesystems |
| Slow data transfer | Use filesystem in same region |
| GPU not detected | Reboot instance, check drivers |
## References
- **[Advanced Usage](references/advanced-usage.md)** - Multi-node training, API automation
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
## Resources
- **Documentation**: https://docs.lambda.ai
- **Console**: https://cloud.lambda.ai
- **Pricing**: https://lambda.ai/instances
- **Support**: https://support.lambdalabs.com
- **Blog**: https://lambda.ai/blog

View file

@ -0,0 +1,611 @@
# Lambda Labs Advanced Usage Guide
## Multi-Node Distributed Training
### PyTorch DDP across nodes
```python
# train_multi_node.py
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_distributed():
# Environment variables set by launcher
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"])
dist.init_process_group(
backend="nccl",
rank=rank,
world_size=world_size
)
torch.cuda.set_device(local_rank)
return rank, world_size, local_rank
def main():
rank, world_size, local_rank = setup_distributed()
model = MyModel().cuda(local_rank)
model = DDP(model, device_ids=[local_rank])
# Training loop with synchronized gradients
for epoch in range(num_epochs):
train_one_epoch(model, dataloader)
# Save checkpoint on rank 0 only
if rank == 0:
torch.save(model.module.state_dict(), f"checkpoint_{epoch}.pt")
dist.destroy_process_group()
if __name__ == "__main__":
main()
```
### Launch on multiple instances
```bash
# On Node 0 (master)
export MASTER_ADDR=<NODE0_PRIVATE_IP>
export MASTER_PORT=29500
torchrun \
--nnodes=2 \
--nproc_per_node=8 \
--node_rank=0 \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
train_multi_node.py
# On Node 1
export MASTER_ADDR=<NODE0_PRIVATE_IP>
export MASTER_PORT=29500
torchrun \
--nnodes=2 \
--nproc_per_node=8 \
--node_rank=1 \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
train_multi_node.py
```
### FSDP for large models
```python
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
# Wrap policy for transformer models
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={LlamaDecoderLayer}
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
device_id=local_rank,
)
```
### DeepSpeed ZeRO
```python
# ds_config.json
{
"train_batch_size": 64,
"gradient_accumulation_steps": 4,
"fp16": {"enabled": true},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {"device": "cpu"},
"offload_param": {"device": "cpu"}
}
}
```
```bash
# Launch with DeepSpeed
deepspeed --num_nodes=2 \
--num_gpus=8 \
--hostfile=hostfile.txt \
train.py --deepspeed ds_config.json
```
### Hostfile for multi-node
```bash
# hostfile.txt
node0_ip slots=8
node1_ip slots=8
```
## API Automation
### Auto-launch training jobs
```python
import os
import time
import lambda_cloud_client
from lambda_cloud_client.models import LaunchInstanceRequest
class LambdaJobManager:
def __init__(self, api_key: str):
self.config = lambda_cloud_client.Configuration(
host="https://cloud.lambdalabs.com/api/v1",
access_token=api_key
)
def find_available_gpu(self, gpu_types: list[str], regions: list[str] = None):
"""Find first available GPU type across regions."""
with lambda_cloud_client.ApiClient(self.config) as client:
api = lambda_cloud_client.DefaultApi(client)
types = api.instance_types()
for gpu_type in gpu_types:
if gpu_type in types.data:
info = types.data[gpu_type]
for region in info.regions_with_capacity_available:
if regions is None or region.name in regions:
return gpu_type, region.name
return None, None
def launch_and_wait(self, instance_type: str, region: str,
ssh_key: str, filesystem: str = None,
timeout: int = 900) -> dict:
"""Launch instance and wait for it to be ready."""
with lambda_cloud_client.ApiClient(self.config) as client:
api = lambda_cloud_client.DefaultApi(client)
request = LaunchInstanceRequest(
region_name=region,
instance_type_name=instance_type,
ssh_key_names=[ssh_key],
file_system_names=[filesystem] if filesystem else [],
)
response = api.launch_instance(request)
instance_id = response.data.instance_ids[0]
# Poll until ready
start = time.time()
while time.time() - start < timeout:
instance = api.get_instance(instance_id)
if instance.data.status == "active":
return {
"id": instance_id,
"ip": instance.data.ip,
"status": "active"
}
time.sleep(30)
raise TimeoutError(f"Instance {instance_id} not ready after {timeout}s")
def terminate(self, instance_ids: list[str]):
"""Terminate instances."""
from lambda_cloud_client.models import TerminateInstanceRequest
with lambda_cloud_client.ApiClient(self.config) as client:
api = lambda_cloud_client.DefaultApi(client)
request = TerminateInstanceRequest(instance_ids=instance_ids)
api.terminate_instance(request)
# Usage
manager = LambdaJobManager(os.environ["LAMBDA_API_KEY"])
# Find available H100 or A100
gpu_type, region = manager.find_available_gpu(
["gpu_8x_h100_sxm5", "gpu_8x_a100_80gb_sxm4"],
regions=["us-west-1", "us-east-1"]
)
if gpu_type:
instance = manager.launch_and_wait(
gpu_type, region,
ssh_key="my-key",
filesystem="training-data"
)
print(f"Ready: ssh ubuntu@{instance['ip']}")
```
### Batch job submission
```python
import subprocess
import paramiko
def run_remote_job(ip: str, ssh_key_path: str, commands: list[str]):
"""Execute commands on remote instance."""
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(ip, username="ubuntu", key_filename=ssh_key_path)
for cmd in commands:
stdin, stdout, stderr = client.exec_command(cmd)
print(stdout.read().decode())
if stderr.read():
print(f"Error: {stderr.read().decode()}")
client.close()
# Submit training job
commands = [
"cd /lambda/nfs/storage/project",
"git pull",
"pip install -r requirements.txt",
"nohup torchrun --nproc_per_node=8 train.py > train.log 2>&1 &"
]
run_remote_job(instance["ip"], "~/.ssh/lambda_key", commands)
```
### Monitor training progress
```python
def monitor_job(ip: str, ssh_key_path: str, log_file: str = "train.log"):
"""Stream training logs from remote instance."""
import time
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(ip, username="ubuntu", key_filename=ssh_key_path)
# Tail log file
stdin, stdout, stderr = client.exec_command(f"tail -f {log_file}")
try:
for line in stdout:
print(line.strip())
except KeyboardInterrupt:
pass
finally:
client.close()
```
## 1-Click Cluster Workflows
### Slurm job submission
```bash
#!/bin/bash
#SBATCH --job-name=llm-training
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=8
#SBATCH --gpus-per-node=8
#SBATCH --time=24:00:00
#SBATCH --output=logs/%j.out
#SBATCH --error=logs/%j.err
# Set up distributed environment
export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
export MASTER_PORT=29500
# Launch training
srun torchrun \
--nnodes=$SLURM_NNODES \
--nproc_per_node=$SLURM_GPUS_PER_NODE \
--rdzv_backend=c10d \
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
train.py \
--config config.yaml
```
### Interactive cluster session
```bash
# Request interactive session
srun --nodes=1 --ntasks=1 --gpus=8 --time=4:00:00 --pty bash
# Now on compute node with 8 GPUs
nvidia-smi
python train.py
```
### Monitoring cluster jobs
```bash
# View job queue
squeue
# View job details
scontrol show job <JOB_ID>
# Cancel job
scancel <JOB_ID>
# View node status
sinfo
# View GPU usage across cluster
srun --nodes=4 nvidia-smi --query-gpu=name,utilization.gpu --format=csv
```
## Advanced Filesystem Usage
### Data staging workflow
```bash
# Stage data from S3 to filesystem (one-time)
aws s3 sync s3://my-bucket/dataset /lambda/nfs/storage/datasets/
# Or use rclone
rclone sync s3:my-bucket/dataset /lambda/nfs/storage/datasets/
```
### Shared filesystem across instances
```python
# Instance 1: Write checkpoints
checkpoint_path = "/lambda/nfs/shared/checkpoints/model_step_1000.pt"
torch.save(model.state_dict(), checkpoint_path)
# Instance 2: Read checkpoints
model.load_state_dict(torch.load(checkpoint_path))
```
### Filesystem best practices
```bash
# Organize for ML workflows
/lambda/nfs/storage/
├── datasets/
│ ├── raw/ # Original data
│ └── processed/ # Preprocessed data
├── models/
│ ├── pretrained/ # Base models
│ └── fine-tuned/ # Your trained models
├── checkpoints/
│ └── experiment_1/ # Per-experiment checkpoints
├── logs/
│ └── tensorboard/ # Training logs
└── outputs/
└── inference/ # Inference results
```
## Environment Management
### Custom Python environments
```bash
# Don't modify system Python, create venv
python -m venv ~/myenv
source ~/myenv/bin/activate
# Install packages
pip install torch transformers accelerate
# Save to filesystem for reuse
cp -r ~/myenv /lambda/nfs/storage/envs/myenv
```
### Conda environments
```bash
# Install miniconda (if not present)
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh -b -p ~/miniconda3
# Create environment
~/miniconda3/bin/conda create -n ml python=3.10 pytorch pytorch-cuda=12.1 -c pytorch -c nvidia -y
# Activate
source ~/miniconda3/bin/activate ml
```
### Docker containers
```bash
# Pull and run NVIDIA container
docker run --gpus all -it --rm \
-v /lambda/nfs/storage:/data \
nvcr.io/nvidia/pytorch:24.01-py3
# Run training in container
docker run --gpus all -d \
-v /lambda/nfs/storage:/data \
-v $(pwd):/workspace \
nvcr.io/nvidia/pytorch:24.01-py3 \
python /workspace/train.py
```
## Monitoring and Observability
### GPU monitoring
```bash
# Real-time GPU stats
watch -n 1 nvidia-smi
# GPU utilization over time
nvidia-smi dmon -s u -d 1
# Detailed GPU info
nvidia-smi -q
```
### System monitoring
```bash
# CPU and memory
htop
# Disk I/O
iostat -x 1
# Network
iftop
# All resources
glances
```
### TensorBoard integration
```bash
# Start TensorBoard
tensorboard --logdir /lambda/nfs/storage/logs --port 6006 --bind_all
# SSH tunnel from local machine
ssh -L 6006:localhost:6006 ubuntu@<IP>
# Access at http://localhost:6006
```
### Weights & Biases integration
```python
import wandb
# Initialize with API key
wandb.login(key=os.environ["WANDB_API_KEY"])
# Start run
wandb.init(
project="lambda-training",
config={"learning_rate": 1e-4, "epochs": 100}
)
# Log metrics
wandb.log({"loss": loss, "accuracy": acc})
# Save artifacts to filesystem + W&B
wandb.save("/lambda/nfs/storage/checkpoints/best_model.pt")
```
## Cost Optimization Strategies
### Checkpointing for interruption recovery
```python
import os
def save_checkpoint(model, optimizer, epoch, loss, path):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, path)
def load_checkpoint(path, model, optimizer):
if os.path.exists(path):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch'], checkpoint['loss']
return 0, float('inf')
# Save every N steps to filesystem
checkpoint_path = "/lambda/nfs/storage/checkpoints/latest.pt"
if step % 1000 == 0:
save_checkpoint(model, optimizer, epoch, loss, checkpoint_path)
```
### Instance selection by workload
```python
def recommend_instance(model_params: int, batch_size: int, task: str) -> str:
"""Recommend Lambda instance based on workload."""
if task == "inference":
if model_params < 7e9:
return "gpu_1x_a10" # $0.75/hr
elif model_params < 13e9:
return "gpu_1x_a6000" # $0.80/hr
else:
return "gpu_1x_h100_pcie" # $2.49/hr
elif task == "fine-tuning":
if model_params < 7e9:
return "gpu_1x_a100" # $1.29/hr
elif model_params < 13e9:
return "gpu_4x_a100" # $5.16/hr
else:
return "gpu_8x_h100_sxm5" # $23.92/hr
elif task == "pretraining":
return "gpu_8x_h100_sxm5" # Maximum performance
return "gpu_1x_a100" # Default
```
### Auto-terminate idle instances
```python
import time
from datetime import datetime, timedelta
def auto_terminate_idle(api_key: str, idle_threshold_hours: float = 2):
"""Terminate instances idle for too long."""
manager = LambdaJobManager(api_key)
with lambda_cloud_client.ApiClient(manager.config) as client:
api = lambda_cloud_client.DefaultApi(client)
instances = api.list_instances()
for instance in instances.data:
# Check if instance has been running without activity
# (You'd need to track this separately)
launch_time = instance.launched_at
if datetime.now() - launch_time > timedelta(hours=idle_threshold_hours):
print(f"Terminating idle instance: {instance.id}")
manager.terminate([instance.id])
```
## Security Best Practices
### SSH key rotation
```bash
# Generate new key pair
ssh-keygen -t ed25519 -f ~/.ssh/lambda_key_new -C "lambda-$(date +%Y%m)"
# Add new key via Lambda console or API
# Update authorized_keys on running instances
ssh ubuntu@<IP> "echo '$(cat ~/.ssh/lambda_key_new.pub)' >> ~/.ssh/authorized_keys"
# Test new key
ssh -i ~/.ssh/lambda_key_new ubuntu@<IP>
# Remove old key from Lambda console
```
### Firewall configuration
```bash
# Lambda console: Only open necessary ports
# Recommended:
# - 22 (SSH) - Always needed
# - 6006 (TensorBoard) - If using
# - 8888 (Jupyter) - If using
# - 29500 (PyTorch distributed) - For multi-node only
```
### Secrets management
```bash
# Don't hardcode API keys in code
# Use environment variables
export HF_TOKEN="hf_..."
export WANDB_API_KEY="..."
# Or use .env file (add to .gitignore)
source .env
# On instance, store in ~/.bashrc
echo 'export HF_TOKEN="..."' >> ~/.bashrc
```

View file

@ -0,0 +1,530 @@
# Lambda Labs Troubleshooting Guide
## Instance Launch Issues
### No instances available
**Error**: "No capacity available" or instance type not listed
**Solutions**:
```bash
# Check availability via API
curl -u $LAMBDA_API_KEY: \
https://cloud.lambdalabs.com/api/v1/instance-types | jq '.data | to_entries[] | select(.value.regions_with_capacity_available | length > 0) | .key'
# Try different regions
# US regions: us-west-1, us-east-1, us-south-1
# International: eu-west-1, asia-northeast-1, etc.
# Try alternative GPU types
# H100 not available? Try A100
# A100 not available? Try A10 or A6000
```
### Instance stuck launching
**Problem**: Instance shows "booting" for over 20 minutes
**Solutions**:
```bash
# Single-GPU: Should be ready in 3-5 minutes
# Multi-GPU (8x): May take 10-15 minutes
# If stuck longer:
# 1. Terminate the instance
# 2. Try a different region
# 3. Try a different instance type
# 4. Contact Lambda support if persistent
```
### API authentication fails
**Error**: `401 Unauthorized` or `403 Forbidden`
**Solutions**:
```bash
# Verify API key format (should start with specific prefix)
echo $LAMBDA_API_KEY
# Test API key
curl -u $LAMBDA_API_KEY: \
https://cloud.lambdalabs.com/api/v1/instance-types
# Generate new API key from Lambda console if needed
# Settings > API keys > Generate
```
### Quota limits reached
**Error**: "Instance limit reached" or "Quota exceeded"
**Solutions**:
- Check current running instances in console
- Terminate unused instances
- Contact Lambda support to request quota increase
- Use 1-Click Clusters for large-scale needs
## SSH Connection Issues
### Connection refused
**Error**: `ssh: connect to host <IP> port 22: Connection refused`
**Solutions**:
```bash
# Wait for instance to fully initialize
# Single-GPU: 3-5 minutes
# Multi-GPU: 10-15 minutes
# Check instance status in console (should be "active")
# Verify correct IP address
curl -u $LAMBDA_API_KEY: \
https://cloud.lambdalabs.com/api/v1/instances | jq '.data[].ip'
```
### Permission denied
**Error**: `Permission denied (publickey)`
**Solutions**:
```bash
# Verify SSH key matches
ssh -v -i ~/.ssh/lambda_key ubuntu@<IP>
# Check key permissions
chmod 600 ~/.ssh/lambda_key
chmod 644 ~/.ssh/lambda_key.pub
# Verify key was added to Lambda console before launch
# Keys must be added BEFORE launching instance
# Check authorized_keys on instance (if you have another way in)
cat ~/.ssh/authorized_keys
```
### Host key verification failed
**Error**: `WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED!`
**Solutions**:
```bash
# This happens when IP is reused by different instance
# Remove old key
ssh-keygen -R <IP>
# Then connect again
ssh ubuntu@<IP>
```
### Timeout during SSH
**Error**: `ssh: connect to host <IP> port 22: Operation timed out`
**Solutions**:
```bash
# Check if instance is in "active" state
# Verify firewall allows SSH (port 22)
# Lambda console > Firewall
# Check your local network allows outbound SSH
# Try from different network/VPN
```
## GPU Issues
### GPU not detected
**Error**: `nvidia-smi: command not found` or no GPUs shown
**Solutions**:
```bash
# Reboot instance
sudo reboot
# Reinstall NVIDIA drivers (if needed)
wget -nv -O- https://lambdalabs.com/install-lambda-stack.sh | sh -
sudo reboot
# Check driver status
nvidia-smi
lsmod | grep nvidia
```
### CUDA out of memory
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
**Solutions**:
```python
# Check GPU memory
import torch
print(torch.cuda.get_device_properties(0).total_memory / 1e9, "GB")
# Clear cache
torch.cuda.empty_cache()
# Reduce batch size
batch_size = batch_size // 2
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Use mixed precision
from torch.cuda.amp import autocast
with autocast():
outputs = model(**inputs)
# Use larger GPU instance
# A100-40GB → A100-80GB → H100
```
### CUDA version mismatch
**Error**: `CUDA driver version is insufficient for CUDA runtime version`
**Solutions**:
```bash
# Check versions
nvidia-smi # Shows driver CUDA version
nvcc --version # Shows toolkit version
# Lambda Stack should have compatible versions
# If mismatch, reinstall Lambda Stack
wget -nv -O- https://lambdalabs.com/install-lambda-stack.sh | sh -
sudo reboot
# Or install specific PyTorch version
pip install torch==2.1.0+cu121 -f https://download.pytorch.org/whl/torch_stable.html
```
### Multi-GPU not working
**Error**: Only one GPU being used
**Solutions**:
```python
# Check all GPUs visible
import torch
print(f"GPUs available: {torch.cuda.device_count()}")
# Verify CUDA_VISIBLE_DEVICES not set restrictively
import os
print(os.environ.get("CUDA_VISIBLE_DEVICES", "not set"))
# Use DataParallel or DistributedDataParallel
model = torch.nn.DataParallel(model)
# or
model = torch.nn.parallel.DistributedDataParallel(model)
```
## Filesystem Issues
### Filesystem not mounted
**Error**: `/lambda/nfs/<name>` doesn't exist
**Solutions**:
```bash
# Filesystem must be attached at launch time
# Cannot attach to running instance
# Verify filesystem was selected during launch
# Check mount points
df -h | grep lambda
# If missing, terminate and relaunch with filesystem
```
### Slow filesystem performance
**Problem**: Reading/writing to filesystem is slow
**Solutions**:
```bash
# Use local SSD for temporary/intermediate files
# /home/ubuntu has fast NVMe storage
# Copy frequently accessed data to local storage
cp -r /lambda/nfs/storage/dataset /home/ubuntu/dataset
# Use filesystem for checkpoints and final outputs only
# Check network bandwidth
iperf3 -c <filesystem_server>
```
### Data lost after termination
**Problem**: Files disappeared after instance terminated
**Solutions**:
```bash
# Root volume (/home/ubuntu) is EPHEMERAL
# Data there is lost on termination
# ALWAYS use filesystem for persistent data
/lambda/nfs/<filesystem_name>/
# Sync important local files before terminating
rsync -av /home/ubuntu/outputs/ /lambda/nfs/storage/outputs/
```
### Filesystem full
**Error**: `No space left on device`
**Solutions**:
```bash
# Check filesystem usage
df -h /lambda/nfs/storage
# Find large files
du -sh /lambda/nfs/storage/* | sort -h
# Clean up old checkpoints
find /lambda/nfs/storage/checkpoints -mtime +7 -delete
# Increase filesystem size in Lambda console
# (may require support request)
```
## Network Issues
### Port not accessible
**Error**: Cannot connect to service (TensorBoard, Jupyter, etc.)
**Solutions**:
```bash
# Lambda default: Only port 22 is open
# Configure firewall in Lambda console
# Or use SSH tunneling (recommended)
ssh -L 6006:localhost:6006 ubuntu@<IP>
# Access at http://localhost:6006
# For Jupyter
ssh -L 8888:localhost:8888 ubuntu@<IP>
```
### Slow data download
**Problem**: Downloading datasets is slow
**Solutions**:
```bash
# Check available bandwidth
speedtest-cli
# Use multi-threaded download
aria2c -x 16 <URL>
# For HuggingFace models
export HF_HUB_ENABLE_HF_TRANSFER=1
pip install hf_transfer
# For S3, use parallel transfer
aws s3 sync s3://bucket/data /local/data --quiet
```
### Inter-node communication fails
**Error**: Distributed training can't connect between nodes
**Solutions**:
```bash
# Verify nodes in same region (required)
# Check private IPs can communicate
ping <other_node_private_ip>
# Verify NCCL settings
export NCCL_DEBUG=INFO
export NCCL_IB_DISABLE=0 # Enable InfiniBand if available
# Check firewall allows distributed ports
# Need: 29500 (PyTorch), or configured MASTER_PORT
```
## Software Issues
### Package installation fails
**Error**: `pip install` errors
**Solutions**:
```bash
# Use virtual environment (don't modify system Python)
python -m venv ~/myenv
source ~/myenv/bin/activate
pip install <package>
# For CUDA packages, match CUDA version
pip install torch --index-url https://download.pytorch.org/whl/cu121
# Clear pip cache if corrupted
pip cache purge
```
### Python version issues
**Error**: Package requires different Python version
**Solutions**:
```bash
# Install alternate Python (don't replace system Python)
sudo apt install python3.11 python3.11-venv python3.11-dev
# Create venv with specific Python
python3.11 -m venv ~/py311env
source ~/py311env/bin/activate
```
### ImportError or ModuleNotFoundError
**Error**: Module not found despite installation
**Solutions**:
```bash
# Verify correct Python environment
which python
pip list | grep <module>
# Ensure virtual environment is activated
source ~/myenv/bin/activate
# Reinstall in correct environment
pip uninstall <package>
pip install <package>
```
## Training Issues
### Training hangs
**Problem**: Training stops progressing, no output
**Solutions**:
```bash
# Check GPU utilization
watch -n 1 nvidia-smi
# If GPUs at 0%, likely data loading bottleneck
# Increase num_workers in DataLoader
# Check for deadlocks in distributed training
export NCCL_DEBUG=INFO
# Add timeouts
dist.init_process_group(..., timeout=timedelta(minutes=30))
```
### Checkpoint corruption
**Error**: `RuntimeError: storage has wrong size` or similar
**Solutions**:
```python
# Use safe saving pattern
checkpoint_path = "/lambda/nfs/storage/checkpoint.pt"
temp_path = checkpoint_path + ".tmp"
# Save to temp first
torch.save(state_dict, temp_path)
# Then atomic rename
os.rename(temp_path, checkpoint_path)
# For loading corrupted checkpoint
try:
state = torch.load(checkpoint_path)
except:
# Fall back to previous checkpoint
state = torch.load(checkpoint_path + ".backup")
```
### Memory leak
**Problem**: Memory usage grows over time
**Solutions**:
```python
# Clear CUDA cache periodically
torch.cuda.empty_cache()
# Detach tensors when logging
loss_value = loss.detach().cpu().item()
# Don't accumulate gradients unintentionally
optimizer.zero_grad(set_to_none=True)
# Use gradient accumulation properly
if (step + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
```
## Billing Issues
### Unexpected charges
**Problem**: Bill higher than expected
**Solutions**:
```bash
# Check for forgotten running instances
curl -u $LAMBDA_API_KEY: \
https://cloud.lambdalabs.com/api/v1/instances | jq '.data[].id'
# Terminate all instances
# Lambda console > Instances > Terminate all
# Lambda charges by the minute
# No charge for stopped instances (but no "stop" feature - only terminate)
```
### Instance terminated unexpectedly
**Problem**: Instance disappeared without manual termination
**Possible causes**:
- Payment issue (card declined)
- Account suspension
- Instance health check failure
**Solutions**:
- Check email for Lambda notifications
- Verify payment method in console
- Contact Lambda support
- Always checkpoint to filesystem
## Common Error Messages
| Error | Cause | Solution |
|-------|-------|----------|
| `No capacity available` | Region/GPU sold out | Try different region or GPU type |
| `Permission denied (publickey)` | SSH key mismatch | Re-add key, check permissions |
| `CUDA out of memory` | Model too large | Reduce batch size, use larger GPU |
| `No space left on device` | Disk full | Clean up or use filesystem |
| `Connection refused` | Instance not ready | Wait 3-15 minutes for boot |
| `Module not found` | Wrong Python env | Activate correct virtualenv |
## Getting Help
1. **Documentation**: https://docs.lambda.ai
2. **Support**: https://support.lambdalabs.com
3. **Email**: support@lambdalabs.com
4. **Status**: Check Lambda status page for outages
### Information to Include
When contacting support, include:
- Instance ID
- Region
- Instance type
- Error message (full traceback)
- Steps to reproduce
- Time of occurrence

View file

@ -0,0 +1,307 @@
---
name: llava
description: Large Language and Vision Assistant. Enables visual instruction tuning and image-based conversations. Combines CLIP vision encoder with Vicuna/LLaMA language models. Supports multi-turn image chat, visual question answering, and instruction following. Use for vision-language chatbots or image understanding tasks. Best for conversational image analysis.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [transformers, torch, pillow]
metadata:
hermes:
tags: [LLaVA, Vision-Language, Multimodal, Visual Question Answering, Image Chat, CLIP, Vicuna, Conversational AI, Instruction Tuning, VQA]
---
# LLaVA - Large Language and Vision Assistant
Open-source vision-language model for conversational image understanding.
## When to use LLaVA
**Use when:**
- Building vision-language chatbots
- Visual question answering (VQA)
- Image description and captioning
- Multi-turn image conversations
- Visual instruction following
- Document understanding with images
**Metrics**:
- **23,000+ GitHub stars**
- GPT-4V level capabilities (targeted)
- Apache 2.0 License
- Multiple model sizes (7B-34B params)
**Use alternatives instead**:
- **GPT-4V**: Highest quality, API-based
- **CLIP**: Simple zero-shot classification
- **BLIP-2**: Better for captioning only
- **Flamingo**: Research, not open-source
## Quick start
### Installation
```bash
# Clone repository
git clone https://github.com/haotian-liu/LLaVA
cd LLaVA
# Install
pip install -e .
```
### Basic usage
```python
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates
from PIL import Image
import torch
# Load model
model_path = "liuhaotian/llava-v1.5-7b"
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path=model_path,
model_base=None,
model_name=get_model_name_from_path(model_path)
)
# Load image
image = Image.open("image.jpg")
image_tensor = process_images([image], image_processor, model.config)
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
# Create conversation
conv = conv_templates["llava_v1"].copy()
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\nWhat is in this image?")
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
# Generate response
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
do_sample=True,
temperature=0.2,
max_new_tokens=512
)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
print(response)
```
## Available models
| Model | Parameters | VRAM | Quality |
|-------|------------|------|---------|
| LLaVA-v1.5-7B | 7B | ~14 GB | Good |
| LLaVA-v1.5-13B | 13B | ~28 GB | Better |
| LLaVA-v1.6-34B | 34B | ~70 GB | Best |
```python
# Load different models
model_7b = "liuhaotian/llava-v1.5-7b"
model_13b = "liuhaotian/llava-v1.5-13b"
model_34b = "liuhaotian/llava-v1.6-34b"
# 4-bit quantization for lower VRAM
load_4bit = True # Reduces VRAM by ~4×
```
## CLI usage
```bash
# Single image query
python -m llava.serve.cli \
--model-path liuhaotian/llava-v1.5-7b \
--image-file image.jpg \
--query "What is in this image?"
# Multi-turn conversation
python -m llava.serve.cli \
--model-path liuhaotian/llava-v1.5-7b \
--image-file image.jpg
# Then type questions interactively
```
## Web UI (Gradio)
```bash
# Launch Gradio interface
python -m llava.serve.gradio_web_server \
--model-path liuhaotian/llava-v1.5-7b \
--load-4bit # Optional: reduce VRAM
# Access at http://localhost:7860
```
## Multi-turn conversations
```python
# Initialize conversation
conv = conv_templates["llava_v1"].copy()
# Turn 1
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\nWhat is in this image?")
conv.append_message(conv.roles[1], None)
response1 = generate(conv, model, image) # "A dog playing in a park"
# Turn 2
conv.messages[-1][1] = response1 # Add previous response
conv.append_message(conv.roles[0], "What breed is the dog?")
conv.append_message(conv.roles[1], None)
response2 = generate(conv, model, image) # "Golden Retriever"
# Turn 3
conv.messages[-1][1] = response2
conv.append_message(conv.roles[0], "What time of day is it?")
conv.append_message(conv.roles[1], None)
response3 = generate(conv, model, image)
```
## Common tasks
### Image captioning
```python
question = "Describe this image in detail."
response = ask(model, image, question)
```
### Visual question answering
```python
question = "How many people are in the image?"
response = ask(model, image, question)
```
### Object detection (textual)
```python
question = "List all the objects you can see in this image."
response = ask(model, image, question)
```
### Scene understanding
```python
question = "What is happening in this scene?"
response = ask(model, image, question)
```
### Document understanding
```python
question = "What is the main topic of this document?"
response = ask(model, document_image, question)
```
## Training custom model
```bash
# Stage 1: Feature alignment (558K image-caption pairs)
bash scripts/v1_5/pretrain.sh
# Stage 2: Visual instruction tuning (150K instruction data)
bash scripts/v1_5/finetune.sh
```
## Quantization (reduce VRAM)
```python
# 4-bit quantization
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path="liuhaotian/llava-v1.5-13b",
model_base=None,
model_name=get_model_name_from_path("liuhaotian/llava-v1.5-13b"),
load_4bit=True # Reduces VRAM ~4×
)
# 8-bit quantization
load_8bit=True # Reduces VRAM ~2×
```
## Best practices
1. **Start with 7B model** - Good quality, manageable VRAM
2. **Use 4-bit quantization** - Reduces VRAM significantly
3. **GPU required** - CPU inference extremely slow
4. **Clear prompts** - Specific questions get better answers
5. **Multi-turn conversations** - Maintain conversation context
6. **Temperature 0.2-0.7** - Balance creativity/consistency
7. **max_new_tokens 512-1024** - For detailed responses
8. **Batch processing** - Process multiple images sequentially
## Performance
| Model | VRAM (FP16) | VRAM (4-bit) | Speed (tokens/s) |
|-------|-------------|--------------|------------------|
| 7B | ~14 GB | ~4 GB | ~20 |
| 13B | ~28 GB | ~8 GB | ~12 |
| 34B | ~70 GB | ~18 GB | ~5 |
*On A100 GPU*
## Benchmarks
LLaVA achieves competitive scores on:
- **VQAv2**: 78.5%
- **GQA**: 62.0%
- **MM-Vet**: 35.4%
- **MMBench**: 64.3%
## Limitations
1. **Hallucinations** - May describe things not in image
2. **Spatial reasoning** - Struggles with precise locations
3. **Small text** - Difficulty reading fine print
4. **Object counting** - Imprecise for many objects
5. **VRAM requirements** - Need powerful GPU
6. **Inference speed** - Slower than CLIP
## Integration with frameworks
### LangChain
```python
from langchain.llms.base import LLM
class LLaVALLM(LLM):
def _call(self, prompt, stop=None):
# Custom LLaVA inference
return response
llm = LLaVALLM()
```
### Gradio App
```python
import gradio as gr
def chat(image, text, history):
response = ask_llava(model, image, text)
return response
demo = gr.ChatInterface(
chat,
additional_inputs=[gr.Image(type="pil")],
title="LLaVA Chat"
)
demo.launch()
```
## Resources
- **GitHub**: https://github.com/haotian-liu/LLaVA ⭐ 23,000+
- **Paper**: https://arxiv.org/abs/2304.08485
- **Demo**: https://llava.hliu.cc
- **Models**: https://huggingface.co/liuhaotian
- **License**: Apache 2.0

View file

@ -0,0 +1,197 @@
# LLaVA Training Guide
Guide to training and fine-tuning LLaVA models.
## Training stages
### Stage 1: Feature alignment (Pretraining)
**Purpose**: Align vision encoder with language model
**Data**: 558K image-caption pairs (CC3M subset)
```bash
# Download pretrained projector or train from scratch
bash scripts/v1_5/pretrain.sh
```
**Configuration:**
- Base model: Vicuna-7B or LLaMA-2-7B
- Vision encoder: CLIP ViT-L/14
- Training time: ~20 hours on 8× A100
### Stage 2: Visual instruction tuning
**Purpose**: Teach model to follow visual instructions
**Data**: 150K GPT-generated multimodal instruction data
```bash
# Fine-tune with instruction data
bash scripts/v1_5/finetune.sh
```
**Configuration:**
- Epochs: 1
- Batch size: 128 (across 8 GPUs)
- Learning rate: 2e-5
- Training time: ~24 hours on 8× A100
## Data format
### Instruction data format
```json
[
{
"id": "001",
"image": "path/to/image.jpg",
"conversations": [
{
"from": "human",
"value": "<image>\nWhat is in this image?"
},
{
"from": "gpt",
"value": "The image shows a dog playing in a park."
},
{
"from": "human",
"value": "What breed is the dog?"
},
{
"from": "gpt",
"value": "It appears to be a Golden Retriever."
}
]
}
]
```
## Fine-tuning on custom data
### Prepare your data
```python
import json
# Create instruction data
data = []
for image_path, qa_pairs in your_dataset:
conversations = []
for q, a in qa_pairs:
conversations.append({"from": "human", "value": f"<image>\n{q}"})
conversations.append({"from": "gpt", "value": a})
data.append({
"id": str(len(data)),
"image": image_path,
"conversations": conversations
})
# Save
with open("custom_data.json", "w") as f:
json.dump(data, f, indent=2)
```
### Fine-tune script
```bash
#!/bin/bash
# Set paths
DATA_PATH="custom_data.json"
IMAGE_FOLDER="path/to/images"
MODEL_PATH="liuhaotian/llava-v1.5-7b"
OUTPUT_DIR="./checkpoints/llava-custom"
# Fine-tune
deepspeed llava/train/train_mem.py \
--deepspeed ./scripts/zero2.json \
--model_name_or_path $MODEL_PATH \
--version v1 \
--data_path $DATA_PATH \
--image_folder $IMAGE_FOLDER \
--vision_tower openai/clip-vit-large-patch14-336 \
--mm_projector_type mlp2x_gelu \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--image_aspect_ratio pad \
--group_by_modality_length True \
--bf16 True \
--output_dir $OUTPUT_DIR \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 50000 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb
```
## LoRA fine-tuning (memory efficient)
```python
from peft import LoraConfig, get_peft_model
# LoRA config
lora_config = LoraConfig(
r=8, # LoRA rank
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
# Apply LoRA
model = get_peft_model(base_model, lora_config)
# Train with much lower memory
```
## Hardware requirements
### Full fine-tuning
- **7B model**: 8× A100 (40GB)
- **13B model**: 8× A100 (80GB)
- **Training time**: 20-48 hours
### LoRA fine-tuning
- **7B model**: 1× A100 (40GB)
- **13B model**: 2× A100 (40GB)
- **Training time**: 10-24 hours
## Best practices
1. **Start with pretrained** - Don't train from scratch
2. **Use LoRA for efficiency** - 10× less memory
3. **Quality over quantity** - 1K high-quality > 10K low-quality
4. **Multi-turn conversations** - More engaging than single Q&A
5. **Diverse images** - Cover different scenarios
6. **Clear instructions** - Specific questions get better answers
7. **Monitor loss** - Should decrease smoothly
8. **Save checkpoints** - Training can fail
9. **Test regularly** - Validate on held-out set
10. **Use DeepSpeed** - For multi-GPU training
## Resources
- **Training script**: https://github.com/haotian-liu/LLaVA/tree/main/scripts
- **Data format**: https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md
- **Paper**: https://arxiv.org/abs/2304.08485

View file

@ -0,0 +1,386 @@
---
name: nemo-curator
description: GPU-accelerated data curation for LLM training. Supports text/image/video/audio. Features fuzzy deduplication (16× faster), quality filtering (30+ heuristics), semantic deduplication, PII redaction, NSFW detection. Scales across GPUs with RAPIDS. Use for preparing high-quality training datasets, cleaning web data, or deduplicating large corpora.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [nemo-curator, cudf, dask, rapids]
metadata:
hermes:
tags: [Data Processing, NeMo Curator, Data Curation, GPU Acceleration, Deduplication, Quality Filtering, NVIDIA, RAPIDS, PII Redaction, Multimodal, LLM Training Data]
---
# NeMo Curator - GPU-Accelerated Data Curation
NVIDIA's toolkit for preparing high-quality training data for LLMs.
## When to use NeMo Curator
**Use NeMo Curator when:**
- Preparing LLM training data from web scrapes (Common Crawl)
- Need fast deduplication (16× faster than CPU)
- Curating multi-modal datasets (text, images, video, audio)
- Filtering low-quality or toxic content
- Scaling data processing across GPU cluster
**Performance**:
- **16× faster** fuzzy deduplication (8TB RedPajama v2)
- **40% lower TCO** vs CPU alternatives
- **Near-linear scaling** across GPU nodes
**Use alternatives instead**:
- **datatrove**: CPU-based, open-source data processing
- **dolma**: Allen AI's data toolkit
- **Ray Data**: General ML data processing (no curation focus)
## Quick start
### Installation
```bash
# Text curation (CUDA 12)
uv pip install "nemo-curator[text_cuda12]"
# All modalities
uv pip install "nemo-curator[all_cuda12]"
# CPU-only (slower)
uv pip install "nemo-curator[cpu]"
```
### Basic text curation pipeline
```python
from nemo_curator import ScoreFilter, Modify
from nemo_curator.datasets import DocumentDataset
import pandas as pd
# Load data
df = pd.DataFrame({"text": ["Good document", "Bad doc", "Excellent text"]})
dataset = DocumentDataset(df)
# Quality filtering
def quality_score(doc):
return len(doc["text"].split()) > 5 # Filter short docs
filtered = ScoreFilter(quality_score)(dataset)
# Deduplication
from nemo_curator.modules import ExactDuplicates
deduped = ExactDuplicates()(filtered)
# Save
deduped.to_parquet("curated_data/")
```
## Data curation pipeline
### Stage 1: Quality filtering
```python
from nemo_curator.filters import (
WordCountFilter,
RepeatedLinesFilter,
UrlRatioFilter,
NonAlphaNumericFilter
)
# Apply 30+ heuristic filters
from nemo_curator import ScoreFilter
# Word count filter
dataset = dataset.filter(WordCountFilter(min_words=50, max_words=100000))
# Remove repetitive content
dataset = dataset.filter(RepeatedLinesFilter(max_repeated_line_fraction=0.3))
# URL ratio filter
dataset = dataset.filter(UrlRatioFilter(max_url_ratio=0.2))
```
### Stage 2: Deduplication
**Exact deduplication**:
```python
from nemo_curator.modules import ExactDuplicates
# Remove exact duplicates
deduped = ExactDuplicates(id_field="id", text_field="text")(dataset)
```
**Fuzzy deduplication** (16× faster on GPU):
```python
from nemo_curator.modules import FuzzyDuplicates
# MinHash + LSH deduplication
fuzzy_dedup = FuzzyDuplicates(
id_field="id",
text_field="text",
num_hashes=260, # MinHash parameters
num_buckets=20,
hash_method="md5"
)
deduped = fuzzy_dedup(dataset)
```
**Semantic deduplication**:
```python
from nemo_curator.modules import SemanticDuplicates
# Embedding-based deduplication
semantic_dedup = SemanticDuplicates(
id_field="id",
text_field="text",
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
threshold=0.8 # Cosine similarity threshold
)
deduped = semantic_dedup(dataset)
```
### Stage 3: PII redaction
```python
from nemo_curator.modules import Modify
from nemo_curator.modifiers import PIIRedactor
# Redact personally identifiable information
pii_redactor = PIIRedactor(
supported_entities=["EMAIL_ADDRESS", "PHONE_NUMBER", "PERSON", "LOCATION"],
anonymize_action="replace" # or "redact"
)
redacted = Modify(pii_redactor)(dataset)
```
### Stage 4: Classifier filtering
```python
from nemo_curator.classifiers import QualityClassifier
# Quality classification
quality_clf = QualityClassifier(
model_path="nvidia/quality-classifier-deberta",
batch_size=256,
device="cuda"
)
# Filter low-quality documents
high_quality = dataset.filter(lambda doc: quality_clf(doc["text"]) > 0.5)
```
## GPU acceleration
### GPU vs CPU performance
| Operation | CPU (16 cores) | GPU (A100) | Speedup |
|-----------|----------------|------------|---------|
| Fuzzy dedup (8TB) | 120 hours | 7.5 hours | 16× |
| Exact dedup (1TB) | 8 hours | 0.5 hours | 16× |
| Quality filtering | 2 hours | 0.2 hours | 10× |
### Multi-GPU scaling
```python
from nemo_curator import get_client
import dask_cuda
# Initialize GPU cluster
client = get_client(cluster_type="gpu", n_workers=8)
# Process with 8 GPUs
deduped = FuzzyDuplicates(...)(dataset)
```
## Multi-modal curation
### Image curation
```python
from nemo_curator.image import (
AestheticFilter,
NSFWFilter,
CLIPEmbedder
)
# Aesthetic scoring
aesthetic_filter = AestheticFilter(threshold=5.0)
filtered_images = aesthetic_filter(image_dataset)
# NSFW detection
nsfw_filter = NSFWFilter(threshold=0.9)
safe_images = nsfw_filter(filtered_images)
# Generate CLIP embeddings
clip_embedder = CLIPEmbedder(model="openai/clip-vit-base-patch32")
image_embeddings = clip_embedder(safe_images)
```
### Video curation
```python
from nemo_curator.video import (
SceneDetector,
ClipExtractor,
InternVideo2Embedder
)
# Detect scenes
scene_detector = SceneDetector(threshold=27.0)
scenes = scene_detector(video_dataset)
# Extract clips
clip_extractor = ClipExtractor(min_duration=2.0, max_duration=10.0)
clips = clip_extractor(scenes)
# Generate embeddings
video_embedder = InternVideo2Embedder()
video_embeddings = video_embedder(clips)
```
### Audio curation
```python
from nemo_curator.audio import (
ASRInference,
WERFilter,
DurationFilter
)
# ASR transcription
asr = ASRInference(model="nvidia/stt_en_fastconformer_hybrid_large_pc")
transcribed = asr(audio_dataset)
# Filter by WER (word error rate)
wer_filter = WERFilter(max_wer=0.3)
high_quality_audio = wer_filter(transcribed)
# Duration filtering
duration_filter = DurationFilter(min_duration=1.0, max_duration=30.0)
filtered_audio = duration_filter(high_quality_audio)
```
## Common patterns
### Web scrape curation (Common Crawl)
```python
from nemo_curator import ScoreFilter, Modify
from nemo_curator.filters import *
from nemo_curator.modules import *
from nemo_curator.datasets import DocumentDataset
# Load Common Crawl data
dataset = DocumentDataset.read_parquet("common_crawl/*.parquet")
# Pipeline
pipeline = [
# 1. Quality filtering
WordCountFilter(min_words=100, max_words=50000),
RepeatedLinesFilter(max_repeated_line_fraction=0.2),
SymbolToWordRatioFilter(max_symbol_to_word_ratio=0.3),
UrlRatioFilter(max_url_ratio=0.3),
# 2. Language filtering
LanguageIdentificationFilter(target_languages=["en"]),
# 3. Deduplication
ExactDuplicates(id_field="id", text_field="text"),
FuzzyDuplicates(id_field="id", text_field="text", num_hashes=260),
# 4. PII redaction
PIIRedactor(),
# 5. NSFW filtering
NSFWClassifier(threshold=0.8)
]
# Execute
for stage in pipeline:
dataset = stage(dataset)
# Save
dataset.to_parquet("curated_common_crawl/")
```
### Distributed processing
```python
from nemo_curator import get_client
from dask_cuda import LocalCUDACluster
# Multi-GPU cluster
cluster = LocalCUDACluster(n_workers=8)
client = get_client(cluster=cluster)
# Process large dataset
dataset = DocumentDataset.read_parquet("s3://large_dataset/*.parquet")
deduped = FuzzyDuplicates(...)(dataset)
# Cleanup
client.close()
cluster.close()
```
## Performance benchmarks
### Fuzzy deduplication (8TB RedPajama v2)
- **CPU (256 cores)**: 120 hours
- **GPU (8× A100)**: 7.5 hours
- **Speedup**: 16×
### Exact deduplication (1TB)
- **CPU (64 cores)**: 8 hours
- **GPU (4× A100)**: 0.5 hours
- **Speedup**: 16×
### Quality filtering (100GB)
- **CPU (32 cores)**: 2 hours
- **GPU (2× A100)**: 0.2 hours
- **Speedup**: 10×
## Cost comparison
**CPU-based curation** (AWS c5.18xlarge × 10):
- Cost: $3.60/hour × 10 = $36/hour
- Time for 8TB: 120 hours
- **Total**: $4,320
**GPU-based curation** (AWS p4d.24xlarge × 2):
- Cost: $32.77/hour × 2 = $65.54/hour
- Time for 8TB: 7.5 hours
- **Total**: $491.55
**Savings**: 89% reduction ($3,828 saved)
## Supported data formats
- **Input**: Parquet, JSONL, CSV
- **Output**: Parquet (recommended), JSONL
- **WebDataset**: TAR archives for multi-modal
## Use cases
**Production deployments**:
- NVIDIA used NeMo Curator to prepare Nemotron-4 training data
- Open-source datasets curated: RedPajama v2, The Pile
## References
- **[Filtering Guide](references/filtering.md)** - 30+ quality filters, heuristics
- **[Deduplication Guide](references/deduplication.md)** - Exact, fuzzy, semantic methods
## Resources
- **GitHub**: https://github.com/NVIDIA/NeMo-Curator ⭐ 500+
- **Docs**: https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/
- **Version**: 0.4.0+
- **License**: Apache 2.0

View file

@ -0,0 +1,87 @@
# Deduplication Guide
Complete guide to exact, fuzzy, and semantic deduplication.
## Exact deduplication
Remove documents with identical content.
```python
from nemo_curator.modules import ExactDuplicates
# Exact deduplication
exact_dedup = ExactDuplicates(
id_field="id",
text_field="text",
hash_method="md5" # or "sha256"
)
deduped = exact_dedup(dataset)
```
**Performance**: ~16× faster on GPU vs CPU
## Fuzzy deduplication
Remove near-duplicate documents using MinHash + LSH.
```python
from nemo_curator.modules import FuzzyDuplicates
fuzzy_dedup = FuzzyDuplicates(
id_field="id",
text_field="text",
num_hashes=260, # MinHash permutations (more = accurate)
num_buckets=20, # LSH buckets (more = faster, less recall)
hash_method="md5",
jaccard_threshold=0.8 # Similarity threshold
)
deduped = fuzzy_dedup(dataset)
```
**Parameters**:
- `num_hashes`: 128-512 (default 260)
- `num_buckets`: 10-50 (default 20)
- `jaccard_threshold`: 0.7-0.9 (default 0.8)
**Performance**: 16× faster on 8TB dataset (120h → 7.5h)
## Semantic deduplication
Remove semantically similar documents using embeddings.
```python
from nemo_curator.modules import SemanticDuplicates
semantic_dedup = SemanticDuplicates(
id_field="id",
text_field="text",
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
embedding_batch_size=256,
threshold=0.85, # Cosine similarity threshold
device="cuda"
)
deduped = semantic_dedup(dataset)
```
**Models**:
- `all-MiniLM-L6-v2`: Fast, 384 dims
- `all-mpnet-base-v2`: Better quality, 768 dims
- Custom models supported
## Comparison
| Method | Speed | Recall | Use Case |
|--------|-------|--------|----------|
| Exact | Fastest | 100% | Exact matches only |
| Fuzzy | Fast | ~95% | Near-duplicates (recommended) |
| Semantic | Slow | ~90% | Paraphrases, rewrites |
## Best practices
1. **Start with exact dedup** - Remove obvious duplicates
2. **Use fuzzy for large datasets** - Best speed/quality trade-off
3. **Semantic for high-value data** - Expensive but thorough
4. **GPU acceleration required** - 10-16× speedup

View file

@ -0,0 +1,102 @@
# Quality Filtering Guide
Complete guide to NeMo Curator's 30+ quality filters.
## Text-based filters
### Word count
```python
from nemo_curator.filters import WordCountFilter
# Filter by word count
dataset = dataset.filter(WordCountFilter(min_words=50, max_words=100000))
```
### Repeated content
```python
from nemo_curator.filters import RepeatedLinesFilter
# Remove documents with >30% repeated lines
dataset = dataset.filter(RepeatedLinesFilter(max_repeated_line_fraction=0.3))
```
### Symbol ratio
```python
from nemo_curator.filters import SymbolToWordRatioFilter
# Remove documents with too many symbols
dataset = dataset.filter(SymbolToWordRatioFilter(max_symbol_to_word_ratio=0.3))
```
### URL ratio
```python
from nemo_curator.filters import UrlRatioFilter
# Remove documents with many URLs
dataset = dataset.filter(UrlRatioFilter(max_url_ratio=0.2))
```
## Language filtering
```python
from nemo_curator.filters import LanguageIdentificationFilter
# Keep only English documents
dataset = dataset.filter(LanguageIdentificationFilter(target_languages=["en"]))
# Multiple languages
dataset = dataset.filter(LanguageIdentificationFilter(target_languages=["en", "es", "fr"]))
```
## Classifier-based filtering
### Quality classifier
```python
from nemo_curator.classifiers import QualityClassifier
quality_clf = QualityClassifier(
model_path="nvidia/quality-classifier-deberta",
batch_size=256,
device="cuda"
)
# Filter low-quality (threshold > 0.5 = high quality)
dataset = dataset.filter(lambda doc: quality_clf(doc["text"]) > 0.5)
```
### NSFW classifier
```python
from nemo_curator.classifiers import NSFWClassifier
nsfw_clf = NSFWClassifier(threshold=0.9, device="cuda")
# Remove NSFW content
dataset = dataset.filter(lambda doc: nsfw_clf(doc["text"]) < 0.9)
```
## Heuristic filters
Full list of 30+ filters:
- WordCountFilter
- RepeatedLinesFilter
- UrlRatioFilter
- SymbolToWordRatioFilter
- NonAlphaNumericFilter
- BulletsFilter
- WhiteSpaceFilter
- ParenthesesFilter
- LongWordFilter
- And 20+ more...
## Best practices
1. **Apply cheap filters first** - Word count before GPU classifiers
2. **Tune thresholds on sample** - Test on 10k docs before full run
3. **Use GPU classifiers sparingly** - Expensive but effective
4. **Chain filters efficiently** - Order by cost (cheap → expensive)

View file

@ -0,0 +1,361 @@
---
name: pinecone
description: Managed vector database for production AI applications. Fully managed, auto-scaling, with hybrid search (dense + sparse), metadata filtering, and namespaces. Low latency (<100ms p95). Use for production RAG, recommendation systems, or semantic search at scale. Best for serverless, managed infrastructure.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [pinecone-client]
metadata:
hermes:
tags: [RAG, Pinecone, Vector Database, Managed Service, Serverless, Hybrid Search, Production, Auto-Scaling, Low Latency, Recommendations]
---
# Pinecone - Managed Vector Database
The vector database for production AI applications.
## When to use Pinecone
**Use when:**
- Need managed, serverless vector database
- Production RAG applications
- Auto-scaling required
- Low latency critical (<100ms)
- Don't want to manage infrastructure
- Need hybrid search (dense + sparse vectors)
**Metrics**:
- Fully managed SaaS
- Auto-scales to billions of vectors
- **p95 latency <100ms**
- 99.9% uptime SLA
**Use alternatives instead**:
- **Chroma**: Self-hosted, open-source
- **FAISS**: Offline, pure similarity search
- **Weaviate**: Self-hosted with more features
## Quick start
### Installation
```bash
pip install pinecone-client
```
### Basic usage
```python
from pinecone import Pinecone, ServerlessSpec
# Initialize
pc = Pinecone(api_key="your-api-key")
# Create index
pc.create_index(
name="my-index",
dimension=1536, # Must match embedding dimension
metric="cosine", # or "euclidean", "dotproduct"
spec=ServerlessSpec(cloud="aws", region="us-east-1")
)
# Connect to index
index = pc.Index("my-index")
# Upsert vectors
index.upsert(vectors=[
{"id": "vec1", "values": [0.1, 0.2, ...], "metadata": {"category": "A"}},
{"id": "vec2", "values": [0.3, 0.4, ...], "metadata": {"category": "B"}}
])
# Query
results = index.query(
vector=[0.1, 0.2, ...],
top_k=5,
include_metadata=True
)
print(results["matches"])
```
## Core operations
### Create index
```python
# Serverless (recommended)
pc.create_index(
name="my-index",
dimension=1536,
metric="cosine",
spec=ServerlessSpec(
cloud="aws", # or "gcp", "azure"
region="us-east-1"
)
)
# Pod-based (for consistent performance)
from pinecone import PodSpec
pc.create_index(
name="my-index",
dimension=1536,
metric="cosine",
spec=PodSpec(
environment="us-east1-gcp",
pod_type="p1.x1"
)
)
```
### Upsert vectors
```python
# Single upsert
index.upsert(vectors=[
{
"id": "doc1",
"values": [0.1, 0.2, ...], # 1536 dimensions
"metadata": {
"text": "Document content",
"category": "tutorial",
"timestamp": "2025-01-01"
}
}
])
# Batch upsert (recommended)
vectors = [
{"id": f"vec{i}", "values": embedding, "metadata": metadata}
for i, (embedding, metadata) in enumerate(zip(embeddings, metadatas))
]
index.upsert(vectors=vectors, batch_size=100)
```
### Query vectors
```python
# Basic query
results = index.query(
vector=[0.1, 0.2, ...],
top_k=10,
include_metadata=True,
include_values=False
)
# With metadata filtering
results = index.query(
vector=[0.1, 0.2, ...],
top_k=5,
filter={"category": {"$eq": "tutorial"}}
)
# Namespace query
results = index.query(
vector=[0.1, 0.2, ...],
top_k=5,
namespace="production"
)
# Access results
for match in results["matches"]:
print(f"ID: {match['id']}")
print(f"Score: {match['score']}")
print(f"Metadata: {match['metadata']}")
```
### Metadata filtering
```python
# Exact match
filter = {"category": "tutorial"}
# Comparison
filter = {"price": {"$gte": 100}} # $gt, $gte, $lt, $lte, $ne
# Logical operators
filter = {
"$and": [
{"category": "tutorial"},
{"difficulty": {"$lte": 3}}
]
} # Also: $or
# In operator
filter = {"tags": {"$in": ["python", "ml"]}}
```
## Namespaces
```python
# Partition data by namespace
index.upsert(
vectors=[{"id": "vec1", "values": [...]}],
namespace="user-123"
)
# Query specific namespace
results = index.query(
vector=[...],
namespace="user-123",
top_k=5
)
# List namespaces
stats = index.describe_index_stats()
print(stats['namespaces'])
```
## Hybrid search (dense + sparse)
```python
# Upsert with sparse vectors
index.upsert(vectors=[
{
"id": "doc1",
"values": [0.1, 0.2, ...], # Dense vector
"sparse_values": {
"indices": [10, 45, 123], # Token IDs
"values": [0.5, 0.3, 0.8] # TF-IDF scores
},
"metadata": {"text": "..."}
}
])
# Hybrid query
results = index.query(
vector=[0.1, 0.2, ...],
sparse_vector={
"indices": [10, 45],
"values": [0.5, 0.3]
},
top_k=5,
alpha=0.5 # 0=sparse, 1=dense, 0.5=hybrid
)
```
## LangChain integration
```python
from langchain_pinecone import PineconeVectorStore
from langchain_openai import OpenAIEmbeddings
# Create vector store
vectorstore = PineconeVectorStore.from_documents(
documents=docs,
embedding=OpenAIEmbeddings(),
index_name="my-index"
)
# Query
results = vectorstore.similarity_search("query", k=5)
# With metadata filter
results = vectorstore.similarity_search(
"query",
k=5,
filter={"category": "tutorial"}
)
# As retriever
retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
```
## LlamaIndex integration
```python
from llama_index.vector_stores.pinecone import PineconeVectorStore
# Connect to Pinecone
pc = Pinecone(api_key="your-key")
pinecone_index = pc.Index("my-index")
# Create vector store
vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
# Use in LlamaIndex
from llama_index.core import StorageContext, VectorStoreIndex
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
```
## Index management
```python
# List indices
indexes = pc.list_indexes()
# Describe index
index_info = pc.describe_index("my-index")
print(index_info)
# Get index stats
stats = index.describe_index_stats()
print(f"Total vectors: {stats['total_vector_count']}")
print(f"Namespaces: {stats['namespaces']}")
# Delete index
pc.delete_index("my-index")
```
## Delete vectors
```python
# Delete by ID
index.delete(ids=["vec1", "vec2"])
# Delete by filter
index.delete(filter={"category": "old"})
# Delete all in namespace
index.delete(delete_all=True, namespace="test")
# Delete entire index
index.delete(delete_all=True)
```
## Best practices
1. **Use serverless** - Auto-scaling, cost-effective
2. **Batch upserts** - More efficient (100-200 per batch)
3. **Add metadata** - Enable filtering
4. **Use namespaces** - Isolate data by user/tenant
5. **Monitor usage** - Check Pinecone dashboard
6. **Optimize filters** - Index frequently filtered fields
7. **Test with free tier** - 1 index, 100K vectors free
8. **Use hybrid search** - Better quality
9. **Set appropriate dimensions** - Match embedding model
10. **Regular backups** - Export important data
## Performance
| Operation | Latency | Notes |
|-----------|---------|-------|
| Upsert | ~50-100ms | Per batch |
| Query (p50) | ~50ms | Depends on index size |
| Query (p95) | ~100ms | SLA target |
| Metadata filter | ~+10-20ms | Additional overhead |
## Pricing (as of 2025)
**Serverless**:
- $0.096 per million read units
- $0.06 per million write units
- $0.06 per GB storage/month
**Free tier**:
- 1 serverless index
- 100K vectors (1536 dimensions)
- Great for prototyping
## Resources
- **Website**: https://www.pinecone.io
- **Docs**: https://docs.pinecone.io
- **Console**: https://app.pinecone.io
- **Pricing**: https://www.pinecone.io/pricing

View file

@ -0,0 +1,181 @@
# Pinecone Deployment Guide
Production deployment patterns for Pinecone.
## Serverless vs Pod-based
### Serverless (Recommended)
```python
from pinecone import Pinecone, ServerlessSpec
pc = Pinecone(api_key="your-key")
# Create serverless index
pc.create_index(
name="my-index",
dimension=1536,
metric="cosine",
spec=ServerlessSpec(
cloud="aws", # or "gcp", "azure"
region="us-east-1"
)
)
```
**Benefits:**
- Auto-scaling
- Pay per usage
- No infrastructure management
- Cost-effective for variable load
**Use when:**
- Variable traffic
- Cost optimization important
- Don't need consistent latency
### Pod-based
```python
from pinecone import PodSpec
pc.create_index(
name="my-index",
dimension=1536,
metric="cosine",
spec=PodSpec(
environment="us-east1-gcp",
pod_type="p1.x1", # or p1.x2, p1.x4, p1.x8
pods=2, # Number of pods
replicas=2 # High availability
)
)
```
**Benefits:**
- Consistent performance
- Predictable latency
- Higher throughput
- Dedicated resources
**Use when:**
- Production workloads
- Need consistent p95 latency
- High throughput required
## Hybrid search
### Dense + Sparse vectors
```python
# Upsert with both dense and sparse vectors
index.upsert(vectors=[
{
"id": "doc1",
"values": [0.1, 0.2, ...], # Dense (semantic)
"sparse_values": {
"indices": [10, 45, 123], # Token IDs
"values": [0.5, 0.3, 0.8] # TF-IDF/BM25 scores
},
"metadata": {"text": "..."}
}
])
# Hybrid query
results = index.query(
vector=[0.1, 0.2, ...], # Dense query
sparse_vector={
"indices": [10, 45],
"values": [0.5, 0.3]
},
top_k=10,
alpha=0.5 # 0=sparse only, 1=dense only, 0.5=balanced
)
```
**Benefits:**
- Best of both worlds
- Semantic + keyword matching
- Better recall than either alone
## Namespaces for multi-tenancy
```python
# Separate data by user/tenant
index.upsert(
vectors=[{"id": "doc1", "values": [...]}],
namespace="user-123"
)
# Query specific namespace
results = index.query(
vector=[...],
namespace="user-123",
top_k=5
)
# List namespaces
stats = index.describe_index_stats()
print(stats['namespaces'])
```
**Use cases:**
- Multi-tenant SaaS
- User-specific data isolation
- A/B testing (prod/staging namespaces)
## Metadata filtering
### Exact match
```python
results = index.query(
vector=[...],
filter={"category": "tutorial"},
top_k=5
)
```
### Range queries
```python
results = index.query(
vector=[...],
filter={"price": {"$gte": 100, "$lte": 500}},
top_k=5
)
```
### Complex filters
```python
results = index.query(
vector=[...],
filter={
"$and": [
{"category": {"$in": ["tutorial", "guide"]}},
{"difficulty": {"$lte": 3}},
{"published": {"$gte": "2024-01-01"}}
]
},
top_k=5
)
```
## Best practices
1. **Use serverless for development** - Cost-effective
2. **Switch to pods for production** - Consistent performance
3. **Implement namespaces** - Multi-tenancy
4. **Add metadata strategically** - Enable filtering
5. **Use hybrid search** - Better quality
6. **Batch upserts** - 100-200 vectors per batch
7. **Monitor usage** - Check Pinecone dashboard
8. **Set up alerts** - Usage/cost thresholds
9. **Regular backups** - Export important data
10. **Test filters** - Verify performance
## Resources
- **Docs**: https://docs.pinecone.io
- **Console**: https://app.pinecone.io

View file

@ -0,0 +1,349 @@
---
name: pytorch-lightning
description: High-level PyTorch framework with Trainer class, automatic distributed training (DDP/FSDP/DeepSpeed), callbacks system, and minimal boilerplate. Scales from laptop to supercomputer with same code. Use when you want clean training loops with built-in best practices.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [lightning, torch, transformers]
metadata:
hermes:
tags: [PyTorch Lightning, Training Framework, Distributed Training, DDP, FSDP, DeepSpeed, High-Level API, Callbacks, Best Practices, Scalable]
---
# PyTorch Lightning - High-Level Training Framework
## Quick start
PyTorch Lightning organizes PyTorch code to eliminate boilerplate while maintaining flexibility.
**Installation**:
```bash
pip install lightning
```
**Convert PyTorch to Lightning** (3 steps):
```python
import lightning as L
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
# Step 1: Define LightningModule (organize your PyTorch code)
class LitModel(L.LightningModule):
def __init__(self, hidden_size=128):
super().__init__()
self.model = nn.Sequential(
nn.Linear(28 * 28, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 10)
)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = nn.functional.cross_entropy(y_hat, y)
self.log('train_loss', loss) # Auto-logged to TensorBoard
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
# Step 2: Create data
train_loader = DataLoader(train_dataset, batch_size=32)
# Step 3: Train with Trainer (handles everything else!)
trainer = L.Trainer(max_epochs=10, accelerator='gpu', devices=2)
model = LitModel()
trainer.fit(model, train_loader)
```
**That's it!** Trainer handles:
- GPU/TPU/CPU switching
- Distributed training (DDP, FSDP, DeepSpeed)
- Mixed precision (FP16, BF16)
- Gradient accumulation
- Checkpointing
- Logging
- Progress bars
## Common workflows
### Workflow 1: From PyTorch to Lightning
**Original PyTorch code**:
```python
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
model.to('cuda')
for epoch in range(max_epochs):
for batch in train_loader:
batch = batch.to('cuda')
optimizer.zero_grad()
loss = model(batch)
loss.backward()
optimizer.step()
```
**Lightning version**:
```python
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
self.model = MyModel()
def training_step(self, batch, batch_idx):
loss = self.model(batch) # No .to('cuda') needed!
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())
# Train
trainer = L.Trainer(max_epochs=10, accelerator='gpu')
trainer.fit(LitModel(), train_loader)
```
**Benefits**: 40+ lines → 15 lines, no device management, automatic distributed
### Workflow 2: Validation and testing
```python
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
self.model = MyModel()
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = nn.functional.cross_entropy(y_hat, y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
val_loss = nn.functional.cross_entropy(y_hat, y)
acc = (y_hat.argmax(dim=1) == y).float().mean()
self.log('val_loss', val_loss)
self.log('val_acc', acc)
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
test_loss = nn.functional.cross_entropy(y_hat, y)
self.log('test_loss', test_loss)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
# Train with validation
trainer = L.Trainer(max_epochs=10)
trainer.fit(model, train_loader, val_loader)
# Test
trainer.test(model, test_loader)
```
**Automatic features**:
- Validation runs every epoch by default
- Metrics logged to TensorBoard
- Best model checkpointing based on val_loss
### Workflow 3: Distributed training (DDP)
```python
# Same code as single GPU!
model = LitModel()
# 8 GPUs with DDP (automatic!)
trainer = L.Trainer(
accelerator='gpu',
devices=8,
strategy='ddp' # Or 'fsdp', 'deepspeed'
)
trainer.fit(model, train_loader)
```
**Launch**:
```bash
# Single command, Lightning handles the rest
python train.py
```
**No changes needed**:
- Automatic data distribution
- Gradient synchronization
- Multi-node support (just set `num_nodes=2`)
### Workflow 4: Callbacks for monitoring
```python
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
# Create callbacks
checkpoint = ModelCheckpoint(
monitor='val_loss',
mode='min',
save_top_k=3,
filename='model-{epoch:02d}-{val_loss:.2f}'
)
early_stop = EarlyStopping(
monitor='val_loss',
patience=5,
mode='min'
)
lr_monitor = LearningRateMonitor(logging_interval='epoch')
# Add to Trainer
trainer = L.Trainer(
max_epochs=100,
callbacks=[checkpoint, early_stop, lr_monitor]
)
trainer.fit(model, train_loader, val_loader)
```
**Result**:
- Auto-saves best 3 models
- Stops early if no improvement for 5 epochs
- Logs learning rate to TensorBoard
### Workflow 5: Learning rate scheduling
```python
class LitModel(L.LightningModule):
# ... (training_step, etc.)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
# Cosine annealing
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=100,
eta_min=1e-5
)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'interval': 'epoch', # Update per epoch
'frequency': 1
}
}
# Learning rate auto-logged!
trainer = L.Trainer(max_epochs=100)
trainer.fit(model, train_loader)
```
## When to use vs alternatives
**Use PyTorch Lightning when**:
- Want clean, organized code
- Need production-ready training loops
- Switching between single GPU, multi-GPU, TPU
- Want built-in callbacks and logging
- Team collaboration (standardized structure)
**Key advantages**:
- **Organized**: Separates research code from engineering
- **Automatic**: DDP, FSDP, DeepSpeed with 1 line
- **Callbacks**: Modular training extensions
- **Reproducible**: Less boilerplate = fewer bugs
- **Tested**: 1M+ downloads/month, battle-tested
**Use alternatives instead**:
- **Accelerate**: Minimal changes to existing code, more flexibility
- **Ray Train**: Multi-node orchestration, hyperparameter tuning
- **Raw PyTorch**: Maximum control, learning purposes
- **Keras**: TensorFlow ecosystem
## Common issues
**Issue: Loss not decreasing**
Check data and model setup:
```python
# Add to training_step
def training_step(self, batch, batch_idx):
if batch_idx == 0:
print(f"Batch shape: {batch[0].shape}")
print(f"Labels: {batch[1]}")
loss = ...
return loss
```
**Issue: Out of memory**
Reduce batch size or use gradient accumulation:
```python
trainer = L.Trainer(
accumulate_grad_batches=4, # Effective batch = batch_size × 4
precision='bf16' # Or 'fp16', reduces memory 50%
)
```
**Issue: Validation not running**
Ensure you pass val_loader:
```python
# WRONG
trainer.fit(model, train_loader)
# CORRECT
trainer.fit(model, train_loader, val_loader)
```
**Issue: DDP spawns multiple processes unexpectedly**
Lightning auto-detects GPUs. Explicitly set devices:
```python
# Test on CPU first
trainer = L.Trainer(accelerator='cpu', devices=1)
# Then GPU
trainer = L.Trainer(accelerator='gpu', devices=1)
```
## Advanced topics
**Callbacks**: See [references/callbacks.md](references/callbacks.md) for EarlyStopping, ModelCheckpoint, custom callbacks, and callback hooks.
**Distributed strategies**: See [references/distributed.md](references/distributed.md) for DDP, FSDP, DeepSpeed ZeRO integration, multi-node setup.
**Hyperparameter tuning**: See [references/hyperparameter-tuning.md](references/hyperparameter-tuning.md) for integration with Optuna, Ray Tune, and WandB sweeps.
## Hardware requirements
- **CPU**: Works (good for debugging)
- **Single GPU**: Works
- **Multi-GPU**: DDP (default), FSDP, or DeepSpeed
- **Multi-node**: DDP, FSDP, DeepSpeed
- **TPU**: Supported (8 cores)
- **Apple MPS**: Supported
**Precision options**:
- FP32 (default)
- FP16 (V100, older GPUs)
- BF16 (A100/H100, recommended)
- FP8 (H100)
## Resources
- Docs: https://lightning.ai/docs/pytorch/stable/
- GitHub: https://github.com/Lightning-AI/pytorch-lightning ⭐ 29,000+
- Version: 2.5.5+
- Examples: https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples
- Discord: https://discord.gg/lightning-ai
- Used by: Kaggle winners, research labs, production teams

View file

@ -0,0 +1,436 @@
# PyTorch Lightning Callbacks
## Overview
Callbacks add functionality to training without modifying the LightningModule. They capture **non-essential logic** like checkpointing, early stopping, and logging.
## Built-In Callbacks
### 1. ModelCheckpoint
**Saves best models during training**:
```python
from lightning.pytorch.callbacks import ModelCheckpoint
# Save top 3 models based on validation loss
checkpoint = ModelCheckpoint(
dirpath='checkpoints/',
filename='model-{epoch:02d}-{val_loss:.2f}',
monitor='val_loss',
mode='min',
save_top_k=3,
save_last=True, # Also save last epoch
verbose=True
)
trainer = L.Trainer(callbacks=[checkpoint])
trainer.fit(model, train_loader, val_loader)
```
**Configuration options**:
```python
checkpoint = ModelCheckpoint(
monitor='val_acc', # Metric to monitor
mode='max', # 'max' for accuracy, 'min' for loss
save_top_k=5, # Keep best 5 models
save_last=True, # Save last epoch separately
every_n_epochs=1, # Save every N epochs
save_on_train_epoch_end=False, # Save on validation end instead
filename='best-{epoch}-{val_acc:.3f}', # Naming pattern
auto_insert_metric_name=False # Don't auto-add metric to filename
)
```
**Load checkpoint**:
```python
# Load best model
best_model_path = checkpoint.best_model_path
model = LitModel.load_from_checkpoint(best_model_path)
# Resume training
trainer = L.Trainer(callbacks=[checkpoint])
trainer.fit(model, train_loader, val_loader, ckpt_path='checkpoints/last.ckpt')
```
### 2. EarlyStopping
**Stops training when metric stops improving**:
```python
from lightning.pytorch.callbacks import EarlyStopping
early_stop = EarlyStopping(
monitor='val_loss',
patience=5, # Wait 5 epochs
mode='min',
min_delta=0.001, # Minimum change to qualify as improvement
verbose=True,
strict=True, # Crash if monitored metric not found
check_on_train_epoch_end=False # Check on validation end
)
trainer = L.Trainer(callbacks=[early_stop])
trainer.fit(model, train_loader, val_loader)
# Stops automatically if no improvement for 5 epochs
```
**Advanced usage**:
```python
early_stop = EarlyStopping(
monitor='val_loss',
patience=10,
min_delta=0.0,
verbose=True,
mode='min',
stopping_threshold=0.1, # Stop if val_loss < 0.1
divergence_threshold=5.0, # Stop if val_loss > 5.0
check_finite=True # Stop on NaN/Inf
)
```
### 3. LearningRateMonitor
**Logs learning rate**:
```python
from lightning.pytorch.callbacks import LearningRateMonitor
lr_monitor = LearningRateMonitor(
logging_interval='epoch', # Or 'step'
log_momentum=True # Also log momentum
)
trainer = L.Trainer(callbacks=[lr_monitor])
# Learning rate automatically logged to TensorBoard/WandB
```
### 4. TQDMProgressBar
**Customizes progress bar**:
```python
from lightning.pytorch.callbacks import TQDMProgressBar
progress_bar = TQDMProgressBar(
refresh_rate=10, # Update every 10 batches
process_position=0
)
trainer = L.Trainer(callbacks=[progress_bar])
```
### 5. GradientAccumulationScheduler
**Dynamic gradient accumulation**:
```python
from lightning.pytorch.callbacks import GradientAccumulationScheduler
# Accumulate more gradients as training progresses
accumulator = GradientAccumulationScheduler(
scheduling={
0: 8, # Epochs 0-4: accumulate 8 batches
5: 4, # Epochs 5-9: accumulate 4 batches
10: 2 # Epochs 10+: accumulate 2 batches
}
)
trainer = L.Trainer(callbacks=[accumulator])
```
### 6. StochasticWeightAveraging (SWA)
**Averages weights for better generalization**:
```python
from lightning.pytorch.callbacks import StochasticWeightAveraging
swa = StochasticWeightAveraging(
swa_lrs=1e-2, # SWA learning rate
swa_epoch_start=0.8, # Start at 80% of training
annealing_epochs=10, # Annealing period
annealing_strategy='cos' # 'cos' or 'linear'
)
trainer = L.Trainer(callbacks=[swa])
```
## Custom Callbacks
### Basic Custom Callback
```python
from lightning.pytorch.callbacks import Callback
class PrintingCallback(Callback):
def on_train_start(self, trainer, pl_module):
print("Training is starting!")
def on_train_end(self, trainer, pl_module):
print("Training is done!")
def on_epoch_end(self, trainer, pl_module):
print(f"Epoch {trainer.current_epoch} ended")
# Use it
trainer = L.Trainer(callbacks=[PrintingCallback()])
```
### Advanced Custom Callback
```python
class MetricsCallback(Callback):
"""Logs custom metrics every N batches."""
def __init__(self, log_every_n_batches=100):
self.log_every_n_batches = log_every_n_batches
self.metrics = []
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if batch_idx % self.log_every_n_batches == 0:
# Compute custom metric
metric = self.compute_metric(outputs)
self.metrics.append(metric)
# Log to Lightning
pl_module.log('custom_metric', metric)
def compute_metric(self, outputs):
# Your custom logic
return outputs['loss'].item()
def state_dict(self):
"""Save callback state in checkpoint."""
return {'metrics': self.metrics}
def load_state_dict(self, state_dict):
"""Restore callback state from checkpoint."""
self.metrics = state_dict['metrics']
```
### Gradient Monitoring Callback
```python
class GradientMonitorCallback(Callback):
"""Monitor gradient norms."""
def on_after_backward(self, trainer, pl_module):
# Compute gradient norm
total_norm = 0.0
for p in pl_module.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
# Log
pl_module.log('grad_norm', total_norm)
# Warn if exploding
if total_norm > 100:
print(f"Warning: Large gradient norm: {total_norm:.2f}")
```
### Model Inspection Callback
```python
class ModelInspectionCallback(Callback):
"""Inspect model activations during training."""
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
if batch_idx == 0: # First batch of epoch
# Register hooks
self.activations = {}
def get_activation(name):
def hook(model, input, output):
self.activations[name] = output.detach()
return hook
# Attach to specific layers
pl_module.model.layer1.register_forward_hook(get_activation('layer1'))
pl_module.model.layer2.register_forward_hook(get_activation('layer2'))
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if batch_idx == 0:
# Log activation statistics
for name, activation in self.activations.items():
mean = activation.mean().item()
std = activation.std().item()
pl_module.log(f'{name}_mean', mean)
pl_module.log(f'{name}_std', std)
```
## Callback Hooks
**All available hooks**:
```python
class MyCallback(Callback):
# Setup/Teardown
def setup(self, trainer, pl_module, stage):
"""Called at beginning of fit/test/predict."""
pass
def teardown(self, trainer, pl_module, stage):
"""Called at end of fit/test/predict."""
pass
# Training
def on_train_start(self, trainer, pl_module):
pass
def on_train_epoch_start(self, trainer, pl_module):
pass
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
pass
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
pass
def on_train_epoch_end(self, trainer, pl_module):
pass
def on_train_end(self, trainer, pl_module):
pass
# Validation
def on_validation_start(self, trainer, pl_module):
pass
def on_validation_epoch_start(self, trainer, pl_module):
pass
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
pass
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
pass
def on_validation_epoch_end(self, trainer, pl_module):
pass
def on_validation_end(self, trainer, pl_module):
pass
# Test (same structure as validation)
def on_test_start(self, trainer, pl_module):
pass
# ... (test_epoch_start, test_batch_start, etc.)
# Predict
def on_predict_start(self, trainer, pl_module):
pass
# ... (predict_epoch_start, predict_batch_start, etc.)
# Backward
def on_before_backward(self, trainer, pl_module, loss):
pass
def on_after_backward(self, trainer, pl_module):
pass
# Optimizer
def on_before_optimizer_step(self, trainer, pl_module, optimizer):
pass
# Checkpointing
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
"""Add data to checkpoint."""
pass
def on_load_checkpoint(self, trainer, pl_module, checkpoint):
"""Restore data from checkpoint."""
pass
```
## Combining Multiple Callbacks
```python
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
# Create all callbacks
checkpoint = ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=3)
early_stop = EarlyStopping(monitor='val_loss', patience=5)
lr_monitor = LearningRateMonitor(logging_interval='epoch')
custom_callback = MyCustomCallback()
# Add all to Trainer
trainer = L.Trainer(
callbacks=[checkpoint, early_stop, lr_monitor, custom_callback]
)
trainer.fit(model, train_loader, val_loader)
```
**Execution order**: Callbacks execute in the order they're added
## Best Practices
### 1. Keep Callbacks Independent
**Bad** (dependent on other callback):
```python
class BadCallback(Callback):
def on_train_end(self, trainer, pl_module):
# Assumes ModelCheckpoint is present
best_path = trainer.checkpoint_callback.best_model_path # Fragile!
```
**Good** (self-contained):
```python
class GoodCallback(Callback):
def on_train_end(self, trainer, pl_module):
# Find checkpoint callback if present
for callback in trainer.callbacks:
if isinstance(callback, ModelCheckpoint):
best_path = callback.best_model_path
break
```
### 2. Use State Dict for Persistence
```python
class StatefulCallback(Callback):
def __init__(self):
self.counter = 0
self.history = []
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.counter += 1
self.history.append(outputs['loss'].item())
def state_dict(self):
"""Save state."""
return {
'counter': self.counter,
'history': self.history
}
def load_state_dict(self, state_dict):
"""Restore state."""
self.counter = state_dict['counter']
self.history = state_dict['history']
```
### 3. Handle Distributed Training
```python
class DistributedCallback(Callback):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
# Only run on main process
if trainer.is_global_zero:
print("This only prints once in distributed training")
# Run on all processes
loss = outputs['loss']
# ... do something with loss on each GPU
```
## Resources
- Callback API: https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html
- Built-in callbacks: https://lightning.ai/docs/pytorch/stable/api_references.html#callbacks
- Examples: https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/callbacks

View file

@ -0,0 +1,490 @@
# PyTorch Lightning Distributed Training
## Distributed Strategies
Lightning supports multiple distributed strategies with a single parameter change.
### 1. DDP (DistributedDataParallel)
**Default strategy for multi-GPU**:
```python
# Automatic DDP on all available GPUs
trainer = L.Trainer(accelerator='gpu', devices=4, strategy='ddp')
# Or auto-detect
trainer = L.Trainer(accelerator='gpu', devices='auto')
```
**How DDP works**:
- Replicates model on each GPU
- Each GPU processes different batch
- Gradients all-reduced across GPUs
- Model weights synchronized
**Launch**:
```bash
# Lightning handles spawning processes automatically
python train.py
```
**DDP Configuration**:
```python
from lightning.pytorch.strategies import DDPStrategy
strategy = DDPStrategy(
find_unused_parameters=False, # Set True if model has unused params
gradient_as_bucket_view=True, # Memory optimization
static_graph=False, # Set True if graph doesn't change
)
trainer = L.Trainer(strategy=strategy)
```
### 2. FSDP (Fully Sharded Data Parallel)
**For large models (7B+ parameters)**:
```python
from lightning.pytorch.strategies import FSDPStrategy
strategy = FSDPStrategy(
sharding_strategy="FULL_SHARD", # ZeRO-3 equivalent
activation_checkpointing=None, # Or specify layer types
cpu_offload=False, # CPU offload for memory
)
trainer = L.Trainer(
accelerator='gpu',
devices=8,
strategy=strategy,
precision='bf16' # Recommended with FSDP
)
trainer.fit(model, train_loader)
```
**FSDP Sharding Strategies**:
```python
# FULL_SHARD (most memory efficient, equivalent to ZeRO-3)
strategy = FSDPStrategy(sharding_strategy="FULL_SHARD")
# SHARD_GRAD_OP (less memory efficient, equivalent to ZeRO-2)
strategy = FSDPStrategy(sharding_strategy="SHARD_GRAD_OP")
# NO_SHARD (no sharding, like DDP)
strategy = FSDPStrategy(sharding_strategy="NO_SHARD")
```
**Auto-wrap policy** (wrap transformer blocks):
```python
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
import functools
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={GPT2Block}
)
strategy = FSDPStrategy(
auto_wrap_policy=auto_wrap_policy,
activation_checkpointing_policy={GPT2Block} # Checkpoint these blocks
)
```
### 3. DeepSpeed
**For massive models (70B+ parameters)**:
```python
from lightning.pytorch.strategies import DeepSpeedStrategy
# DeepSpeed ZeRO-3 with CPU offload
strategy = DeepSpeedStrategy(
stage=3, # ZeRO-3
offload_optimizer=True, # CPU offload optimizer
offload_parameters=True, # CPU offload parameters
cpu_checkpointing=True, # Checkpoint to CPU
)
trainer = L.Trainer(
accelerator='gpu',
devices=8,
strategy=strategy,
precision='bf16'
)
trainer.fit(model, train_loader)
```
**DeepSpeed configuration file**:
```json
{
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": 5e8,
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e6
},
"bf16": {
"enabled": true
}
}
```
**Use config file**:
```python
strategy = DeepSpeedStrategy(config='deepspeed_config.json')
trainer = L.Trainer(strategy=strategy)
```
### 4. DDP Spawn
**Windows-compatible DDP**:
```python
# Use when DDP doesn't work (e.g., Windows, Jupyter)
trainer = L.Trainer(
accelerator='gpu',
devices=2,
strategy='ddp_spawn' # Spawns new processes
)
```
**Note**: Slower than DDP due to process spawning overhead
## Multi-Node Training
### Setup Multi-Node Cluster
**Node 0 (master)**:
```bash
export MASTER_ADDR=192.168.1.100
export MASTER_PORT=12355
export WORLD_SIZE=16 # 2 nodes × 8 GPUs
export NODE_RANK=0
python train.py
```
**Node 1 (worker)**:
```bash
export MASTER_ADDR=192.168.1.100
export MASTER_PORT=12355
export WORLD_SIZE=16
export NODE_RANK=1
python train.py
```
**Training script**:
```python
trainer = L.Trainer(
accelerator='gpu',
devices=8, # GPUs per node
num_nodes=2, # Total nodes
strategy='ddp'
)
trainer.fit(model, train_loader)
```
### SLURM Integration
**SLURM job script**:
```bash
#!/bin/bash
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=8
#SBATCH --gres=gpu:8
#SBATCH --time=24:00:00
# Lightning auto-detects SLURM environment
srun python train.py
```
**Training script** (no changes needed):
```python
# Lightning automatically reads SLURM environment variables
trainer = L.Trainer(
accelerator='gpu',
devices=8,
num_nodes=4, # From SBATCH --nodes
strategy='ddp'
)
```
### Kubernetes (KubeFlow)
**Training script**:
```python
import os
# Lightning auto-detects Kubernetes
trainer = L.Trainer(
accelerator='gpu',
devices=int(os.getenv('WORLD_SIZE', 1)),
strategy='ddp'
)
```
## Mixed Precision Training
### BF16 (A100/H100)
```python
trainer = L.Trainer(
precision='bf16', # Or 'bf16-mixed'
accelerator='gpu'
)
```
**Advantages**:
- No gradient scaler needed
- Same dynamic range as FP32
- 2× speedup, 50% memory reduction
### FP16 (V100, older GPUs)
```python
trainer = L.Trainer(
precision='16-mixed', # Or just '16'
accelerator='gpu'
)
```
**Automatic gradient scaling** handled by Lightning
### FP8 (H100)
```python
# Requires transformer_engine
# pip install transformer-engine[pytorch]
trainer = L.Trainer(
precision='transformer-engine',
accelerator='gpu'
)
```
**Benefits**: 2× faster than BF16 on H100
## Gradient Accumulation
**Simulate larger batch size**:
```python
trainer = L.Trainer(
accumulate_grad_batches=4, # Accumulate 4 batches
precision='bf16'
)
# Effective batch = batch_size × accumulate_grad_batches × num_gpus
# Example: 32 × 4 × 8 = 1024
```
**Dynamic accumulation**:
```python
# Accumulate more early in training
trainer = L.Trainer(
accumulate_grad_batches={
0: 8, # Epochs 0-4: accumulate 8
5: 4, # Epochs 5-9: accumulate 4
10: 2 # Epochs 10+: accumulate 2
}
)
```
## Checkpointing in Distributed
### Save Checkpoint
```python
from lightning.pytorch.callbacks import ModelCheckpoint
# Only rank 0 saves by default
checkpoint = ModelCheckpoint(
dirpath='checkpoints/',
filename='model-{epoch:02d}',
save_top_k=3
)
trainer = L.Trainer(callbacks=[checkpoint], strategy='ddp')
trainer.fit(model, train_loader)
```
**Manual save**:
```python
class MyModel(L.LightningModule):
def training_step(self, batch, batch_idx):
# Training...
loss = ...
# Save every 1000 steps (only rank 0)
if batch_idx % 1000 == 0 and self.trainer.is_global_zero:
self.trainer.save_checkpoint(f'checkpoint_step_{batch_idx}.ckpt')
return loss
```
### Load Checkpoint
```python
# Resume training
trainer = L.Trainer(strategy='ddp')
trainer.fit(model, train_loader, ckpt_path='checkpoints/last.ckpt')
# Load for inference
model = MyModel.load_from_checkpoint('checkpoints/best.ckpt')
model.eval()
```
## Strategy Comparison
| Strategy | Memory Efficiency | Speed | Use Case |
|----------|------------------|-------|----------|
| DDP | Low | Fast | Small models (<7B), single node |
| FSDP | High | Medium | Large models (7-70B) |
| DeepSpeed ZeRO-2 | Medium | Fast | Medium models (1-13B) |
| DeepSpeed ZeRO-3 | Very High | Slower | Massive models (70B+) |
| DDP Spawn | Low | Slow | Windows, debugging |
## Best Practices
### 1. Choose Right Strategy
```python
# Model size guide
if model_params < 1e9: # <1B
strategy = 'ddp'
elif model_params < 7e9: # 1-7B
strategy = 'ddp' or DeepSpeedStrategy(stage=2)
elif model_params < 70e9: # 7-70B
strategy = FSDPStrategy(sharding_strategy="FULL_SHARD")
else: # 70B+
strategy = DeepSpeedStrategy(stage=3, offload_optimizer=True)
trainer = L.Trainer(strategy=strategy)
```
### 2. Avoid Sync Issues
```python
class MyModel(L.LightningModule):
def training_step(self, batch, batch_idx):
# WRONG: This runs on all GPUs independently
if batch_idx % 100 == 0:
self.log_something() # Logged 8 times on 8 GPUs!
# CORRECT: Use is_global_zero
if batch_idx % 100 == 0 and self.trainer.is_global_zero:
self.log_something() # Logged once
loss = ...
return loss
```
### 3. Efficient Data Loading
```python
from torch.utils.data import DataLoader, DistributedSampler
# Lightning handles DistributedSampler automatically
train_loader = DataLoader(
dataset,
batch_size=32,
num_workers=4, # 4 workers per GPU
pin_memory=True,
persistent_workers=True
)
# Lightning automatically wraps with DistributedSampler in DDP
trainer.fit(model, train_loader)
```
### 4. Reduce Communication Overhead
```python
from lightning.pytorch.strategies import DDPStrategy
strategy = DDPStrategy(
gradient_as_bucket_view=True, # Reduce memory copies
static_graph=True, # If model graph doesn't change (faster)
)
trainer = L.Trainer(strategy=strategy)
```
## Common Issues
### Issue: NCCL Timeout
**Symptom**: Training hangs with `NCCL timeout` error
**Solution 1**: Increase timeout
```bash
export NCCL_TIMEOUT=3600 # 1 hour
python train.py
```
**Solution 2**: Check network
```bash
# Test inter-node communication
nvidia-smi nvlink -s
# Verify all nodes can ping each other
ping <node-2-ip>
```
### Issue: OOM with FSDP
**Solution**: Enable CPU offload
```python
strategy = FSDPStrategy(
sharding_strategy="FULL_SHARD",
cpu_offload=True # Offload to CPU
)
```
### Issue: Different Results with DDP
**Cause**: Different random seeds per GPU
**Solution**: Set seed in LightningModule
```python
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
L.seed_everything(42, workers=True) # Same seed everywhere
```
### Issue: DeepSpeed Config Errors
**Solution**: Use Lightning's auto config
```python
strategy = DeepSpeedStrategy(
stage=3,
# Don't specify config file, Lightning generates automatically
)
```
## Resources
- Distributed strategies: https://lightning.ai/docs/pytorch/stable/accelerators/gpu_intermediate.html
- FSDP guide: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html
- DeepSpeed: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/deepspeed.html
- Multi-node: https://lightning.ai/docs/pytorch/stable/clouds/cluster.html

View file

@ -0,0 +1,556 @@
# Hyperparameter Tuning with PyTorch Lightning
## Integration with Tuning Frameworks
Lightning integrates seamlessly with popular hyperparameter tuning libraries.
### 1. Ray Tune Integration
**Installation**:
```bash
pip install ray[tune]
pip install lightning
```
**Basic Ray Tune example**:
```python
import lightning as L
from ray import tune
from ray.tune.integration.pytorch_lightning import TuneReportCallback
class LitModel(L.LightningModule):
def __init__(self, lr, batch_size):
super().__init__()
self.lr = lr
self.batch_size = batch_size
self.model = nn.Sequential(nn.Linear(10, 128), nn.ReLU(), nn.Linear(128, 1))
def training_step(self, batch, batch_idx):
loss = self.model(batch).mean()
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
val_loss = self.model(batch).mean()
self.log('val_loss', val_loss)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
def train_fn(config):
"""Training function for Ray Tune."""
model = LitModel(lr=config["lr"], batch_size=config["batch_size"])
# Add callback to report metrics to Tune
trainer = L.Trainer(
max_epochs=10,
callbacks=[TuneReportCallback({"loss": "val_loss"}, on="validation_end")]
)
trainer.fit(model, train_loader, val_loader)
# Define search space
config = {
"lr": tune.loguniform(1e-5, 1e-1),
"batch_size": tune.choice([16, 32, 64, 128])
}
# Run hyperparameter search
analysis = tune.run(
train_fn,
config=config,
num_samples=20, # 20 trials
resources_per_trial={"gpu": 1}
)
# Best hyperparameters
best_config = analysis.get_best_config(metric="loss", mode="min")
print(f"Best config: {best_config}")
```
**Advanced: Population-Based Training (PBT)**:
```python
from ray.tune.schedulers import PopulationBasedTraining
# PBT scheduler
scheduler = PopulationBasedTraining(
time_attr='training_iteration',
metric='val_loss',
mode='min',
perturbation_interval=5, # Perturb every 5 epochs
hyperparam_mutations={
"lr": tune.loguniform(1e-5, 1e-1),
"batch_size": [16, 32, 64, 128]
}
)
analysis = tune.run(
train_fn,
config=config,
num_samples=8, # Population size
scheduler=scheduler,
resources_per_trial={"gpu": 1}
)
```
### 2. Optuna Integration
**Installation**:
```bash
pip install optuna
pip install optuna-integration
```
**Optuna example**:
```python
import optuna
from optuna.integration import PyTorchLightningPruningCallback
def objective(trial):
# Suggest hyperparameters
lr = trial.suggest_loguniform('lr', 1e-5, 1e-1)
batch_size = trial.suggest_categorical('batch_size', [16, 32, 64, 128])
n_layers = trial.suggest_int('n_layers', 1, 3)
hidden_size = trial.suggest_int('hidden_size', 64, 512, step=64)
# Create model
model = LitModel(lr=lr, n_layers=n_layers, hidden_size=hidden_size)
# Pruning callback (early stopping for bad trials)
pruning_callback = PyTorchLightningPruningCallback(trial, monitor="val_loss")
trainer = L.Trainer(
max_epochs=20,
callbacks=[pruning_callback],
enable_progress_bar=False,
logger=False
)
trainer.fit(model, train_loader, val_loader)
return trainer.callback_metrics["val_loss"].item()
# Create study
study = optuna.create_study(
direction='minimize',
pruner=optuna.pruners.MedianPruner() # Prune bad trials early
)
# Optimize
study.optimize(objective, n_trials=50, timeout=3600)
# Best params
print(f"Best trial: {study.best_trial.params}")
print(f"Best value: {study.best_value}")
# Visualization
optuna.visualization.plot_optimization_history(study).show()
optuna.visualization.plot_param_importances(study).show()
```
**Optuna with distributed training**:
```python
import optuna
# Shared database for distributed optimization
storage = optuna.storages.RDBStorage(
url='postgresql://user:pass@localhost/optuna'
)
study = optuna.create_study(
study_name='distributed_study',
storage=storage,
load_if_exists=True,
direction='minimize'
)
# Run on multiple machines
study.optimize(objective, n_trials=50)
```
### 3. Weights & Biases (WandB) Sweeps
**Installation**:
```bash
pip install wandb
```
**WandB sweep config** (`sweep.yaml`):
```yaml
program: train.py
method: bayes
metric:
name: val_loss
goal: minimize
parameters:
lr:
distribution: log_uniform_values
min: 0.00001
max: 0.1
batch_size:
values: [16, 32, 64, 128]
optimizer:
values: ['adam', 'sgd', 'adamw']
dropout:
distribution: uniform
min: 0.0
max: 0.5
```
**Training script** (`train.py`):
```python
import wandb
import lightning as L
from lightning.pytorch.loggers import WandbLogger
def train():
# Initialize wandb
wandb.init()
config = wandb.config
# Create model with sweep params
model = LitModel(
lr=config.lr,
batch_size=config.batch_size,
optimizer=config.optimizer,
dropout=config.dropout
)
# WandB logger
wandb_logger = WandbLogger(project='hyperparameter-sweep')
trainer = L.Trainer(
max_epochs=20,
logger=wandb_logger
)
trainer.fit(model, train_loader, val_loader)
if __name__ == '__main__':
train()
```
**Launch sweep**:
```bash
# Initialize sweep
wandb sweep sweep.yaml
# Output: wandb: Created sweep with ID: abc123
# Run agent (can run on multiple machines)
wandb agent your-entity/your-project/abc123
```
### 4. Hyperopt Integration
**Installation**:
```bash
pip install hyperopt
```
**Hyperopt example**:
```python
from hyperopt import hp, fmin, tpe, Trials
def objective(params):
model = LitModel(
lr=params['lr'],
batch_size=int(params['batch_size']),
hidden_size=int(params['hidden_size'])
)
trainer = L.Trainer(
max_epochs=10,
enable_progress_bar=False,
logger=False
)
trainer.fit(model, train_loader, val_loader)
# Return loss (minimize)
return trainer.callback_metrics["val_loss"].item()
# Define search space
space = {
'lr': hp.loguniform('lr', np.log(1e-5), np.log(1e-1)),
'batch_size': hp.quniform('batch_size', 16, 128, 16),
'hidden_size': hp.quniform('hidden_size', 64, 512, 64)
}
# Optimize
trials = Trials()
best = fmin(
fn=objective,
space=space,
algo=tpe.suggest, # Tree-structured Parzen Estimator
max_evals=50,
trials=trials
)
print(f"Best hyperparameters: {best}")
```
## Built-In Lightning Tuning
### Auto Learning Rate Finder
```python
class LitModel(L.LightningModule):
def __init__(self, lr=1e-3):
super().__init__()
self.lr = lr
self.model = nn.Linear(10, 1)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
def training_step(self, batch, batch_idx):
loss = self.model(batch).mean()
return loss
# Find optimal learning rate
model = LitModel()
trainer = L.Trainer(auto_lr_find=True)
# This runs LR finder before training
trainer.tune(model, train_loader)
# Or manually
from lightning.pytorch.tuner import Tuner
tuner = Tuner(trainer)
lr_finder = tuner.lr_find(model, train_loader)
# Plot results
fig = lr_finder.plot(suggest=True)
fig.show()
# Get suggested LR
suggested_lr = lr_finder.suggestion()
print(f"Suggested LR: {suggested_lr}")
# Update model
model.lr = suggested_lr
# Train with optimal LR
trainer.fit(model, train_loader)
```
### Auto Batch Size Finder
```python
class LitModel(L.LightningModule):
def __init__(self, batch_size=32):
super().__init__()
self.batch_size = batch_size
self.model = nn.Linear(10, 1)
def train_dataloader(self):
return DataLoader(dataset, batch_size=self.batch_size)
model = LitModel()
trainer = L.Trainer(auto_scale_batch_size='binsearch')
# Find optimal batch size
trainer.tune(model)
print(f"Optimal batch size: {model.batch_size}")
# Train with optimal batch size
trainer.fit(model, train_loader)
```
## Advanced Tuning Strategies
### 1. Multi-Fidelity Optimization (Successive Halving)
```python
from ray.tune.schedulers import ASHAScheduler
# ASHA: Asynchronous Successive Halving Algorithm
scheduler = ASHAScheduler(
max_t=100, # Max epochs
grace_period=10, # Min epochs before stopping
reduction_factor=2 # Halve resources each round
)
analysis = tune.run(
train_fn,
config=config,
num_samples=64,
scheduler=scheduler,
resources_per_trial={"gpu": 1}
)
```
**How it works**:
- Start 64 trials
- After 10 epochs, stop bottom 50% (32 trials remain)
- After 20 epochs, stop bottom 50% (16 trials remain)
- After 40 epochs, stop bottom 50% (8 trials remain)
- After 80 epochs, stop bottom 50% (4 trials remain)
- Run remaining 4 trials to completion (100 epochs)
### 2. Bayesian Optimization
```python
from ray.tune.search.bayesopt import BayesOptSearch
search = BayesOptSearch(
metric="val_loss",
mode="min"
)
analysis = tune.run(
train_fn,
config=config,
num_samples=50,
search_alg=search,
resources_per_trial={"gpu": 1}
)
```
### 3. Grid Search
```python
from ray import tune
# Exhaustive grid search
config = {
"lr": tune.grid_search([1e-5, 1e-4, 1e-3, 1e-2]),
"batch_size": tune.grid_search([16, 32, 64, 128]),
"optimizer": tune.grid_search(['adam', 'sgd', 'adamw'])
}
# Total trials: 4 × 4 × 3 = 48
analysis = tune.run(train_fn, config=config)
```
### 4. Random Search
```python
config = {
"lr": tune.loguniform(1e-5, 1e-1),
"batch_size": tune.choice([16, 32, 64, 128]),
"dropout": tune.uniform(0.0, 0.5),
"hidden_size": tune.randint(64, 512)
}
# Random sampling
analysis = tune.run(
train_fn,
config=config,
num_samples=100 # 100 random samples
)
```
## Best Practices
### 1. Start Simple
```python
# Phase 1: Coarse search (fast)
coarse_config = {
"lr": tune.loguniform(1e-5, 1e-1),
"batch_size": tune.choice([32, 64])
}
coarse_analysis = tune.run(train_fn, config=coarse_config, num_samples=10, max_epochs=5)
# Phase 2: Fine-tune around best (slow)
best_lr = coarse_analysis.best_config["lr"]
fine_config = {
"lr": tune.uniform(best_lr * 0.5, best_lr * 2),
"batch_size": tune.choice([16, 32, 64, 128])
}
fine_analysis = tune.run(train_fn, config=fine_config, num_samples=20, max_epochs=20)
```
### 2. Use Checkpointing
```python
def train_fn(config, checkpoint_dir=None):
model = LitModel(lr=config["lr"])
trainer = L.Trainer(
max_epochs=100,
callbacks=[
TuneReportCheckpointCallback(
metrics={"loss": "val_loss"},
filename="checkpoint",
on="validation_end"
)
]
)
# Resume from checkpoint if exists
ckpt_path = None
if checkpoint_dir:
ckpt_path = os.path.join(checkpoint_dir, "checkpoint")
trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path)
```
### 3. Monitor Resource Usage
```python
import GPUtil
def train_fn(config):
# Before training
GPUs = GPUtil.getGPUs()
print(f"GPU memory before: {GPUs[0].memoryUsed} MB")
# Train
model = LitModel(lr=config["lr"], batch_size=config["batch_size"])
trainer.fit(model, train_loader)
# After training
GPUs = GPUtil.getGPUs()
print(f"GPU memory after: {GPUs[0].memoryUsed} MB")
```
## Common Issues
### Issue: Trials Running Out of Memory
**Solution**: Reduce concurrent trials or batch size
```python
analysis = tune.run(
train_fn,
config=config,
resources_per_trial={"gpu": 0.5}, # 2 trials per GPU
max_concurrent_trials=2 # Limit concurrent trials
)
```
### Issue: Slow Hyperparameter Search
**Solution**: Use early stopping scheduler
```python
from ray.tune.schedulers import ASHAScheduler
scheduler = ASHAScheduler(
max_t=100,
grace_period=5, # Stop bad trials after 5 epochs
reduction_factor=3
)
```
### Issue: Can't Reproduce Best Trial
**Solution**: Set seeds in training function
```python
def train_fn(config):
L.seed_everything(42, workers=True)
# Rest of training...
```
## Resources
- Ray Tune + Lightning: https://docs.ray.io/en/latest/tune/examples/tune-pytorch-lightning.html
- Optuna: https://optuna.readthedocs.io/
- WandB Sweeps: https://docs.wandb.ai/guides/sweeps
- Lightning Tuner: https://lightning.ai/docs/pytorch/stable/tuning.html

View file

@ -0,0 +1,496 @@
---
name: qdrant-vector-search
description: High-performance vector similarity search engine for RAG and semantic search. Use when building production RAG systems requiring fast nearest neighbor search, hybrid search with filtering, or scalable vector storage with Rust-powered performance.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [qdrant-client>=1.12.0]
metadata:
hermes:
tags: [RAG, Vector Search, Qdrant, Semantic Search, Embeddings, Similarity Search, HNSW, Production, Distributed]
---
# Qdrant - Vector Similarity Search Engine
High-performance vector database written in Rust for production RAG and semantic search.
## When to use Qdrant
**Use Qdrant when:**
- Building production RAG systems requiring low latency
- Need hybrid search (vectors + metadata filtering)
- Require horizontal scaling with sharding/replication
- Want on-premise deployment with full data control
- Need multi-vector storage per record (dense + sparse)
- Building real-time recommendation systems
**Key features:**
- **Rust-powered**: Memory-safe, high performance
- **Rich filtering**: Filter by any payload field during search
- **Multiple vectors**: Dense, sparse, multi-dense per point
- **Quantization**: Scalar, product, binary for memory efficiency
- **Distributed**: Raft consensus, sharding, replication
- **REST + gRPC**: Both APIs with full feature parity
**Use alternatives instead:**
- **Chroma**: Simpler setup, embedded use cases
- **FAISS**: Maximum raw speed, research/batch processing
- **Pinecone**: Fully managed, zero ops preferred
- **Weaviate**: GraphQL preference, built-in vectorizers
## Quick start
### Installation
```bash
# Python client
pip install qdrant-client
# Docker (recommended for development)
docker run -p 6333:6333 -p 6334:6334 qdrant/qdrant
# Docker with persistent storage
docker run -p 6333:6333 -p 6334:6334 \
-v $(pwd)/qdrant_storage:/qdrant/storage \
qdrant/qdrant
```
### Basic usage
```python
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
# Connect to Qdrant
client = QdrantClient(host="localhost", port=6333)
# Create collection
client.create_collection(
collection_name="documents",
vectors_config=VectorParams(size=384, distance=Distance.COSINE)
)
# Insert vectors with payload
client.upsert(
collection_name="documents",
points=[
PointStruct(
id=1,
vector=[0.1, 0.2, ...], # 384-dim vector
payload={"title": "Doc 1", "category": "tech"}
),
PointStruct(
id=2,
vector=[0.3, 0.4, ...],
payload={"title": "Doc 2", "category": "science"}
)
]
)
# Search with filtering
results = client.search(
collection_name="documents",
query_vector=[0.15, 0.25, ...],
query_filter={
"must": [{"key": "category", "match": {"value": "tech"}}]
},
limit=10
)
for point in results:
print(f"ID: {point.id}, Score: {point.score}, Payload: {point.payload}")
```
## Core concepts
### Points - Basic data unit
```python
from qdrant_client.models import PointStruct
# Point = ID + Vector(s) + Payload
point = PointStruct(
id=123, # Integer or UUID string
vector=[0.1, 0.2, 0.3, ...], # Dense vector
payload={ # Arbitrary JSON metadata
"title": "Document title",
"category": "tech",
"timestamp": 1699900000,
"tags": ["python", "ml"]
}
)
# Batch upsert (recommended)
client.upsert(
collection_name="documents",
points=[point1, point2, point3],
wait=True # Wait for indexing
)
```
### Collections - Vector containers
```python
from qdrant_client.models import VectorParams, Distance, HnswConfigDiff
# Create with HNSW configuration
client.create_collection(
collection_name="documents",
vectors_config=VectorParams(
size=384, # Vector dimensions
distance=Distance.COSINE # COSINE, EUCLID, DOT, MANHATTAN
),
hnsw_config=HnswConfigDiff(
m=16, # Connections per node (default 16)
ef_construct=100, # Build-time accuracy (default 100)
full_scan_threshold=10000 # Switch to brute force below this
),
on_disk_payload=True # Store payload on disk
)
# Collection info
info = client.get_collection("documents")
print(f"Points: {info.points_count}, Vectors: {info.vectors_count}")
```
### Distance metrics
| Metric | Use Case | Range |
|--------|----------|-------|
| `COSINE` | Text embeddings, normalized vectors | 0 to 2 |
| `EUCLID` | Spatial data, image features | 0 to ∞ |
| `DOT` | Recommendations, unnormalized | -∞ to ∞ |
| `MANHATTAN` | Sparse features, discrete data | 0 to ∞ |
## Search operations
### Basic search
```python
# Simple nearest neighbor search
results = client.search(
collection_name="documents",
query_vector=[0.1, 0.2, ...],
limit=10,
with_payload=True,
with_vectors=False # Don't return vectors (faster)
)
```
### Filtered search
```python
from qdrant_client.models import Filter, FieldCondition, MatchValue, Range
# Complex filtering
results = client.search(
collection_name="documents",
query_vector=query_embedding,
query_filter=Filter(
must=[
FieldCondition(key="category", match=MatchValue(value="tech")),
FieldCondition(key="timestamp", range=Range(gte=1699000000))
],
must_not=[
FieldCondition(key="status", match=MatchValue(value="archived"))
]
),
limit=10
)
# Shorthand filter syntax
results = client.search(
collection_name="documents",
query_vector=query_embedding,
query_filter={
"must": [
{"key": "category", "match": {"value": "tech"}},
{"key": "price", "range": {"gte": 10, "lte": 100}}
]
},
limit=10
)
```
### Batch search
```python
from qdrant_client.models import SearchRequest
# Multiple queries in one request
results = client.search_batch(
collection_name="documents",
requests=[
SearchRequest(vector=[0.1, ...], limit=5),
SearchRequest(vector=[0.2, ...], limit=5, filter={"must": [...]}),
SearchRequest(vector=[0.3, ...], limit=10)
]
)
```
## RAG integration
### With sentence-transformers
```python
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance, PointStruct
# Initialize
encoder = SentenceTransformer("all-MiniLM-L6-v2")
client = QdrantClient(host="localhost", port=6333)
# Create collection
client.create_collection(
collection_name="knowledge_base",
vectors_config=VectorParams(size=384, distance=Distance.COSINE)
)
# Index documents
documents = [
{"id": 1, "text": "Python is a programming language", "source": "wiki"},
{"id": 2, "text": "Machine learning uses algorithms", "source": "textbook"},
]
points = [
PointStruct(
id=doc["id"],
vector=encoder.encode(doc["text"]).tolist(),
payload={"text": doc["text"], "source": doc["source"]}
)
for doc in documents
]
client.upsert(collection_name="knowledge_base", points=points)
# RAG retrieval
def retrieve(query: str, top_k: int = 5) -> list[dict]:
query_vector = encoder.encode(query).tolist()
results = client.search(
collection_name="knowledge_base",
query_vector=query_vector,
limit=top_k
)
return [{"text": r.payload["text"], "score": r.score} for r in results]
# Use in RAG pipeline
context = retrieve("What is Python?")
prompt = f"Context: {context}\n\nQuestion: What is Python?"
```
### With LangChain
```python
from langchain_community.vectorstores import Qdrant
from langchain_community.embeddings import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vectorstore = Qdrant.from_documents(documents, embeddings, url="http://localhost:6333", collection_name="docs")
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
```
### With LlamaIndex
```python
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.core import VectorStoreIndex, StorageContext
vector_store = QdrantVectorStore(client=client, collection_name="llama_docs")
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
query_engine = index.as_query_engine()
```
## Multi-vector support
### Named vectors (different embedding models)
```python
from qdrant_client.models import VectorParams, Distance
# Collection with multiple vector types
client.create_collection(
collection_name="hybrid_search",
vectors_config={
"dense": VectorParams(size=384, distance=Distance.COSINE),
"sparse": VectorParams(size=30000, distance=Distance.DOT)
}
)
# Insert with named vectors
client.upsert(
collection_name="hybrid_search",
points=[
PointStruct(
id=1,
vector={
"dense": dense_embedding,
"sparse": sparse_embedding
},
payload={"text": "document text"}
)
]
)
# Search specific vector
results = client.search(
collection_name="hybrid_search",
query_vector=("dense", query_dense), # Specify which vector
limit=10
)
```
### Sparse vectors (BM25, SPLADE)
```python
from qdrant_client.models import SparseVectorParams, SparseIndexParams, SparseVector
# Collection with sparse vectors
client.create_collection(
collection_name="sparse_search",
vectors_config={},
sparse_vectors_config={"text": SparseVectorParams(index=SparseIndexParams(on_disk=False))}
)
# Insert sparse vector
client.upsert(
collection_name="sparse_search",
points=[PointStruct(id=1, vector={"text": SparseVector(indices=[1, 5, 100], values=[0.5, 0.8, 0.2])}, payload={"text": "document"})]
)
```
## Quantization (memory optimization)
```python
from qdrant_client.models import ScalarQuantization, ScalarQuantizationConfig, ScalarType
# Scalar quantization (4x memory reduction)
client.create_collection(
collection_name="quantized",
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
quantization_config=ScalarQuantization(
scalar=ScalarQuantizationConfig(
type=ScalarType.INT8,
quantile=0.99, # Clip outliers
always_ram=True # Keep quantized in RAM
)
)
)
# Search with rescoring
results = client.search(
collection_name="quantized",
query_vector=query,
search_params={"quantization": {"rescore": True}}, # Rescore top results
limit=10
)
```
## Payload indexing
```python
from qdrant_client.models import PayloadSchemaType
# Create payload index for faster filtering
client.create_payload_index(
collection_name="documents",
field_name="category",
field_schema=PayloadSchemaType.KEYWORD
)
client.create_payload_index(
collection_name="documents",
field_name="timestamp",
field_schema=PayloadSchemaType.INTEGER
)
# Index types: KEYWORD, INTEGER, FLOAT, GEO, TEXT (full-text), BOOL
```
## Production deployment
### Qdrant Cloud
```python
from qdrant_client import QdrantClient
# Connect to Qdrant Cloud
client = QdrantClient(
url="https://your-cluster.cloud.qdrant.io",
api_key="your-api-key"
)
```
### Performance tuning
```python
# Optimize for search speed (higher recall)
client.update_collection(
collection_name="documents",
hnsw_config=HnswConfigDiff(ef_construct=200, m=32)
)
# Optimize for indexing speed (bulk loads)
client.update_collection(
collection_name="documents",
optimizer_config={"indexing_threshold": 20000}
)
```
## Best practices
1. **Batch operations** - Use batch upsert/search for efficiency
2. **Payload indexing** - Index fields used in filters
3. **Quantization** - Enable for large collections (>1M vectors)
4. **Sharding** - Use for collections >10M vectors
5. **On-disk storage** - Enable `on_disk_payload` for large payloads
6. **Connection pooling** - Reuse client instances
## Common issues
**Slow search with filters:**
```python
# Create payload index for filtered fields
client.create_payload_index(
collection_name="docs",
field_name="category",
field_schema=PayloadSchemaType.KEYWORD
)
```
**Out of memory:**
```python
# Enable quantization and on-disk storage
client.create_collection(
collection_name="large_collection",
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
quantization_config=ScalarQuantization(...),
on_disk_payload=True
)
```
**Connection issues:**
```python
# Use timeout and retry
client = QdrantClient(
host="localhost",
port=6333,
timeout=30,
prefer_grpc=True # gRPC for better performance
)
```
## References
- **[Advanced Usage](references/advanced-usage.md)** - Distributed mode, hybrid search, recommendations
- **[Troubleshooting](references/troubleshooting.md)** - Common issues, debugging, performance tuning
## Resources
- **GitHub**: https://github.com/qdrant/qdrant (22k+ stars)
- **Docs**: https://qdrant.tech/documentation/
- **Python Client**: https://github.com/qdrant/qdrant-client
- **Cloud**: https://cloud.qdrant.io
- **Version**: 1.12.0+
- **License**: Apache 2.0

View file

@ -0,0 +1,648 @@
# Qdrant Advanced Usage Guide
## Distributed Deployment
### Cluster Setup
Qdrant uses Raft consensus for distributed coordination.
```yaml
# docker-compose.yml for 3-node cluster
version: '3.8'
services:
qdrant-node-1:
image: qdrant/qdrant:latest
ports:
- "6333:6333"
- "6334:6334"
- "6335:6335"
volumes:
- ./node1_storage:/qdrant/storage
environment:
- QDRANT__CLUSTER__ENABLED=true
- QDRANT__CLUSTER__P2P__PORT=6335
- QDRANT__SERVICE__HTTP_PORT=6333
- QDRANT__SERVICE__GRPC_PORT=6334
qdrant-node-2:
image: qdrant/qdrant:latest
ports:
- "6343:6333"
- "6344:6334"
- "6345:6335"
volumes:
- ./node2_storage:/qdrant/storage
environment:
- QDRANT__CLUSTER__ENABLED=true
- QDRANT__CLUSTER__P2P__PORT=6335
- QDRANT__CLUSTER__BOOTSTRAP=http://qdrant-node-1:6335
depends_on:
- qdrant-node-1
qdrant-node-3:
image: qdrant/qdrant:latest
ports:
- "6353:6333"
- "6354:6334"
- "6355:6335"
volumes:
- ./node3_storage:/qdrant/storage
environment:
- QDRANT__CLUSTER__ENABLED=true
- QDRANT__CLUSTER__P2P__PORT=6335
- QDRANT__CLUSTER__BOOTSTRAP=http://qdrant-node-1:6335
depends_on:
- qdrant-node-1
```
### Sharding Configuration
```python
from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance, ShardingMethod
client = QdrantClient(host="localhost", port=6333)
# Create sharded collection
client.create_collection(
collection_name="large_collection",
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
shard_number=6, # Number of shards
replication_factor=2, # Replicas per shard
write_consistency_factor=1 # Required acks for write
)
# Check cluster status
cluster_info = client.get_cluster_info()
print(f"Peers: {cluster_info.peers}")
print(f"Raft state: {cluster_info.raft_info}")
```
### Replication and Consistency
```python
from qdrant_client.models import WriteOrdering
# Strong consistency write
client.upsert(
collection_name="critical_data",
points=points,
ordering=WriteOrdering.STRONG # Wait for all replicas
)
# Eventual consistency (faster)
client.upsert(
collection_name="logs",
points=points,
ordering=WriteOrdering.WEAK # Return after primary ack
)
# Read from specific shard
results = client.search(
collection_name="documents",
query_vector=query,
consistency="majority" # Read from majority of replicas
)
```
## Hybrid Search
### Dense + Sparse Vectors
Combine semantic (dense) and keyword (sparse) search:
```python
from qdrant_client.models import (
VectorParams, SparseVectorParams, SparseIndexParams,
Distance, PointStruct, SparseVector, Prefetch, Query
)
# Create hybrid collection
client.create_collection(
collection_name="hybrid",
vectors_config={
"dense": VectorParams(size=384, distance=Distance.COSINE)
},
sparse_vectors_config={
"sparse": SparseVectorParams(
index=SparseIndexParams(on_disk=False)
)
}
)
# Insert with both vector types
def encode_sparse(text: str) -> SparseVector:
"""Simple BM25-like sparse encoding"""
from collections import Counter
tokens = text.lower().split()
counts = Counter(tokens)
# Map tokens to indices (use vocabulary in production)
indices = [hash(t) % 30000 for t in counts.keys()]
values = list(counts.values())
return SparseVector(indices=indices, values=values)
client.upsert(
collection_name="hybrid",
points=[
PointStruct(
id=1,
vector={
"dense": dense_encoder.encode("Python programming").tolist(),
"sparse": encode_sparse("Python programming language code")
},
payload={"text": "Python programming language code"}
)
]
)
# Hybrid search with Reciprocal Rank Fusion (RRF)
from qdrant_client.models import FusionQuery
results = client.query_points(
collection_name="hybrid",
prefetch=[
Prefetch(query=dense_query, using="dense", limit=20),
Prefetch(query=sparse_query, using="sparse", limit=20)
],
query=FusionQuery(fusion="rrf"), # Combine results
limit=10
)
```
### Multi-Stage Search
```python
from qdrant_client.models import Prefetch, Query
# Two-stage retrieval: coarse then fine
results = client.query_points(
collection_name="documents",
prefetch=[
Prefetch(
query=query_vector,
limit=100, # Broad first stage
params={"quantization": {"rescore": False}} # Fast, approximate
)
],
query=Query(nearest=query_vector),
limit=10,
params={"quantization": {"rescore": True}} # Accurate reranking
)
```
## Recommendations
### Item-to-Item Recommendations
```python
# Find similar items
recommendations = client.recommend(
collection_name="products",
positive=[1, 2, 3], # IDs user liked
negative=[4], # IDs user disliked
limit=10
)
# With filtering
recommendations = client.recommend(
collection_name="products",
positive=[1, 2],
query_filter={
"must": [
{"key": "category", "match": {"value": "electronics"}},
{"key": "in_stock", "match": {"value": True}}
]
},
limit=10
)
```
### Lookup from Another Collection
```python
from qdrant_client.models import RecommendStrategy, LookupLocation
# Recommend using vectors from another collection
results = client.recommend(
collection_name="products",
positive=[
LookupLocation(
collection_name="user_history",
id="user_123"
)
],
strategy=RecommendStrategy.AVERAGE_VECTOR,
limit=10
)
```
## Advanced Filtering
### Nested Payload Filtering
```python
from qdrant_client.models import Filter, FieldCondition, MatchValue, NestedCondition
# Filter on nested objects
results = client.search(
collection_name="documents",
query_vector=query,
query_filter=Filter(
must=[
NestedCondition(
key="metadata",
filter=Filter(
must=[
FieldCondition(
key="author.name",
match=MatchValue(value="John")
)
]
)
)
]
),
limit=10
)
```
### Geo Filtering
```python
from qdrant_client.models import FieldCondition, GeoRadius, GeoPoint
# Find within radius
results = client.search(
collection_name="locations",
query_vector=query,
query_filter=Filter(
must=[
FieldCondition(
key="location",
geo_radius=GeoRadius(
center=GeoPoint(lat=40.7128, lon=-74.0060),
radius=5000 # meters
)
)
]
),
limit=10
)
# Geo bounding box
from qdrant_client.models import GeoBoundingBox
results = client.search(
collection_name="locations",
query_vector=query,
query_filter=Filter(
must=[
FieldCondition(
key="location",
geo_bounding_box=GeoBoundingBox(
top_left=GeoPoint(lat=40.8, lon=-74.1),
bottom_right=GeoPoint(lat=40.6, lon=-73.9)
)
)
]
),
limit=10
)
```
### Full-Text Search
```python
from qdrant_client.models import TextIndexParams, TokenizerType
# Create text index
client.create_payload_index(
collection_name="documents",
field_name="content",
field_schema=TextIndexParams(
type="text",
tokenizer=TokenizerType.WORD,
min_token_len=2,
max_token_len=15,
lowercase=True
)
)
# Full-text filter
from qdrant_client.models import MatchText
results = client.search(
collection_name="documents",
query_vector=query,
query_filter=Filter(
must=[
FieldCondition(
key="content",
match=MatchText(text="machine learning")
)
]
),
limit=10
)
```
## Quantization Strategies
### Scalar Quantization (INT8)
```python
from qdrant_client.models import ScalarQuantization, ScalarQuantizationConfig, ScalarType
# ~4x memory reduction, minimal accuracy loss
client.create_collection(
collection_name="scalar_quantized",
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
quantization_config=ScalarQuantization(
scalar=ScalarQuantizationConfig(
type=ScalarType.INT8,
quantile=0.99, # Clip extreme values
always_ram=True # Keep quantized vectors in RAM
)
)
)
```
### Product Quantization
```python
from qdrant_client.models import ProductQuantization, ProductQuantizationConfig, CompressionRatio
# ~16x memory reduction, some accuracy loss
client.create_collection(
collection_name="product_quantized",
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
quantization_config=ProductQuantization(
product=ProductQuantizationConfig(
compression=CompressionRatio.X16,
always_ram=True
)
)
)
```
### Binary Quantization
```python
from qdrant_client.models import BinaryQuantization, BinaryQuantizationConfig
# ~32x memory reduction, requires oversampling
client.create_collection(
collection_name="binary_quantized",
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
quantization_config=BinaryQuantization(
binary=BinaryQuantizationConfig(always_ram=True)
)
)
# Search with oversampling
results = client.search(
collection_name="binary_quantized",
query_vector=query,
search_params={
"quantization": {
"rescore": True,
"oversampling": 2.0 # Retrieve 2x candidates, rescore
}
},
limit=10
)
```
## Snapshots and Backups
### Create Snapshot
```python
# Create collection snapshot
snapshot_info = client.create_snapshot(collection_name="documents")
print(f"Snapshot: {snapshot_info.name}")
# List snapshots
snapshots = client.list_snapshots(collection_name="documents")
for s in snapshots:
print(f"{s.name}: {s.size} bytes")
# Full storage snapshot
full_snapshot = client.create_full_snapshot()
```
### Restore from Snapshot
```python
# Download snapshot
client.download_snapshot(
collection_name="documents",
snapshot_name="documents-2024-01-01.snapshot",
target_path="./backup/"
)
# Restore (via REST API)
import requests
response = requests.put(
"http://localhost:6333/collections/documents/snapshots/recover",
json={"location": "file:///backup/documents-2024-01-01.snapshot"}
)
```
## Collection Aliases
```python
# Create alias
client.update_collection_aliases(
change_aliases_operations=[
{"create_alias": {"alias_name": "production", "collection_name": "documents_v2"}}
]
)
# Blue-green deployment
# 1. Create new collection with updates
client.create_collection(collection_name="documents_v3", ...)
# 2. Populate new collection
client.upsert(collection_name="documents_v3", points=new_points)
# 3. Atomic switch
client.update_collection_aliases(
change_aliases_operations=[
{"delete_alias": {"alias_name": "production"}},
{"create_alias": {"alias_name": "production", "collection_name": "documents_v3"}}
]
)
# Search via alias
results = client.search(collection_name="production", query_vector=query, limit=10)
```
## Scroll and Iteration
### Scroll Through All Points
```python
# Paginated iteration
offset = None
all_points = []
while True:
results, offset = client.scroll(
collection_name="documents",
limit=100,
offset=offset,
with_payload=True,
with_vectors=False
)
all_points.extend(results)
if offset is None:
break
print(f"Total points: {len(all_points)}")
```
### Filtered Scroll
```python
# Scroll with filter
results, _ = client.scroll(
collection_name="documents",
scroll_filter=Filter(
must=[
FieldCondition(key="status", match=MatchValue(value="active"))
]
),
limit=1000
)
```
## Async Client
```python
import asyncio
from qdrant_client import AsyncQdrantClient
async def main():
client = AsyncQdrantClient(host="localhost", port=6333)
# Async operations
await client.create_collection(
collection_name="async_docs",
vectors_config=VectorParams(size=384, distance=Distance.COSINE)
)
await client.upsert(
collection_name="async_docs",
points=points
)
results = await client.search(
collection_name="async_docs",
query_vector=query,
limit=10
)
return results
results = asyncio.run(main())
```
## gRPC Client
```python
from qdrant_client import QdrantClient
# Prefer gRPC for better performance
client = QdrantClient(
host="localhost",
port=6333,
grpc_port=6334,
prefer_grpc=True # Use gRPC when available
)
# gRPC-only client
from qdrant_client import QdrantClient
client = QdrantClient(
host="localhost",
grpc_port=6334,
prefer_grpc=True,
https=False
)
```
## Multitenancy
### Payload-Based Isolation
```python
# Single collection, filter by tenant
client.upsert(
collection_name="multi_tenant",
points=[
PointStruct(
id=1,
vector=embedding,
payload={"tenant_id": "tenant_a", "text": "..."}
)
]
)
# Search within tenant
results = client.search(
collection_name="multi_tenant",
query_vector=query,
query_filter=Filter(
must=[FieldCondition(key="tenant_id", match=MatchValue(value="tenant_a"))]
),
limit=10
)
```
### Collection-Per-Tenant
```python
# Create tenant collection
def create_tenant_collection(tenant_id: str):
client.create_collection(
collection_name=f"tenant_{tenant_id}",
vectors_config=VectorParams(size=384, distance=Distance.COSINE)
)
# Search tenant collection
def search_tenant(tenant_id: str, query_vector: list, limit: int = 10):
return client.search(
collection_name=f"tenant_{tenant_id}",
query_vector=query_vector,
limit=limit
)
```
## Performance Monitoring
### Collection Statistics
```python
# Collection info
info = client.get_collection("documents")
print(f"Points: {info.points_count}")
print(f"Indexed vectors: {info.indexed_vectors_count}")
print(f"Segments: {len(info.segments)}")
print(f"Status: {info.status}")
# Detailed segment info
for i, segment in enumerate(info.segments):
print(f"Segment {i}: {segment}")
```
### Telemetry
```python
# Get telemetry data
telemetry = client.get_telemetry()
print(f"Collections: {telemetry.collections}")
print(f"Operations: {telemetry.operations}")
```

View file

@ -0,0 +1,631 @@
# Qdrant Troubleshooting Guide
## Installation Issues
### Docker Issues
**Error**: `Cannot connect to Docker daemon`
**Fix**:
```bash
# Start Docker daemon
sudo systemctl start docker
# Or use Docker Desktop on Mac/Windows
open -a Docker
```
**Error**: `Port 6333 already in use`
**Fix**:
```bash
# Find process using port
lsof -i :6333
# Kill process or use different port
docker run -p 6334:6333 qdrant/qdrant
```
### Python Client Issues
**Error**: `ModuleNotFoundError: No module named 'qdrant_client'`
**Fix**:
```bash
pip install qdrant-client
# With specific version
pip install qdrant-client>=1.12.0
```
**Error**: `grpc._channel._InactiveRpcError`
**Fix**:
```bash
# Install with gRPC support
pip install 'qdrant-client[grpc]'
# Or disable gRPC
client = QdrantClient(host="localhost", port=6333, prefer_grpc=False)
```
## Connection Issues
### Cannot Connect to Server
**Error**: `ConnectionRefusedError: [Errno 111] Connection refused`
**Solutions**:
1. **Check server is running**:
```bash
docker ps | grep qdrant
curl http://localhost:6333/healthz
```
2. **Verify port binding**:
```bash
# Check listening ports
netstat -tlnp | grep 6333
# Docker port mapping
docker port <container_id>
```
3. **Use correct host**:
```python
# Docker on Linux
client = QdrantClient(host="localhost", port=6333)
# Docker on Mac/Windows with networking issues
client = QdrantClient(host="127.0.0.1", port=6333)
# Inside Docker network
client = QdrantClient(host="qdrant", port=6333)
```
### Timeout Errors
**Error**: `TimeoutError: Connection timed out`
**Fix**:
```python
# Increase timeout
client = QdrantClient(
host="localhost",
port=6333,
timeout=60 # seconds
)
# For large operations
client.upsert(
collection_name="documents",
points=large_batch,
wait=False # Don't wait for indexing
)
```
### SSL/TLS Errors
**Error**: `ssl.SSLCertVerificationError`
**Fix**:
```python
# Qdrant Cloud
client = QdrantClient(
url="https://cluster.cloud.qdrant.io",
api_key="your-api-key"
)
# Self-signed certificate
client = QdrantClient(
host="localhost",
port=6333,
https=True,
verify=False # Disable verification (not recommended for production)
)
```
## Collection Issues
### Collection Already Exists
**Error**: `ValueError: Collection 'documents' already exists`
**Fix**:
```python
# Check before creating
collections = client.get_collections().collections
names = [c.name for c in collections]
if "documents" not in names:
client.create_collection(...)
# Or recreate
client.recreate_collection(
collection_name="documents",
vectors_config=VectorParams(size=384, distance=Distance.COSINE)
)
```
### Collection Not Found
**Error**: `NotFoundException: Collection 'docs' not found`
**Fix**:
```python
# List available collections
collections = client.get_collections()
print([c.name for c in collections.collections])
# Check exact name (case-sensitive)
try:
info = client.get_collection("documents")
except Exception as e:
print(f"Collection not found: {e}")
```
### Vector Dimension Mismatch
**Error**: `ValueError: Vector dimension mismatch. Expected 384, got 768`
**Fix**:
```python
# Check collection config
info = client.get_collection("documents")
print(f"Expected dimension: {info.config.params.vectors.size}")
# Recreate with correct dimension
client.recreate_collection(
collection_name="documents",
vectors_config=VectorParams(size=768, distance=Distance.COSINE) # Match your embeddings
)
```
## Search Issues
### Empty Search Results
**Problem**: Search returns empty results.
**Solutions**:
1. **Verify data exists**:
```python
info = client.get_collection("documents")
print(f"Points: {info.points_count}")
# Scroll to check data
points, _ = client.scroll(
collection_name="documents",
limit=10,
with_payload=True
)
print(points)
```
2. **Check vector format**:
```python
# Must be list of floats
query_vector = embedding.tolist() # Convert numpy to list
# Check dimensions
print(f"Query dimension: {len(query_vector)}")
```
3. **Verify filter conditions**:
```python
# Test without filter first
results = client.search(
collection_name="documents",
query_vector=query,
limit=10
# No filter
)
# Then add filter incrementally
```
### Slow Search Performance
**Problem**: Search takes too long.
**Solutions**:
1. **Create payload indexes**:
```python
# Index fields used in filters
client.create_payload_index(
collection_name="documents",
field_name="category",
field_schema="keyword"
)
```
2. **Enable quantization**:
```python
client.update_collection(
collection_name="documents",
quantization_config=ScalarQuantization(
scalar=ScalarQuantizationConfig(type=ScalarType.INT8)
)
)
```
3. **Tune HNSW parameters**:
```python
# Faster search (less accurate)
client.update_collection(
collection_name="documents",
hnsw_config=HnswConfigDiff(ef_construct=64, m=8)
)
# Use ef search parameter
results = client.search(
collection_name="documents",
query_vector=query,
search_params={"hnsw_ef": 64}, # Lower = faster
limit=10
)
```
4. **Use gRPC**:
```python
client = QdrantClient(
host="localhost",
port=6333,
grpc_port=6334,
prefer_grpc=True
)
```
### Inconsistent Results
**Problem**: Same query returns different results.
**Solutions**:
1. **Wait for indexing**:
```python
client.upsert(
collection_name="documents",
points=points,
wait=True # Wait for index update
)
```
2. **Check replication consistency**:
```python
# Strong consistency read
results = client.search(
collection_name="documents",
query_vector=query,
consistency="all" # Read from all replicas
)
```
## Upsert Issues
### Batch Upsert Fails
**Error**: `PayloadError: Payload too large`
**Fix**:
```python
# Split into smaller batches
def batch_upsert(client, collection, points, batch_size=100):
for i in range(0, len(points), batch_size):
batch = points[i:i + batch_size]
client.upsert(
collection_name=collection,
points=batch,
wait=True
)
batch_upsert(client, "documents", large_points_list)
```
### Invalid Point ID
**Error**: `ValueError: Invalid point ID`
**Fix**:
```python
# Valid ID types: int or UUID string
from uuid import uuid4
# Integer ID
PointStruct(id=123, vector=vec, payload={})
# UUID string
PointStruct(id=str(uuid4()), vector=vec, payload={})
# NOT valid
PointStruct(id="custom-string-123", ...) # Use UUID format
```
### Payload Validation Errors
**Error**: `ValidationError: Invalid payload`
**Fix**:
```python
# Ensure JSON-serializable payload
import json
payload = {
"title": "Document",
"count": 42,
"tags": ["a", "b"],
"nested": {"key": "value"}
}
# Validate before upsert
json.dumps(payload) # Should not raise
# Avoid non-serializable types
# NOT valid: datetime, numpy arrays, custom objects
payload = {
"timestamp": datetime.now().isoformat(), # Convert to string
"vector": embedding.tolist() # Convert numpy to list
}
```
## Memory Issues
### Out of Memory
**Error**: `MemoryError` or container killed
**Solutions**:
1. **Enable on-disk storage**:
```python
client.create_collection(
collection_name="large_collection",
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
on_disk_payload=True, # Store payloads on disk
hnsw_config=HnswConfigDiff(on_disk=True) # Store HNSW on disk
)
```
2. **Use quantization**:
```python
# 4x memory reduction
client.update_collection(
collection_name="large_collection",
quantization_config=ScalarQuantization(
scalar=ScalarQuantizationConfig(
type=ScalarType.INT8,
always_ram=False # Keep on disk
)
)
)
```
3. **Increase Docker memory**:
```bash
docker run -m 8g -p 6333:6333 qdrant/qdrant
```
4. **Configure Qdrant storage**:
```yaml
# config.yaml
storage:
performance:
max_search_threads: 2
optimizers:
memmap_threshold_kb: 20000
```
### High Memory Usage During Indexing
**Fix**:
```python
# Increase indexing threshold for bulk loads
client.update_collection(
collection_name="documents",
optimizer_config={
"indexing_threshold": 50000 # Delay indexing
}
)
# Bulk insert
client.upsert(collection_name="documents", points=all_points, wait=False)
# Then optimize
client.update_collection(
collection_name="documents",
optimizer_config={
"indexing_threshold": 10000 # Resume normal indexing
}
)
```
## Cluster Issues
### Node Not Joining Cluster
**Problem**: New node fails to join cluster.
**Fix**:
```bash
# Check network connectivity
docker exec qdrant-node-2 ping qdrant-node-1
# Verify bootstrap URL
docker logs qdrant-node-2 | grep bootstrap
# Check Raft state
curl http://localhost:6333/cluster
```
### Split Brain
**Problem**: Cluster has inconsistent state.
**Fix**:
```bash
# Force leader election
curl -X POST http://localhost:6333/cluster/recover
# Or restart minority nodes
docker restart qdrant-node-2 qdrant-node-3
```
### Replication Lag
**Problem**: Replicas fall behind.
**Fix**:
```python
# Check collection status
info = client.get_collection("documents")
print(f"Status: {info.status}")
# Use strong consistency for critical writes
client.upsert(
collection_name="documents",
points=points,
ordering=WriteOrdering.STRONG
)
```
## Performance Tuning
### Benchmark Configuration
```python
import time
import numpy as np
def benchmark_search(client, collection, n_queries=100, dimension=384):
# Generate random queries
queries = [np.random.rand(dimension).tolist() for _ in range(n_queries)]
# Warmup
for q in queries[:10]:
client.search(collection_name=collection, query_vector=q, limit=10)
# Benchmark
start = time.perf_counter()
for q in queries:
client.search(collection_name=collection, query_vector=q, limit=10)
elapsed = time.perf_counter() - start
print(f"QPS: {n_queries / elapsed:.2f}")
print(f"Latency: {elapsed / n_queries * 1000:.2f}ms")
benchmark_search(client, "documents")
```
### Optimal HNSW Parameters
```python
# High recall (slower)
client.create_collection(
collection_name="high_recall",
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
hnsw_config=HnswConfigDiff(
m=32, # More connections
ef_construct=200 # Higher build quality
)
)
# High speed (lower recall)
client.create_collection(
collection_name="high_speed",
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
hnsw_config=HnswConfigDiff(
m=8, # Fewer connections
ef_construct=64 # Lower build quality
)
)
# Balanced
client.create_collection(
collection_name="balanced",
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
hnsw_config=HnswConfigDiff(
m=16, # Default
ef_construct=100 # Default
)
)
```
## Debugging Tips
### Enable Verbose Logging
```python
import logging
logging.basicConfig(level=logging.DEBUG)
logging.getLogger("qdrant_client").setLevel(logging.DEBUG)
```
### Check Server Logs
```bash
# Docker logs
docker logs -f qdrant
# With timestamps
docker logs --timestamps qdrant
# Last 100 lines
docker logs --tail 100 qdrant
```
### Inspect Collection State
```python
# Collection info
info = client.get_collection("documents")
print(f"Status: {info.status}")
print(f"Points: {info.points_count}")
print(f"Segments: {len(info.segments)}")
print(f"Config: {info.config}")
# Sample points
points, _ = client.scroll(
collection_name="documents",
limit=5,
with_payload=True,
with_vectors=True
)
for p in points:
print(f"ID: {p.id}, Payload: {p.payload}")
```
### Test Connection
```python
def test_connection(host="localhost", port=6333):
try:
client = QdrantClient(host=host, port=port, timeout=5)
collections = client.get_collections()
print(f"Connected! Collections: {len(collections.collections)}")
return True
except Exception as e:
print(f"Connection failed: {e}")
return False
test_connection()
```
## Getting Help
1. **Documentation**: https://qdrant.tech/documentation/
2. **GitHub Issues**: https://github.com/qdrant/qdrant/issues
3. **Discord**: https://discord.gg/qdrant
4. **Stack Overflow**: Tag `qdrant`
### Reporting Issues
Include:
- Qdrant version: `curl http://localhost:6333/`
- Python client version: `pip show qdrant-client`
- Full error traceback
- Minimal reproducible code
- Collection configuration

View file

@ -0,0 +1,389 @@
---
name: sparse-autoencoder-training
description: Provides guidance for training and analyzing Sparse Autoencoders (SAEs) using SAELens to decompose neural network activations into interpretable features. Use when discovering interpretable features, analyzing superposition, or studying monosemantic representations in language models.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [sae-lens>=6.0.0, transformer-lens>=2.0.0, torch>=2.0.0]
metadata:
hermes:
tags: [Sparse Autoencoders, SAE, Mechanistic Interpretability, Feature Discovery, Superposition]
---
# SAELens: Sparse Autoencoders for Mechanistic Interpretability
SAELens is the primary library for training and analyzing Sparse Autoencoders (SAEs) - a technique for decomposing polysemantic neural network activations into sparse, interpretable features. Based on Anthropic's groundbreaking research on monosemanticity.
**GitHub**: [jbloomAus/SAELens](https://github.com/jbloomAus/SAELens) (1,100+ stars)
## The Problem: Polysemanticity & Superposition
Individual neurons in neural networks are **polysemantic** - they activate in multiple, semantically distinct contexts. This happens because models use **superposition** to represent more features than they have neurons, making interpretability difficult.
**SAEs solve this** by decomposing dense activations into sparse, monosemantic features - typically only a small number of features activate for any given input, and each feature corresponds to an interpretable concept.
## When to Use SAELens
**Use SAELens when you need to:**
- Discover interpretable features in model activations
- Understand what concepts a model has learned
- Study superposition and feature geometry
- Perform feature-based steering or ablation
- Analyze safety-relevant features (deception, bias, harmful content)
**Consider alternatives when:**
- You need basic activation analysis → Use **TransformerLens** directly
- You want causal intervention experiments → Use **pyvene** or **TransformerLens**
- You need production steering → Consider direct activation engineering
## Installation
```bash
pip install sae-lens
```
Requirements: Python 3.10+, transformer-lens>=2.0.0
## Core Concepts
### What SAEs Learn
SAEs are trained to reconstruct model activations through a sparse bottleneck:
```
Input Activation → Encoder → Sparse Features → Decoder → Reconstructed Activation
(d_model) ↓ (d_sae >> d_model) ↓ (d_model)
sparsity reconstruction
penalty loss
```
**Loss Function**: `MSE(original, reconstructed) + L1_coefficient × L1(features)`
### Key Validation (Anthropic Research)
In "Towards Monosemanticity", human evaluators found **70% of SAE features genuinely interpretable**. Features discovered include:
- DNA sequences, legal language, HTTP requests
- Hebrew text, nutrition statements, code syntax
- Sentiment, named entities, grammatical structures
## Workflow 1: Loading and Analyzing Pre-trained SAEs
### Step-by-Step
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
# 1. Load model and pre-trained SAE
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# 2. Get model activations
tokens = model.to_tokens("The capital of France is Paris")
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8] # [batch, pos, d_model]
# 3. Encode to SAE features
sae_features = sae.encode(activations) # [batch, pos, d_sae]
print(f"Active features: {(sae_features > 0).sum()}")
# 4. Find top features for each position
for pos in range(tokens.shape[1]):
top_features = sae_features[0, pos].topk(5)
token = model.to_str_tokens(tokens[0, pos:pos+1])[0]
print(f"Token '{token}': features {top_features.indices.tolist()}")
# 5. Reconstruct activations
reconstructed = sae.decode(sae_features)
reconstruction_error = (activations - reconstructed).norm()
```
### Available Pre-trained SAEs
| Release | Model | Layers |
|---------|-------|--------|
| `gpt2-small-res-jb` | GPT-2 Small | Multiple residual streams |
| `gemma-2b-res` | Gemma 2B | Residual streams |
| Various on HuggingFace | Search tag `saelens` | Various |
### Checklist
- [ ] Load model with TransformerLens
- [ ] Load matching SAE for target layer
- [ ] Encode activations to sparse features
- [ ] Identify top-activating features per token
- [ ] Validate reconstruction quality
## Workflow 2: Training a Custom SAE
### Step-by-Step
```python
from sae_lens import SAE, LanguageModelSAERunnerConfig, SAETrainingRunner
# 1. Configure training
cfg = LanguageModelSAERunnerConfig(
# Model
model_name="gpt2-small",
hook_name="blocks.8.hook_resid_pre",
hook_layer=8,
d_in=768, # Model dimension
# SAE architecture
architecture="standard", # or "gated", "topk"
d_sae=768 * 8, # Expansion factor of 8
activation_fn="relu",
# Training
lr=4e-4,
l1_coefficient=8e-5, # Sparsity penalty
l1_warm_up_steps=1000,
train_batch_size_tokens=4096,
training_tokens=100_000_000,
# Data
dataset_path="monology/pile-uncopyrighted",
context_size=128,
# Logging
log_to_wandb=True,
wandb_project="sae-training",
# Checkpointing
checkpoint_path="checkpoints",
n_checkpoints=5,
)
# 2. Train
trainer = SAETrainingRunner(cfg)
sae = trainer.run()
# 3. Evaluate
print(f"L0 (avg active features): {trainer.metrics['l0']}")
print(f"CE Loss Recovered: {trainer.metrics['ce_loss_score']}")
```
### Key Hyperparameters
| Parameter | Typical Value | Effect |
|-----------|---------------|--------|
| `d_sae` | 4-16× d_model | More features, higher capacity |
| `l1_coefficient` | 5e-5 to 1e-4 | Higher = sparser, less accurate |
| `lr` | 1e-4 to 1e-3 | Standard optimizer LR |
| `l1_warm_up_steps` | 500-2000 | Prevents early feature death |
### Evaluation Metrics
| Metric | Target | Meaning |
|--------|--------|---------|
| **L0** | 50-200 | Average active features per token |
| **CE Loss Score** | 80-95% | Cross-entropy recovered vs original |
| **Dead Features** | <5% | Features that never activate |
| **Explained Variance** | >90% | Reconstruction quality |
### Checklist
- [ ] Choose target layer and hook point
- [ ] Set expansion factor (d_sae = 4-16× d_model)
- [ ] Tune L1 coefficient for desired sparsity
- [ ] Enable L1 warm-up to prevent dead features
- [ ] Monitor metrics during training (W&B)
- [ ] Validate L0 and CE loss recovery
- [ ] Check dead feature ratio
## Workflow 3: Feature Analysis and Steering
### Analyzing Individual Features
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, _, _ = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# Find what activates a specific feature
feature_idx = 1234
test_texts = [
"The scientist conducted an experiment",
"I love chocolate cake",
"The code compiles successfully",
"Paris is beautiful in spring",
]
for text in test_texts:
tokens = model.to_tokens(text)
_, cache = model.run_with_cache(tokens)
features = sae.encode(cache["resid_pre", 8])
activation = features[0, :, feature_idx].max().item()
print(f"{activation:.3f}: {text}")
```
### Feature Steering
```python
def steer_with_feature(model, sae, prompt, feature_idx, strength=5.0):
"""Add SAE feature direction to residual stream."""
tokens = model.to_tokens(prompt)
# Get feature direction from decoder
feature_direction = sae.W_dec[feature_idx] # [d_model]
def steering_hook(activation, hook):
# Add scaled feature direction at all positions
activation += strength * feature_direction
return activation
# Generate with steering
output = model.generate(
tokens,
max_new_tokens=50,
fwd_hooks=[("blocks.8.hook_resid_pre", steering_hook)]
)
return model.to_string(output[0])
```
### Feature Attribution
```python
# Which features most affect a specific output?
tokens = model.to_tokens("The capital of France is")
_, cache = model.run_with_cache(tokens)
# Get features at final position
features = sae.encode(cache["resid_pre", 8])[0, -1] # [d_sae]
# Get logit attribution per feature
# Feature contribution = feature_activation × decoder_weight × unembedding
W_dec = sae.W_dec # [d_sae, d_model]
W_U = model.W_U # [d_model, vocab]
# Contribution to "Paris" logit
paris_token = model.to_single_token(" Paris")
feature_contributions = features * (W_dec @ W_U[:, paris_token])
top_features = feature_contributions.topk(10)
print("Top features for 'Paris' prediction:")
for idx, val in zip(top_features.indices, top_features.values):
print(f" Feature {idx.item()}: {val.item():.3f}")
```
## Common Issues & Solutions
### Issue: High dead feature ratio
```python
# WRONG: No warm-up, features die early
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=1e-4,
l1_warm_up_steps=0, # Bad!
)
# RIGHT: Warm-up L1 penalty
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=8e-5,
l1_warm_up_steps=1000, # Gradually increase
use_ghost_grads=True, # Revive dead features
)
```
### Issue: Poor reconstruction (low CE recovery)
```python
# Reduce sparsity penalty
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=5e-5, # Lower = better reconstruction
d_sae=768 * 16, # More capacity
)
```
### Issue: Features not interpretable
```python
# Increase sparsity (higher L1)
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=1e-4, # Higher = sparser, more interpretable
)
# Or use TopK architecture
cfg = LanguageModelSAERunnerConfig(
architecture="topk",
activation_fn_kwargs={"k": 50}, # Exactly 50 active features
)
```
### Issue: Memory errors during training
```python
cfg = LanguageModelSAERunnerConfig(
train_batch_size_tokens=2048, # Reduce batch size
store_batch_size_prompts=4, # Fewer prompts in buffer
n_batches_in_buffer=8, # Smaller activation buffer
)
```
## Integration with Neuronpedia
Browse pre-trained SAE features at [neuronpedia.org](https://neuronpedia.org):
```python
# Features are indexed by SAE ID
# Example: gpt2-small layer 8 feature 1234
# → neuronpedia.org/gpt2-small/8-res-jb/1234
```
## Key Classes Reference
| Class | Purpose |
|-------|---------|
| `SAE` | Sparse Autoencoder model |
| `LanguageModelSAERunnerConfig` | Training configuration |
| `SAETrainingRunner` | Training loop manager |
| `ActivationsStore` | Activation collection and batching |
| `HookedSAETransformer` | TransformerLens + SAE integration |
## Reference Documentation
For detailed API documentation, tutorials, and advanced usage, see the `references/` folder:
| File | Contents |
|------|----------|
| [references/README.md](references/README.md) | Overview and quick start guide |
| [references/api.md](references/api.md) | Complete API reference for SAE, TrainingSAE, configurations |
| [references/tutorials.md](references/tutorials.md) | Step-by-step tutorials for training, analysis, steering |
## External Resources
### Tutorials
- [Basic Loading & Analysis](https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
- [Training a Sparse Autoencoder](https://github.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
- [ARENA SAE Curriculum](https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab)
### Papers
- [Towards Monosemanticity](https://transformer-circuits.pub/2023/monosemantic-features) - Anthropic (2023)
- [Scaling Monosemanticity](https://transformer-circuits.pub/2024/scaling-monosemanticity/) - Anthropic (2024)
- [Sparse Autoencoders Find Highly Interpretable Features](https://arxiv.org/abs/2309.08600) - Cunningham et al. (ICLR 2024)
### Official Documentation
- [SAELens Docs](https://jbloomaus.github.io/SAELens/)
- [Neuronpedia](https://neuronpedia.org) - Feature browser
## SAE Architectures
| Architecture | Description | Use Case |
|--------------|-------------|----------|
| **Standard** | ReLU + L1 penalty | General purpose |
| **Gated** | Learned gating mechanism | Better sparsity control |
| **TopK** | Exactly K active features | Consistent sparsity |
```python
# TopK SAE (exactly 50 features active)
cfg = LanguageModelSAERunnerConfig(
architecture="topk",
activation_fn="topk",
activation_fn_kwargs={"k": 50},
)
```

View file

@ -0,0 +1,70 @@
# SAELens Reference Documentation
This directory contains comprehensive reference materials for SAELens.
## Contents
- [api.md](api.md) - Complete API reference for SAE, TrainingSAE, and configuration classes
- [tutorials.md](tutorials.md) - Step-by-step tutorials for training and analyzing SAEs
- [papers.md](papers.md) - Key research papers on sparse autoencoders
## Quick Links
- **GitHub Repository**: https://github.com/jbloomAus/SAELens
- **Neuronpedia**: https://neuronpedia.org (browse pre-trained SAE features)
- **HuggingFace SAEs**: Search for tag `saelens`
## Installation
```bash
pip install sae-lens
```
Requirements: Python 3.10+, transformer-lens>=2.0.0
## Basic Usage
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
# Load model and SAE
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# Encode activations to sparse features
tokens = model.to_tokens("Hello world")
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8]
features = sae.encode(activations) # Sparse feature activations
reconstructed = sae.decode(features) # Reconstructed activations
```
## Key Concepts
### Sparse Autoencoders
SAEs decompose dense neural activations into sparse, interpretable features:
- **Encoder**: Maps d_model → d_sae (typically 4-16x expansion)
- **ReLU/TopK**: Enforces sparsity
- **Decoder**: Reconstructs original activations
### Training Loss
`Loss = MSE(original, reconstructed) + L1_coefficient × L1(features)`
### Key Metrics
- **L0**: Average number of active features (target: 50-200)
- **CE Loss Score**: Cross-entropy recovered vs original model (target: 80-95%)
- **Dead Features**: Features that never activate (target: <5%)
## Available Pre-trained SAEs
| Release | Model | Description |
|---------|-------|-------------|
| `gpt2-small-res-jb` | GPT-2 Small | Residual stream SAEs |
| `gemma-2b-res` | Gemma 2B | Residual stream SAEs |
| Various | Search HuggingFace | Community-trained SAEs |

View file

@ -0,0 +1,333 @@
# SAELens API Reference
## SAE Class
The core class representing a Sparse Autoencoder.
### Loading Pre-trained SAEs
```python
from sae_lens import SAE
# From official releases
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# From HuggingFace
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="username/repo-name",
sae_id="path/to/sae",
device="cuda"
)
# From local disk
sae = SAE.load_from_disk("/path/to/sae", device="cuda")
```
### SAE Attributes
| Attribute | Shape | Description |
|-----------|-------|-------------|
| `W_enc` | [d_in, d_sae] | Encoder weights |
| `W_dec` | [d_sae, d_in] | Decoder weights |
| `b_enc` | [d_sae] | Encoder bias |
| `b_dec` | [d_in] | Decoder bias |
| `cfg` | SAEConfig | Configuration object |
### Core Methods
#### encode()
```python
# Encode activations to sparse features
features = sae.encode(activations)
# Input: [batch, pos, d_in]
# Output: [batch, pos, d_sae]
```
#### decode()
```python
# Reconstruct activations from features
reconstructed = sae.decode(features)
# Input: [batch, pos, d_sae]
# Output: [batch, pos, d_in]
```
#### forward()
```python
# Full forward pass (encode + decode)
reconstructed = sae(activations)
# Returns reconstructed activations
```
#### save_model()
```python
sae.save_model("/path/to/save")
```
---
## SAEConfig
Configuration class for SAE architecture and training context.
### Key Parameters
| Parameter | Type | Description |
|-----------|------|-------------|
| `d_in` | int | Input dimension (model's d_model) |
| `d_sae` | int | SAE hidden dimension |
| `architecture` | str | "standard", "gated", "jumprelu", "topk" |
| `activation_fn_str` | str | Activation function name |
| `model_name` | str | Source model name |
| `hook_name` | str | Hook point in model |
| `normalize_activations` | str | Normalization method |
| `dtype` | str | Data type |
| `device` | str | Device |
### Accessing Config
```python
print(sae.cfg.d_in) # 768 for GPT-2 small
print(sae.cfg.d_sae) # e.g., 24576 (32x expansion)
print(sae.cfg.hook_name) # e.g., "blocks.8.hook_resid_pre"
```
---
## LanguageModelSAERunnerConfig
Comprehensive configuration for training SAEs.
### Example Configuration
```python
from sae_lens import LanguageModelSAERunnerConfig
cfg = LanguageModelSAERunnerConfig(
# Model and hook
model_name="gpt2-small",
hook_name="blocks.8.hook_resid_pre",
hook_layer=8,
d_in=768,
# SAE architecture
architecture="standard", # "standard", "gated", "jumprelu", "topk"
d_sae=768 * 8, # Expansion factor
activation_fn="relu",
# Training hyperparameters
lr=4e-4,
l1_coefficient=8e-5,
lp_norm=1.0,
lr_scheduler_name="constant",
lr_warm_up_steps=500,
# Sparsity control
l1_warm_up_steps=1000,
use_ghost_grads=True,
feature_sampling_window=1000,
dead_feature_window=5000,
dead_feature_threshold=1e-8,
# Data
dataset_path="monology/pile-uncopyrighted",
streaming=True,
context_size=128,
# Batch sizes
train_batch_size_tokens=4096,
store_batch_size_prompts=16,
n_batches_in_buffer=64,
# Training duration
training_tokens=100_000_000,
# Logging
log_to_wandb=True,
wandb_project="sae-training",
wandb_log_frequency=100,
# Checkpointing
checkpoint_path="checkpoints",
n_checkpoints=5,
# Hardware
device="cuda",
dtype="float32",
)
```
### Key Parameters Explained
#### Architecture Parameters
| Parameter | Description |
|-----------|-------------|
| `architecture` | SAE type: "standard", "gated", "jumprelu", "topk" |
| `d_sae` | Hidden dimension (or use `expansion_factor`) |
| `expansion_factor` | Alternative to d_sae: d_sae = d_in × expansion_factor |
| `activation_fn` | "relu", "topk", etc. |
| `activation_fn_kwargs` | Dict for activation params (e.g., {"k": 50} for topk) |
#### Sparsity Parameters
| Parameter | Description |
|-----------|-------------|
| `l1_coefficient` | L1 penalty weight (higher = sparser) |
| `l1_warm_up_steps` | Steps to ramp up L1 penalty |
| `use_ghost_grads` | Apply gradients to dead features |
| `dead_feature_threshold` | Activation threshold for "dead" |
| `dead_feature_window` | Steps to check for dead features |
#### Learning Rate Parameters
| Parameter | Description |
|-----------|-------------|
| `lr` | Base learning rate |
| `lr_scheduler_name` | "constant", "cosineannealing", etc. |
| `lr_warm_up_steps` | LR warmup steps |
| `lr_decay_steps` | Steps for LR decay |
---
## SAETrainingRunner
Main class for executing training.
### Basic Training
```python
from sae_lens import SAETrainingRunner, LanguageModelSAERunnerConfig
cfg = LanguageModelSAERunnerConfig(...)
runner = SAETrainingRunner(cfg)
sae = runner.run()
```
### Accessing Training Metrics
```python
# During training, metrics logged to W&B include:
# - l0: Average active features
# - ce_loss_score: Cross-entropy recovery
# - mse_loss: Reconstruction loss
# - l1_loss: Sparsity loss
# - dead_features: Count of dead features
```
---
## ActivationsStore
Manages activation collection and batching.
### Basic Usage
```python
from sae_lens import ActivationsStore
store = ActivationsStore.from_sae(
model=model,
sae=sae,
store_batch_size_prompts=8,
train_batch_size_tokens=4096,
n_batches_in_buffer=32,
device="cuda",
)
# Get batch of activations
activations = store.get_batch_tokens()
```
---
## HookedSAETransformer
Integration of SAEs with TransformerLens models.
### Basic Usage
```python
from sae_lens import HookedSAETransformer
# Load model with SAE
model = HookedSAETransformer.from_pretrained("gpt2-small")
model.add_sae(sae)
# Run with SAE in the loop
output = model.run_with_saes(tokens, saes=[sae])
# Cache with SAE activations
output, cache = model.run_with_cache_with_saes(tokens, saes=[sae])
```
---
## SAE Architectures
### Standard (ReLU + L1)
```python
cfg = LanguageModelSAERunnerConfig(
architecture="standard",
activation_fn="relu",
l1_coefficient=8e-5,
)
```
### Gated
```python
cfg = LanguageModelSAERunnerConfig(
architecture="gated",
)
```
### TopK
```python
cfg = LanguageModelSAERunnerConfig(
architecture="topk",
activation_fn="topk",
activation_fn_kwargs={"k": 50}, # Exactly 50 active features
)
```
### JumpReLU (State-of-the-art)
```python
cfg = LanguageModelSAERunnerConfig(
architecture="jumprelu",
)
```
---
## Utility Functions
### Upload to HuggingFace
```python
from sae_lens import upload_saes_to_huggingface
upload_saes_to_huggingface(
saes=[sae],
repo_id="username/my-saes",
token="hf_token",
)
```
### Neuronpedia Integration
```python
# Features can be viewed on Neuronpedia
# URL format: neuronpedia.org/{model}/{layer}-{sae_type}/{feature_id}
# Example: neuronpedia.org/gpt2-small/8-res-jb/1234
```

View file

@ -0,0 +1,318 @@
# SAELens Tutorials
## Tutorial 1: Loading and Analyzing Pre-trained SAEs
### Goal
Load a pre-trained SAE and analyze which features activate on specific inputs.
### Step-by-Step
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
# 1. Load model and SAE
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
print(f"SAE input dim: {sae.cfg.d_in}")
print(f"SAE hidden dim: {sae.cfg.d_sae}")
print(f"Expansion factor: {sae.cfg.d_sae / sae.cfg.d_in:.1f}x")
# 2. Get model activations
prompt = "The capital of France is Paris"
tokens = model.to_tokens(prompt)
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8] # [1, seq_len, 768]
# 3. Encode to SAE features
features = sae.encode(activations) # [1, seq_len, d_sae]
# 4. Analyze sparsity
active_per_token = (features > 0).sum(dim=-1)
print(f"Average active features per token: {active_per_token.float().mean():.1f}")
# 5. Find top features for each token
str_tokens = model.to_str_tokens(prompt)
for pos in range(len(str_tokens)):
top_features = features[0, pos].topk(5)
print(f"\nToken '{str_tokens[pos]}':")
for feat_idx, feat_val in zip(top_features.indices, top_features.values):
print(f" Feature {feat_idx.item()}: {feat_val.item():.3f}")
# 6. Check reconstruction quality
reconstructed = sae.decode(features)
mse = ((activations - reconstructed) ** 2).mean()
print(f"\nReconstruction MSE: {mse.item():.6f}")
```
---
## Tutorial 2: Training a Custom SAE
### Goal
Train a Sparse Autoencoder on GPT-2 activations.
### Step-by-Step
```python
from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner
# 1. Configure training
cfg = LanguageModelSAERunnerConfig(
# Model
model_name="gpt2-small",
hook_name="blocks.6.hook_resid_pre",
hook_layer=6,
d_in=768,
# SAE architecture
architecture="standard",
d_sae=768 * 8, # 8x expansion
activation_fn="relu",
# Training
lr=4e-4,
l1_coefficient=8e-5,
l1_warm_up_steps=1000,
train_batch_size_tokens=4096,
training_tokens=10_000_000, # Small run for demo
# Data
dataset_path="monology/pile-uncopyrighted",
streaming=True,
context_size=128,
# Dead feature prevention
use_ghost_grads=True,
dead_feature_window=5000,
# Logging
log_to_wandb=True,
wandb_project="sae-training-demo",
# Hardware
device="cuda",
dtype="float32",
)
# 2. Train
runner = SAETrainingRunner(cfg)
sae = runner.run()
# 3. Save
sae.save_model("./my_trained_sae")
```
### Hyperparameter Tuning Guide
| If you see... | Try... |
|---------------|--------|
| High L0 (>200) | Increase `l1_coefficient` |
| Low CE recovery (<80%) | Decrease `l1_coefficient`, increase `d_sae` |
| Many dead features (>5%) | Enable `use_ghost_grads`, increase `l1_warm_up_steps` |
| Training instability | Lower `lr`, increase `lr_warm_up_steps` |
---
## Tutorial 3: Feature Attribution and Steering
### Goal
Identify which SAE features contribute to specific predictions and use them for steering.
### Step-by-Step
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, _, _ = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# 1. Feature attribution for a specific prediction
prompt = "The capital of France is"
tokens = model.to_tokens(prompt)
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8]
features = sae.encode(activations)
# Target token
target_token = model.to_single_token(" Paris")
# Compute feature contributions to target logit
# contribution = feature_activation * decoder_weight * unembedding
W_dec = sae.W_dec # [d_sae, d_model]
W_U = model.W_U # [d_model, d_vocab]
# Feature direction projected to vocabulary
feature_to_logit = W_dec @ W_U # [d_sae, d_vocab]
# Contribution of each feature to "Paris" at final position
feature_acts = features[0, -1] # [d_sae]
contributions = feature_acts * feature_to_logit[:, target_token]
# Top contributing features
top_features = contributions.topk(10)
print("Top features contributing to 'Paris':")
for idx, val in zip(top_features.indices, top_features.values):
print(f" Feature {idx.item()}: {val.item():.3f}")
# 2. Feature steering
def steer_with_feature(feature_idx, strength=5.0):
"""Add a feature direction to the residual stream."""
feature_direction = sae.W_dec[feature_idx] # [d_model]
def hook(activation, hook_obj):
activation[:, -1, :] += strength * feature_direction
return activation
output = model.generate(
tokens,
max_new_tokens=10,
fwd_hooks=[("blocks.8.hook_resid_pre", hook)]
)
return model.to_string(output[0])
# Try steering with top feature
top_feature_idx = top_features.indices[0].item()
print(f"\nSteering with feature {top_feature_idx}:")
print(steer_with_feature(top_feature_idx, strength=10.0))
```
---
## Tutorial 4: Feature Ablation
### Goal
Test the causal importance of features by ablating them.
### Step-by-Step
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, _, _ = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
prompt = "The capital of France is"
tokens = model.to_tokens(prompt)
# Baseline prediction
baseline_logits = model(tokens)
target_token = model.to_single_token(" Paris")
baseline_prob = torch.softmax(baseline_logits[0, -1], dim=-1)[target_token].item()
print(f"Baseline P(Paris): {baseline_prob:.4f}")
# Get features to ablate
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8]
features = sae.encode(activations)
top_features = features[0, -1].topk(10).indices
# Ablate top features one by one
for feat_idx in top_features:
def ablation_hook(activation, hook, feat_idx=feat_idx):
# Encode → zero feature → decode
feats = sae.encode(activation)
feats[:, :, feat_idx] = 0
return sae.decode(feats)
ablated_logits = model.run_with_hooks(
tokens,
fwd_hooks=[("blocks.8.hook_resid_pre", ablation_hook)]
)
ablated_prob = torch.softmax(ablated_logits[0, -1], dim=-1)[target_token].item()
change = (ablated_prob - baseline_prob) / baseline_prob * 100
print(f"Ablate feature {feat_idx.item()}: P(Paris)={ablated_prob:.4f} ({change:+.1f}%)")
```
---
## Tutorial 5: Comparing Features Across Prompts
### Goal
Find which features activate consistently for a concept.
### Step-by-Step
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, _, _ = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# Test prompts about the same concept
prompts = [
"The Eiffel Tower is located in",
"Paris is the capital of",
"France's largest city is",
"The Louvre museum is in",
]
# Collect feature activations
all_features = []
for prompt in prompts:
tokens = model.to_tokens(prompt)
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8]
features = sae.encode(activations)
# Take max activation across positions
max_features = features[0].max(dim=0).values
all_features.append(max_features)
all_features = torch.stack(all_features) # [n_prompts, d_sae]
# Find features that activate consistently
mean_activation = all_features.mean(dim=0)
min_activation = all_features.min(dim=0).values
# Features active in ALL prompts
consistent_features = (min_activation > 0.5).nonzero().squeeze(-1)
print(f"Features active in all prompts: {len(consistent_features)}")
# Top consistent features
top_consistent = mean_activation[consistent_features].topk(min(10, len(consistent_features)))
print("\nTop consistent features (possibly 'France/Paris' related):")
for idx, val in zip(top_consistent.indices, top_consistent.values):
feat_idx = consistent_features[idx].item()
print(f" Feature {feat_idx}: mean activation {val.item():.3f}")
```
---
## External Resources
### Official Tutorials
- [Basic Loading & Analysis](https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
- [Training SAEs](https://github.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
- [Logits Lens with Features](https://github.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
### ARENA Curriculum
Comprehensive SAE course: https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab
### Key Papers
- [Towards Monosemanticity](https://transformer-circuits.pub/2023/monosemantic-features) - Anthropic (2023)
- [Scaling Monosemanticity](https://transformer-circuits.pub/2024/scaling-monosemanticity/) - Anthropic (2024)
- [Sparse Autoencoders Find Interpretable Features](https://arxiv.org/abs/2309.08600) - ICLR 2024

View file

@ -0,0 +1,222 @@
---
name: simpo-training
description: Simple Preference Optimization for LLM alignment. Reference-free alternative to DPO with better performance (+6.4 points on AlpacaEval 2.0). No reference model needed, more efficient than DPO. Use for preference alignment when want simpler, faster training than DPO/PPO.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [torch, transformers, datasets, trl, accelerate]
metadata:
hermes:
tags: [Post-Training, SimPO, Preference Optimization, Alignment, DPO Alternative, Reference-Free, LLM Alignment, Efficient Training]
---
# SimPO - Simple Preference Optimization
## Quick start
SimPO is a reference-free preference optimization method that outperforms DPO without needing a reference model.
**Installation**:
```bash
# Create environment
conda create -n simpo python=3.10 && conda activate simpo
# Install PyTorch 2.2.2
# Visit: https://pytorch.org/get-started/locally/
# Install alignment-handbook
git clone https://github.com/huggingface/alignment-handbook.git
cd alignment-handbook
python -m pip install .
# Install Flash Attention 2
python -m pip install flash-attn --no-build-isolation
```
**Training** (Mistral 7B):
```bash
ACCELERATE_LOG_LEVEL=info accelerate launch \
--config_file accelerate_configs/deepspeed_zero3.yaml \
scripts/run_simpo.py \
training_configs/mistral-7b-base-simpo.yaml
```
## Common workflows
### Workflow 1: Train from base model (Mistral 7B)
**Config** (`mistral-7b-base-simpo.yaml`):
```yaml
# Model
model_name_or_path: mistralai/Mistral-7B-v0.1
torch_dtype: bfloat16
# Dataset
dataset_mixer:
HuggingFaceH4/ultrafeedback_binarized: 1.0
dataset_splits:
- train_prefs
- test_prefs
# SimPO hyperparameters
beta: 2.0 # Reward scaling (2.0-10.0)
gamma_beta_ratio: 0.5 # Target margin (0-1)
loss_type: sigmoid # sigmoid or hinge
sft_weight: 0.0 # Optional SFT regularization
# Training
learning_rate: 5e-7 # Critical: 3e-7 to 1e-6
num_train_epochs: 1
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
# Output
output_dir: ./outputs/mistral-7b-simpo
```
**Launch training**:
```bash
accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml \
scripts/run_simpo.py training_configs/mistral-7b-base-simpo.yaml
```
### Workflow 2: Fine-tune instruct model (Llama 3 8B)
**Config** (`llama3-8b-instruct-simpo.yaml`):
```yaml
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
dataset_mixer:
argilla/ultrafeedback-binarized-preferences-cleaned: 1.0
beta: 2.5
gamma_beta_ratio: 0.5
learning_rate: 5e-7
sft_weight: 0.1 # Add SFT loss to preserve capabilities
num_train_epochs: 1
per_device_train_batch_size: 2
gradient_accumulation_steps: 4
output_dir: ./outputs/llama3-8b-simpo
```
**Launch**:
```bash
accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml \
scripts/run_simpo.py training_configs/llama3-8b-instruct-simpo.yaml
```
### Workflow 3: Reasoning-intensive tasks (lower LR)
**For math/code tasks**:
```yaml
model_name_or_path: deepseek-ai/deepseek-math-7b-base
dataset_mixer:
argilla/distilabel-math-preference-dpo: 1.0
beta: 5.0 # Higher for stronger signal
gamma_beta_ratio: 0.7 # Larger margin
learning_rate: 3e-7 # Lower LR for reasoning
sft_weight: 0.0
num_train_epochs: 1
per_device_train_batch_size: 1
gradient_accumulation_steps: 16
```
## When to use vs alternatives
**Use SimPO when**:
- Want simpler training than DPO (no reference model)
- Have preference data (chosen/rejected pairs)
- Need better performance than DPO
- Limited compute resources
- Single-node training sufficient
**Algorithm selection**:
- **SimPO**: Simplest, best performance, no reference model
- **DPO**: Need reference model baseline, more conservative
- **PPO**: Maximum control, need reward model, complex setup
- **GRPO**: Memory-efficient RL, no critic
**Use alternatives instead**:
- **OpenRLHF**: Multi-node distributed training, PPO/GRPO
- **TRL**: Need multiple methods in one framework
- **DPO**: Established baseline comparison
## Common issues
**Issue: Loss divergence**
Reduce learning rate:
```yaml
learning_rate: 3e-7 # Reduce from 5e-7
```
Reduce beta:
```yaml
beta: 1.0 # Reduce from 2.0
```
**Issue: Model forgets capabilities**
Add SFT regularization:
```yaml
sft_weight: 0.1 # Add SFT loss component
```
**Issue: Poor preference separation**
Increase beta and margin:
```yaml
beta: 5.0 # Increase from 2.0
gamma_beta_ratio: 0.8 # Increase from 0.5
```
**Issue: OOM during training**
Reduce batch size:
```yaml
per_device_train_batch_size: 1
gradient_accumulation_steps: 16 # Maintain effective batch
```
Enable gradient checkpointing:
```yaml
gradient_checkpointing: true
```
## Advanced topics
**Loss functions**: See [references/loss-functions.md](references/loss-functions.md) for sigmoid vs hinge loss, mathematical formulations, and when to use each.
**Hyperparameter tuning**: See [references/hyperparameters.md](references/hyperparameters.md) for beta, gamma, learning rate selection guide, and model-size-specific recommendations.
**Dataset preparation**: See [references/datasets.md](references/datasets.md) for preference data formats, quality filtering, and custom dataset creation.
## Hardware requirements
- **GPU**: NVIDIA A100/H100 recommended
- **VRAM**:
- 7B model: 1× A100 40GB (DeepSpeed ZeRO-3)
- 8B model: 2× A100 40GB
- 70B model: 8× A100 80GB
- **Single-node**: DeepSpeed ZeRO-3 sufficient
- **Mixed precision**: BF16 recommended
**Memory optimization**:
- DeepSpeed ZeRO-3 (default config)
- Gradient checkpointing
- Flash Attention 2
## Resources
- Paper: https://arxiv.org/abs/2405.14734 (NeurIPS 2024)
- GitHub: https://github.com/princeton-nlp/SimPO
- Models: https://huggingface.co/princeton-nlp
- Alignment Handbook: https://github.com/huggingface/alignment-handbook

View file

@ -0,0 +1,478 @@
# Datasets
Complete guide to preference datasets for SimPO training.
## Dataset Format
### Required Fields
Preference datasets must contain:
```json
{
"prompt": "User question or instruction",
"chosen": "Better/preferred response",
"rejected": "Worse/rejected response"
}
```
**Alternative field names** (auto-detected):
- `prompt``question`, `instruction`, `input`
- `chosen``response_chosen`, `winner`, `preferred`
- `rejected``response_rejected`, `loser`
### Example Entry
```json
{
"prompt": "Explain quantum computing in simple terms.",
"chosen": "Quantum computing uses quantum bits (qubits) that can exist in multiple states simultaneously through superposition. This allows quantum computers to process many possibilities at once, making them potentially much faster than classical computers for specific tasks like cryptography and optimization.",
"rejected": "It's like regular computing but quantum."
}
```
## Popular Datasets
### 1. UltraFeedback (Recommended)
**HuggingFaceH4/ultrafeedback_binarized**:
- **Size**: 60K preference pairs
- **Quality**: High (GPT-4 annotations)
- **Domain**: General instruction following
- **Format**: Clean, ready-to-use
**Config**:
```yaml
dataset_mixer:
HuggingFaceH4/ultrafeedback_binarized: 1.0
dataset_splits:
- train_prefs
- test_prefs
```
### 2. Argilla UltraFeedback (Cleaned)
**argilla/ultrafeedback-binarized-preferences-cleaned**:
- **Size**: 50K pairs (filtered)
- **Quality**: Very high (deduped, cleaned)
- **Domain**: General
- **Format**: Clean
**Config**:
```yaml
dataset_mixer:
argilla/ultrafeedback-binarized-preferences-cleaned: 1.0
```
### 3. Distilabel Math
**argilla/distilabel-math-preference-dpo**:
- **Size**: 30K pairs
- **Quality**: High (GSM8K, MATH)
- **Domain**: Math reasoning
- **Format**: Math-specific
**Config**:
```yaml
dataset_mixer:
argilla/distilabel-math-preference-dpo: 1.0
```
### 4. HelpSteer
**nvidia/HelpSteer**:
- **Size**: 38K samples
- **Quality**: High (human ratings)
- **Domain**: Helpfulness alignment
- **Format**: Multi-attribute ratings
**Config**:
```yaml
dataset_mixer:
nvidia/HelpSteer: 1.0
```
### 5. Anthropic HH-RLHF
**Anthropic/hh-rlhf**:
- **Size**: 161K samples
- **Quality**: High (human preferences)
- **Domain**: Harmless + helpful
- **Format**: Conversational
**Config**:
```yaml
dataset_mixer:
Anthropic/hh-rlhf: 1.0
```
## Dataset Mixing
### Multiple Datasets
**Equal mix**:
```yaml
dataset_mixer:
HuggingFaceH4/ultrafeedback_binarized: 0.5
Anthropic/hh-rlhf: 0.5
```
**Weighted mix**:
```yaml
dataset_mixer:
HuggingFaceH4/ultrafeedback_binarized: 0.7
argilla/distilabel-math-preference-dpo: 0.2
nvidia/HelpSteer: 0.1
```
**Domain-specific emphasis**:
```yaml
# 80% general + 20% math
dataset_mixer:
HuggingFaceH4/ultrafeedback_binarized: 0.8
argilla/distilabel-math-preference-dpo: 0.2
```
## Data Quality
### Quality Indicators
**Good preference data**:
- ✅ Clear quality difference between chosen/rejected
- ✅ Diverse prompts
- ✅ Minimal noise/annotation errors
- ✅ Appropriate difficulty level
**Poor preference data**:
- ❌ Ambiguous preferences
- ❌ Repetitive prompts
- ❌ Annotation noise
- ❌ Too easy/hard prompts
### Quality Filtering
**Filter by length difference**:
```python
def filter_by_length(example):
chosen_len = len(example['chosen'].split())
rejected_len = len(example['rejected'].split())
# Reject if chosen is much shorter (potential low-effort)
return chosen_len >= rejected_len * 0.5
dataset = dataset.filter(filter_by_length)
```
**Filter by diversity**:
```python
seen_prompts = set()
def filter_duplicates(example):
prompt = example['prompt']
if prompt in seen_prompts:
return False
seen_prompts.add(prompt)
return True
dataset = dataset.filter(filter_duplicates)
```
## Custom Dataset Creation
### Format 1: JSON Lines
**File** (`preferences.jsonl`):
```jsonl
{"prompt": "What is Python?", "chosen": "Python is a high-level programming language...", "rejected": "It's a snake."}
{"prompt": "Explain AI.", "chosen": "AI refers to systems that can...", "rejected": "It's computers that think."}
```
**Load**:
```yaml
dataset_mixer:
json:
data_files: preferences.jsonl
```
### Format 2: HuggingFace Dataset
**Create from dict**:
```python
from datasets import Dataset
data = {
"prompt": ["What is Python?", "Explain AI."],
"chosen": ["Python is...", "AI refers to..."],
"rejected": ["It's a snake.", "It's computers..."]
}
dataset = Dataset.from_dict(data)
dataset.push_to_hub("username/my-preferences")
```
**Use in config**:
```yaml
dataset_mixer:
username/my-preferences: 1.0
```
### Format 3: ChatML
**For conversational data**:
```json
{
"prompt": [
{"role": "user", "content": "What is quantum computing?"}
],
"chosen": [
{"role": "assistant", "content": "Quantum computing uses qubits..."}
],
"rejected": [
{"role": "assistant", "content": "It's like regular computing but quantum."}
]
}
```
**Apply chat template**:
```yaml
dataset_text_field: null # Will apply chat template
```
## Synthetic Data Generation
### Using GPT-4
**Prompt template**:
```
Given the following question:
{prompt}
Generate two responses:
1. A high-quality, detailed response (chosen)
2. A low-quality, brief response (rejected)
Format as JSON with "chosen" and "rejected" fields.
```
**Example code**:
```python
import openai
def generate_pair(prompt):
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[{
"role": "user",
"content": f"Given: {prompt}\n\nGenerate chosen/rejected pair in JSON."
}]
)
return json.loads(response.choices[0].message.content)
# Generate dataset
prompts = load_prompts()
dataset = [generate_pair(p) for p in prompts]
```
### Using Local Model
**With vLLM**:
```python
from vllm import LLM
llm = LLM(model="meta-llama/Meta-Llama-3-70B-Instruct")
def generate_variations(prompt):
# Generate multiple completions
outputs = llm.generate(
[prompt] * 4,
sampling_params={
"temperature": 0.8,
"top_p": 0.9,
"max_tokens": 512
}
)
# Select best/worst
chosen = max(outputs, key=lambda x: len(x.outputs[0].text))
rejected = min(outputs, key=lambda x: len(x.outputs[0].text))
return {
"prompt": prompt,
"chosen": chosen.outputs[0].text,
"rejected": rejected.outputs[0].text
}
```
## Data Preprocessing
### Truncation
**Limit sequence length**:
```yaml
max_prompt_length: 512
max_completion_length: 512
max_length: 1024 # Total
```
**Implementation**:
```python
def truncate_example(example):
tokenizer.truncation_side = "left" # For prompts
prompt_tokens = tokenizer(
example['prompt'],
max_length=512,
truncation=True
)
tokenizer.truncation_side = "right" # For completions
chosen_tokens = tokenizer(
example['chosen'],
max_length=512,
truncation=True
)
return {
"prompt": tokenizer.decode(prompt_tokens['input_ids']),
"chosen": tokenizer.decode(chosen_tokens['input_ids'])
}
dataset = dataset.map(truncate_example)
```
### Deduplication
**Remove exact duplicates**:
```python
dataset = dataset.unique('prompt')
```
**Remove near-duplicates** (MinHash):
```python
from datasketch import MinHash, MinHashLSH
def deduplicate_lsh(dataset, threshold=0.8):
lsh = MinHashLSH(threshold=threshold, num_perm=128)
seen = []
for i, example in enumerate(dataset):
m = MinHash(num_perm=128)
for word in example['prompt'].split():
m.update(word.encode('utf8'))
if not lsh.query(m):
lsh.insert(i, m)
seen.append(example)
return Dataset.from_list(seen)
dataset = deduplicate_lsh(dataset)
```
## Data Augmentation
### Paraphrasing Prompts
```python
def paraphrase_prompt(example):
# Use paraphrasing model
paraphrased = paraphrase_model(example['prompt'])
return [
example, # Original
{
"prompt": paraphrased,
"chosen": example['chosen'],
"rejected": example['rejected']
}
]
dataset = dataset.map(paraphrase_prompt, batched=False, remove_columns=[])
```
### Difficulty Balancing
**Mix easy/medium/hard**:
```python
def categorize_difficulty(example):
prompt_len = len(example['prompt'].split())
if prompt_len < 20:
return "easy"
elif prompt_len < 50:
return "medium"
else:
return "hard"
dataset = dataset.map(lambda x: {"difficulty": categorize_difficulty(x)})
# Sample balanced dataset
easy = dataset.filter(lambda x: x['difficulty'] == 'easy').shuffle().select(range(1000))
medium = dataset.filter(lambda x: x['difficulty'] == 'medium').shuffle().select(range(1000))
hard = dataset.filter(lambda x: x['difficulty'] == 'hard').shuffle().select(range(1000))
balanced = concatenate_datasets([easy, medium, hard]).shuffle()
```
## Dataset Statistics
### Compute Stats
```python
def compute_stats(dataset):
prompt_lens = [len(x['prompt'].split()) for x in dataset]
chosen_lens = [len(x['chosen'].split()) for x in dataset]
rejected_lens = [len(x['rejected'].split()) for x in dataset]
print(f"Dataset size: {len(dataset)}")
print(f"Avg prompt length: {np.mean(prompt_lens):.1f} words")
print(f"Avg chosen length: {np.mean(chosen_lens):.1f} words")
print(f"Avg rejected length: {np.mean(rejected_lens):.1f} words")
print(f"Chosen > Rejected: {sum(c > r for c, r in zip(chosen_lens, rejected_lens)) / len(dataset):.1%}")
compute_stats(dataset)
```
**Expected output**:
```
Dataset size: 50000
Avg prompt length: 45.2 words
Avg chosen length: 180.5 words
Avg rejected length: 120.3 words
Chosen > Rejected: 85.2%
```
## Best Practices
### 1. Data Quality Over Quantity
- **Prefer**: 10K high-quality pairs
- **Over**: 100K noisy pairs
### 2. Clear Preference Signals
- Chosen should be noticeably better
- Avoid marginal differences
- Remove ambiguous pairs
### 3. Domain Matching
- Match dataset domain to target use case
- Mix datasets for broader coverage
- Include safety-filtered data
### 4. Validate Before Training
```python
# Sample 10 random examples
samples = dataset.shuffle().select(range(10))
for ex in samples:
print(f"Prompt: {ex['prompt']}")
print(f"Chosen: {ex['chosen'][:100]}...")
print(f"Rejected: {ex['rejected'][:100]}...")
print(f"Preference clear: {'✓' if len(ex['chosen']) > len(ex['rejected']) else '?'}")
print()
```
## References
- HuggingFace Datasets: https://huggingface.co/datasets
- Alignment Handbook: https://github.com/huggingface/alignment-handbook
- UltraFeedback: https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized

View file

@ -0,0 +1,452 @@
# Hyperparameters
Complete guide to SimPO hyperparameter selection and tuning.
## Overview
Key hyperparameters in SimPO:
1. **Learning Rate** - Most critical
2. **Beta (β)** - Reward scaling
3. **Gamma-Beta Ratio (γ/β)** - Target margin
4. **SFT Weight** - Regularization strength
## Learning Rate
### Recommended Ranges
**By model size**:
| Model Size | Learning Rate | Notes |
|------------|---------------|-------|
| 1B-3B | 5e-7 to 1e-6 | Higher end safe |
| 7B-8B | 3e-7 to 5e-7 | **Standard** |
| 13B-30B | 1e-7 to 3e-7 | Lower for stability |
| 70B+ | 5e-8 to 1e-7 | Very conservative |
**By task type**:
| Task | Learning Rate | Reason |
|------|---------------|--------|
| General chat | 5e-7 | Standard |
| Code generation | 3e-7 | **Precise reasoning** |
| Math reasoning | 3e-7 | **Careful optimization** |
| Creative writing | 1e-6 | More aggressive OK |
### Why Learning Rate Matters
**Too high** (> 1e-6 for 7B):
- Loss divergence
- Catastrophic forgetting
- Unstable training
**Too low** (< 1e-7 for 7B):
- Very slow convergence
- May not finish in time
- Undertraining
**Optimal** (3e-7 to 5e-7 for 7B):
- Stable convergence
- Good final performance
- Efficient training
### Config Examples
**Mistral 7B (general)**:
```yaml
learning_rate: 5e-7
num_train_epochs: 1
warmup_ratio: 0.1
lr_scheduler_type: cosine
```
**Llama 3 8B (reasoning)**:
```yaml
learning_rate: 3e-7
num_train_epochs: 1
warmup_ratio: 0.1
lr_scheduler_type: cosine
```
**Gemma 2 9B (creative)**:
```yaml
learning_rate: 1e-6
num_train_epochs: 1
warmup_ratio: 0.1
lr_scheduler_type: linear
```
## Beta (β)
### Recommended Values
**Range**: 2.0 to 10.0 (much higher than DPO's 0.01-0.1)
**By preference strength**:
| Beta | Preference Strength | Use Case |
|------|-------------------|----------|
| 1.0-2.0 | Weak | Subtle preferences |
| 2.0-5.0 | **Standard** | General alignment |
| 5.0-10.0 | Strong | Clear preferences |
**Default**: 2.0 to 2.5
### Why Beta Matters
**Low beta** (< 2.0):
- Weak reward signal
- Slow preference learning
- May underfit
**High beta** (> 10.0):
- Very strong reward signal
- Risk of overfitting
- May ignore weak preferences
**Optimal** (2.0-5.0):
- Balanced reward scaling
- Stable training
- Good generalization
### Interaction with Gamma
**Beta and gamma together**:
```
Target margin in reward space = gamma
Target margin in logit space = gamma / beta
```
**Example**:
```yaml
beta: 2.0
gamma_beta_ratio: 0.5
# Effective gamma = 2.0 * 0.5 = 1.0
```
### Config Examples
**Weak preferences**:
```yaml
beta: 2.0
gamma_beta_ratio: 0.3 # Small margin
```
**Standard**:
```yaml
beta: 2.5
gamma_beta_ratio: 0.5 # Default
```
**Strong preferences**:
```yaml
beta: 5.0
gamma_beta_ratio: 0.7 # Larger margin
```
## Gamma-Beta Ratio (γ/β)
### Recommended Values
**Range**: 0.0 to 1.0
**By scenario**:
| Ratio | Margin | Use Case |
|-------|--------|----------|
| 0.0-0.3 | Small | Weak preference data |
| 0.4-0.6 | **Standard** | General use |
| 0.7-1.0 | Large | Very clear preferences |
**Default**: 0.5
### Why Gamma Matters
**Low gamma** (< 0.3):
- Small target margin
- Less aggressive alignment
- More conservative
**High gamma** (> 0.7):
- Large target margin
- Stronger alignment
- More aggressive
**Optimal** (0.4-0.6):
- Balanced margin
- Stable training
- Good alignment
### Mathematical Meaning
**In loss function**:
```python
logits = pi_logratios - gamma_beta_ratio
loss = -log(sigmoid(beta * logits))
```
**Interpretation**:
- gamma_beta_ratio shifts the decision boundary
- Higher ratio = requires larger log prob difference
- Controls how "clear" preferences must be
### Config Examples
**Noisy preferences**:
```yaml
gamma_beta_ratio: 0.3 # Smaller margin, more tolerant
```
**Standard**:
```yaml
gamma_beta_ratio: 0.5 # Default
```
**High-quality preferences**:
```yaml
gamma_beta_ratio: 0.8 # Larger margin, stricter
```
## SFT Weight
### Recommended Values
**Range**: 0.0 to 1.0
**By model type**:
| Model Type | SFT Weight | Reason |
|------------|-----------|--------|
| Base model | 0.0 | No prior capabilities |
| **Instruct model** | 0.05-0.1 | Preserve instruction following |
| Chat model | 0.1-0.2 | Preserve conversational skills |
**Default**: 0.0 (no SFT regularization)
### Why SFT Weight Matters
**Zero SFT** (0.0):
- Pure preference optimization
- May forget capabilities
- Standard for base models
**Low SFT** (0.05-0.1):
- Balanced approach
- **Recommended for instruct models**
- Slight capability preservation
**High SFT** (> 0.2):
- Strong capability preservation
- Weaker preference alignment
- May reduce alignment gains
### Trade-off
```
Total Loss = SimPO Loss + (sft_weight * SFT Loss)
```
**Example**:
```yaml
sft_weight: 0.1
# 90% preference optimization + 10% capability preservation
```
### Config Examples
**Base model (no SFT)**:
```yaml
model_name_or_path: mistralai/Mistral-7B-v0.1
sft_weight: 0.0
```
**Instruct model (light SFT)**:
```yaml
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
sft_weight: 0.1
```
**Chat model (moderate SFT)**:
```yaml
model_name_or_path: HuggingFaceH4/zephyr-7b-beta
sft_weight: 0.2
```
## Model-Size-Specific Recommendations
### 7B Models (Mistral, Llama 3)
**Standard config**:
```yaml
learning_rate: 5e-7
beta: 2.0
gamma_beta_ratio: 0.5
sft_weight: 0.0 # 0.1 if instruct model
num_train_epochs: 1
per_device_train_batch_size: 2
gradient_accumulation_steps: 4
```
### 8B-13B Models
**Standard config**:
```yaml
learning_rate: 3e-7
beta: 2.5
gamma_beta_ratio: 0.5
sft_weight: 0.1 # If instruct
num_train_epochs: 1
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
```
### 70B Models
**Standard config**:
```yaml
learning_rate: 1e-7
beta: 2.0
gamma_beta_ratio: 0.5
sft_weight: 0.05
num_train_epochs: 1
per_device_train_batch_size: 1
gradient_accumulation_steps: 16
```
## Batch Size & Gradient Accumulation
### Effective Batch Size
```
Effective Batch Size = per_device_batch_size * num_gpus * grad_accum_steps
```
**Recommended effective batch sizes**:
- 7B: 128-256
- 13B: 64-128
- 70B: 32-64
### Config Examples
**Single GPU (A100 40GB)**:
```yaml
per_device_train_batch_size: 1
gradient_accumulation_steps: 128 # Effective batch = 128
```
**4 GPUs (A100 40GB)**:
```yaml
per_device_train_batch_size: 2
gradient_accumulation_steps: 16 # Effective batch = 2*4*16 = 128
```
**8 GPUs (A100 80GB)**:
```yaml
per_device_train_batch_size: 2
gradient_accumulation_steps: 8 # Effective batch = 2*8*8 = 128
```
## Loss Type
### Sigmoid vs Hinge
**Sigmoid** (default, recommended):
```yaml
loss_type: sigmoid
label_smoothing: 0.0
```
**Hinge** (experimental):
```yaml
loss_type: hinge
# No label smoothing for hinge
```
**When to use hinge**:
- Margin-based tasks
- SVM-style optimization
- Experimental purposes
**Generally**: Stick with sigmoid
## Tuning Guide
### Step 1: Start with Defaults
```yaml
learning_rate: 5e-7 # For 7B
beta: 2.0
gamma_beta_ratio: 0.5
sft_weight: 0.0 # 0.1 if instruct
loss_type: sigmoid
```
### Step 2: Monitor Training
**Check every 100 steps**:
- Loss curve (should decrease smoothly)
- Reward margin (should increase)
- Chosen/rejected logps (should separate)
### Step 3: Adjust if Needed
**If loss diverges**:
```yaml
learning_rate: 3e-7 # Reduce from 5e-7
beta: 1.0 # Reduce from 2.0
```
**If loss plateaus early**:
```yaml
learning_rate: 1e-6 # Increase from 5e-7
beta: 5.0 # Increase from 2.0
```
**If model forgets**:
```yaml
sft_weight: 0.2 # Increase from 0.0
```
## Complete Example Configs
### Mistral 7B Base (Standard)
```yaml
model_name_or_path: mistralai/Mistral-7B-v0.1
dataset_mixer:
HuggingFaceH4/ultrafeedback_binarized: 1.0
learning_rate: 5e-7
beta: 2.0
gamma_beta_ratio: 0.5
loss_type: sigmoid
sft_weight: 0.0
num_train_epochs: 1
per_device_train_batch_size: 2
gradient_accumulation_steps: 4
warmup_ratio: 0.1
lr_scheduler_type: cosine
bf16: true
gradient_checkpointing: true
```
### Llama 3 8B Instruct (Reasoning)
```yaml
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
dataset_mixer:
argilla/distilabel-math-preference-dpo: 1.0
learning_rate: 3e-7
beta: 5.0
gamma_beta_ratio: 0.7
loss_type: sigmoid
sft_weight: 0.1
num_train_epochs: 1
per_device_train_batch_size: 1
gradient_accumulation_steps: 16
warmup_ratio: 0.1
lr_scheduler_type: cosine
```
## References
- SimPO paper: https://arxiv.org/abs/2405.14734
- Alignment Handbook: https://github.com/huggingface/alignment-handbook

View file

@ -0,0 +1,350 @@
# Loss Functions
Complete guide to SimPO loss functions and mathematical formulations.
## Overview
SimPO supports two loss types:
- **Sigmoid** (default) - Smooth, differentiable loss
- **Hinge** - Margin-based, sparse loss
Both are reference-free (no reference model needed).
## SimPO Loss Formula
### Core Calculation
**Step 1: Log probability ratio**:
```
pi_logratios = log P_θ(y_chosen|x) - log P_θ(y_rejected|x)
```
**Step 2: Apply target margin**:
```
logits = pi_logratios - γ
```
Where:
- γ/β = `gamma_beta_ratio` (target margin)
**Step 3: Compute loss** (depends on loss type)
### Sigmoid Loss (Default)
**Formula**:
```
L = -log σ* logits) * (1 - ε) - log σ(-β * logits) * ε
```
Where:
- β = `beta` (reward scaling)
- σ = sigmoid function
- ε = `label_smoothing` (default 0.0)
**Implementation**:
```python
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
```
**Characteristics**:
- Smooth, continuous gradients
- Probabilistic interpretation
- Standard choice for most tasks
- Works well with higher beta values
### Hinge Loss
**Formula**:
```
L = max(0, 1 - β * logits)
```
**Implementation**:
```python
losses = torch.relu(1 - self.beta * logits)
```
**Characteristics**:
- Non-smooth (has kink at logits = 1/β)
- Margin-based (SVM-style)
- Can lead to sparser solutions
- Less commonly used
## Comparison to DPO
### DPO Loss (Reference Model Required)
**Formula**:
```
L_DPO = -E[log σ* log(π_θ(y_w|x)/π_ref(y_w|x)) - β * log(π_θ(y_l|x)/π_ref(y_l|x)))]
```
**Key features**:
- Requires reference model π_ref
- Normalizes by reference log probabilities
- More conservative (stays close to reference)
### SimPO Loss (Reference-Free)
**Formula**:
```
L_SimPO = -log σ(β * (log π_θ(y_w|x) - log π_θ(y_l|x) - γ/β))
```
**Key features**:
- No reference model needed
- Direct preference optimization
- Target margin γ/β controls preference strength
- More efficient (fewer model forward passes)
**Visual comparison**:
```
DPO: [Policy] - [Reference] → Loss
SimPO: [Policy] → Loss
```
## Average Log Probability Reward
### Calculation
**Per-token log probabilities**:
```python
# Get log probs for each token
per_token_logps = log_softmax(logits).gather(dim=-1, index=labels)
# Create mask to ignore padding
loss_mask = (labels != label_pad_token_id)
```
**Average log probability** (if `average_log_prob=True`):
```python
avg_logp = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
```
**Sum log probability** (if `average_log_prob=False`):
```python
sum_logp = (per_token_logps * loss_mask).sum(-1)
```
**Why average?**
- Normalizes for sequence length
- Prevents bias toward shorter/longer responses
- Standard practice in SimPO
### Reward Metrics
**Chosen reward**:
```python
chosen_rewards = beta * policy_chosen_logps.detach()
```
**Rejected reward**:
```python
rejected_rewards = beta * policy_rejected_logps.detach()
```
**Reward margin**:
```python
reward_margin = chosen_rewards.mean() - rejected_rewards.mean()
```
## Label Smoothing
### Formula with Smoothing
**Sigmoid loss**:
```
L = -log σ* logits) * (1 - ε) - log σ(-β * logits) * ε
```
**Effect**:
- ε = 0.0: No smoothing (default)
- ε = 0.1: 10% smoothing (soft labels)
- ε = 0.5: Maximum smoothing
**When to use**:
- Noisy preference labels
- Uncertain preferences
- Prevent overconfidence
**Config**:
```yaml
label_smoothing: 0.1 # 10% smoothing
```
## SFT Regularization
### Combined Loss
**With SFT component**:
```
L_total = L_SimPO + λ * L_SFT
```
Where:
- L_SFT = cross-entropy loss on chosen responses
- λ = `sft_weight` (0.0 to 1.0)
**Implementation**:
```python
if self.sft_weight > 0:
sft_loss = -policy_chosen_logps
total_loss = simpo_loss + self.sft_weight * sft_loss
```
**When to use**:
- Preserve model capabilities
- Prevent catastrophic forgetting
- Fine-tuning instruct models
**Trade-off**:
- Higher sft_weight: Preserve capabilities, less alignment
- Lower sft_weight: Stronger alignment, may forget capabilities
**Config**:
```yaml
sft_weight: 0.1 # 10% SFT regularization
```
## Loss Type Selection
### Sigmoid vs Hinge
| Aspect | Sigmoid | Hinge |
|--------|---------|-------|
| Smoothness | Smooth | Non-smooth |
| Gradients | Continuous | Discontinuous at margin |
| Sparsity | Dense solutions | Sparse solutions |
| Interpretability | Probabilistic | Geometric margin |
| Use case | **General purpose** | Margin-based tasks |
| Recommendation | **Default choice** | Experimental |
**Config**:
```yaml
# Sigmoid (default)
loss_type: sigmoid
# Hinge (alternative)
loss_type: hinge
```
## Mathematical Properties
### Gradient Analysis
**Sigmoid loss gradient**:
```
∂L/∂logits = -β * σ(-β * logits) * (1 - ε) + β * σ* logits) * ε
```
**Hinge loss gradient**:
```
∂L/∂logits = -β if logits < 1/β
0 otherwise
```
**Implications**:
- Sigmoid: Always provides gradient signal
- Hinge: No gradient when margin satisfied
### Convergence Behavior
**Sigmoid**:
- Asymptotically approaches zero loss
- Continues optimizing even with large margins
- Smoother training curves
**Hinge**:
- Reaches zero loss at margin
- Stops optimizing once margin satisfied
- May have training plateaus
## Complete Loss Examples
### Example 1: Basic SimPO (Sigmoid)
**Config**:
```yaml
beta: 2.0
gamma_beta_ratio: 0.5
loss_type: sigmoid
label_smoothing: 0.0
sft_weight: 0.0
```
**Loss calculation**:
```python
# Step 1: Compute log probs
chosen_logps = avg_log_prob(policy(chosen)) # e.g., -1.2
rejected_logps = avg_log_prob(policy(rejected)) # e.g., -2.5
# Step 2: Log ratio and margin
pi_logratios = -1.2 - (-2.5) = 1.3
logits = 1.3 - 0.5 = 0.8
# Step 3: Sigmoid loss
loss = -log(sigmoid(2.0 * 0.8))
= -log(sigmoid(1.6))
= -log(0.832)
= 0.184
```
### Example 2: SimPO with SFT
**Config**:
```yaml
beta: 2.5
gamma_beta_ratio: 0.5
loss_type: sigmoid
sft_weight: 0.1
```
**Loss calculation**:
```python
# SimPO loss (as above)
simpo_loss = 0.184
# SFT loss
sft_loss = -chosen_logps = -(-1.2) = 1.2
# Total loss
total_loss = simpo_loss + 0.1 * sft_loss
= 0.184 + 0.12
= 0.304
```
## Debugging
### Check Reward Margins
**Low margin (< 0.5)**:
- Preferences not being learned
- Increase beta or gamma_beta_ratio
**High margin (> 5.0)**:
- May be overfitting
- Reduce beta or learning rate
**Monitor**:
```python
reward_margin = chosen_rewards.mean() - rejected_rewards.mean()
print(f"Reward margin: {reward_margin:.2f}")
```
### Check Log Probabilities
**Typical values**:
- Chosen: -1.0 to -2.0 (higher is better)
- Rejected: -2.0 to -4.0 (lower is worse)
**Warning signs**:
- Both very negative (< -10): Model not learning
- Both very positive (> 0): Numerical instability
## References
- SimPO paper: https://arxiv.org/abs/2405.14734
- DPO paper: https://arxiv.org/abs/2305.18290
- Implementation: https://github.com/princeton-nlp/SimPO

View file

@ -0,0 +1,467 @@
---
name: slime-rl-training
description: Provides guidance for LLM post-training with RL using slime, a Megatron+SGLang framework. Use when training GLM models, implementing custom data generation workflows, or needing tight Megatron-LM integration for RL scaling.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [sglang-router>=0.2.3, ray, torch>=2.0.0, transformers>=4.40.0]
metadata:
hermes:
tags: [Reinforcement Learning, Megatron-LM, SGLang, GRPO, Post-Training, GLM]
---
# slime: LLM Post-Training Framework for RL Scaling
slime is an LLM post-training framework from Tsinghua's THUDM team, powering GLM-4.5, GLM-4.6, and GLM-4.7. It connects Megatron-LM for training with SGLang for high-throughput rollout generation.
## When to Use slime
**Choose slime when you need:**
- Megatron-LM native training with SGLang inference
- Custom data generation workflows with flexible data buffers
- Training GLM, Qwen3, DeepSeek V3, or Llama 3 models
- Research-grade framework with production backing (Z.ai)
**Consider alternatives when:**
- You need enterprise-grade stability features → use **miles**
- You want flexible backend swapping → use **verl**
- You need PyTorch-native abstractions → use **torchforge**
## Key Features
- **Training**: Megatron-LM with full parallelism support (TP, PP, DP, SP)
- **Rollout**: SGLang-based high-throughput generation with router
- **Data Buffer**: Flexible prompt management and sample storage
- **Models**: GLM-4.x, Qwen3, DeepSeek V3/R1, Llama 3
## Architecture Overview
```
┌─────────────────────────────────────────────────────────┐
│ Data Buffer │
│ - Prompt initialization and management │
│ - Custom data generation and filtering │
│ - Rollout sample storage │
└─────────────┬───────────────────────────┬───────────────┘
│ │
┌─────────────▼───────────┐ ┌─────────────▼───────────────┐
│ Training (Megatron-LM) │ │ Rollout (SGLang + Router) │
│ - Actor model training │ │ - Response generation │
│ - Critic (optional) │ │ - Reward/verifier output │
│ - Weight sync to rollout│ │ - Multi-turn support │
└─────────────────────────┘ └─────────────────────────────┘
```
## Installation
```bash
# Recommended: Docker
docker pull slimerl/slime:latest
docker run --rm --gpus all --ipc=host --shm-size=16g \
-it slimerl/slime:latest /bin/bash
# Inside container
cd /root/slime && pip install -e . --no-deps
```
### From Source
```bash
git clone https://github.com/THUDM/slime.git
cd slime
pip install -r requirements.txt
pip install -e .
```
## Quick Start: GRPO Training
```bash
# Source model configuration
source scripts/models/qwen3-4B.sh
# Launch training
python train.py \
--actor-num-nodes 1 \
--actor-num-gpus-per-node 4 \
--rollout-num-gpus 4 \
--advantage-estimator grpo \
--use-kl-loss --kl-loss-coef 0.001 \
--rollout-batch-size 32 \
--n-samples-per-prompt 8 \
--global-batch-size 256 \
--num-rollout 3000 \
--prompt-data /path/to/data.jsonl \
${MODEL_ARGS[@]} ${CKPT_ARGS[@]}
```
---
## Workflow 1: Standard GRPO Training
Use this workflow for training reasoning models with group-relative advantages.
### Prerequisites Checklist
- [ ] Docker environment or Megatron-LM + SGLang installed
- [ ] Model checkpoint (HuggingFace or Megatron format)
- [ ] Training data in JSONL format
### Step 1: Prepare Data
```python
# data.jsonl format
{"prompt": "What is 2 + 2?", "label": "4"}
{"prompt": "Solve: 3x = 12", "label": "x = 4"}
```
Or with chat format:
```python
{
"prompt": [
{"role": "system", "content": "You are a math tutor."},
{"role": "user", "content": "What is 15 + 27?"}
],
"label": "42"
}
```
### Step 2: Configure Model
Choose a pre-configured model script:
```bash
# List available models
ls scripts/models/
# glm4-9B.sh, qwen3-4B.sh, qwen3-30B-A3B.sh, deepseek-v3.sh, llama3-8B.sh, ...
# Source your model
source scripts/models/qwen3-4B.sh
```
### Step 3: Launch Training
```bash
python train.py \
--actor-num-nodes 1 \
--actor-num-gpus-per-node 8 \
--rollout-num-gpus 8 \
--advantage-estimator grpo \
--use-kl-loss \
--kl-loss-coef 0.001 \
--prompt-data /path/to/train.jsonl \
--input-key prompt \
--label-key label \
--apply-chat-template \
--rollout-batch-size 32 \
--n-samples-per-prompt 8 \
--global-batch-size 256 \
--num-rollout 3000 \
--save-interval 100 \
--eval-interval 50 \
${MODEL_ARGS[@]}
```
### Step 4: Monitor Training
- [ ] Check TensorBoard: `tensorboard --logdir outputs/`
- [ ] Verify reward curves are increasing
- [ ] Monitor GPU utilization across nodes
---
## Workflow 2: Asynchronous Training
Use async mode for higher throughput by overlapping rollout and training.
### When to Use Async
- Large models with long generation times
- High GPU idle time in synchronous mode
- Sufficient memory for buffering
### Launch Async Training
```bash
python train_async.py \
--actor-num-nodes 1 \
--actor-num-gpus-per-node 8 \
--rollout-num-gpus 8 \
--advantage-estimator grpo \
--async-buffer-size 4 \
--prompt-data /path/to/train.jsonl \
${MODEL_ARGS[@]}
```
### Async-Specific Parameters
```bash
--async-buffer-size 4 # Number of rollouts to buffer
--update-weights-interval 2 # Sync weights every N rollouts
```
---
## Workflow 3: Multi-Turn Agentic Training
Use this workflow for training agents with tool use or multi-step reasoning.
### Prerequisites
- [ ] Custom generate function for multi-turn logic
- [ ] Tool/environment interface
### Step 1: Define Custom Generate Function
```python
# custom_generate.py
async def custom_generate(args, samples, evaluation=False):
"""Multi-turn generation with tool calling."""
for sample in samples:
conversation = sample.prompt
for turn in range(args.max_turns):
# Generate response
response = await generate_single(conversation)
# Check for tool call
tool_call = extract_tool_call(response)
if tool_call:
tool_result = execute_tool(tool_call)
conversation.append({"role": "assistant", "content": response})
conversation.append({"role": "tool", "content": tool_result})
else:
break
sample.response = response
sample.reward = compute_reward(sample)
return samples
```
### Step 2: Launch with Custom Function
```bash
python train.py \
--custom-generate-function-path custom_generate.py \
--max-turns 5 \
--prompt-data /path/to/agent_data.jsonl \
${MODEL_ARGS[@]}
```
See `examples/search-r1/` for a complete multi-turn search example.
---
## Configuration Reference
### Three Argument Categories
slime uses three types of arguments:
**1. Megatron Arguments** (passed directly):
```bash
--tensor-model-parallel-size 2
--pipeline-model-parallel-size 1
--num-layers 32
--hidden-size 4096
```
**2. SGLang Arguments** (prefixed with `--sglang-`):
```bash
--sglang-mem-fraction-static 0.8
--sglang-context-length 8192
--sglang-log-level INFO
```
**3. slime Arguments**:
```bash
# Resource allocation
--actor-num-nodes 1
--actor-num-gpus-per-node 8
--rollout-num-gpus 8
--colocate # Share GPUs between training/inference
# Data
--prompt-data /path/to/data.jsonl
--input-key prompt
--label-key label
# Training loop
--num-rollout 3000
--rollout-batch-size 32
--n-samples-per-prompt 8
--global-batch-size 256
# Algorithm
--advantage-estimator grpo # or: gspo, ppo, reinforce_plus_plus
--use-kl-loss
--kl-loss-coef 0.001
```
### Key Constraints
```
rollout_batch_size × n_samples_per_prompt = global_batch_size × num_steps_per_rollout
```
Example: 32 × 8 = 256 × 1
---
## Data Buffer System
slime's data buffer enables flexible data management:
### Basic Data Source
```python
class RolloutDataSource:
def get_samples(self, num_samples):
"""Fetch prompts from dataset."""
return self.dataset.sample(num_samples)
def add_samples(self, samples):
"""Called after generation (no-op by default)."""
pass
```
### Buffered Data Source (Off-Policy)
```python
class RolloutDataSourceWithBuffer(RolloutDataSource):
def __init__(self):
self.buffer = []
def add_samples(self, samples):
"""Store generated samples for reuse."""
self.buffer.extend(samples)
def buffer_filter(self, args, buffer, num_samples):
"""Custom selection logic (prioritized, stratified, etc.)."""
return select_best(buffer, num_samples)
```
---
## Common Issues and Solutions
### Issue: SGLang Engine Crash
**Symptoms**: Inference engine dies mid-training
**Solutions**:
```bash
# Enable fault tolerance
--use-fault-tolerance
# Increase memory allocation
--sglang-mem-fraction-static 0.85
# Reduce batch size
--rollout-batch-size 16
```
### Issue: Weight Sync Timeout
**Symptoms**: Training hangs after rollout
**Solutions**:
```bash
# Increase sync interval
--update-weights-interval 5
# Use colocated mode (no network transfer)
--colocate
```
### Issue: OOM During Training
**Symptoms**: CUDA OOM in backward pass
**Solutions**:
```bash
# Enable gradient checkpointing
--recompute-activations
# Reduce micro-batch size
--micro-batch-size 1
# Enable sequence parallelism
--sequence-parallel
```
### Issue: Slow Data Loading
**Symptoms**: GPU idle during data fetch
**Solutions**:
```bash
# Increase data workers
--num-data-workers 4
# Use streaming dataset
--streaming-data
```
---
## Supported Models
| Model Family | Configurations |
|--------------|----------------|
| GLM | GLM-4.5, GLM-4.6, GLM-4.7, GLM-Z1-9B |
| Qwen | Qwen3 (4B, 8B, 30B-A3B), Qwen3-MoE, Qwen2.5 |
| DeepSeek | V3, V3.1, R1 |
| Llama | Llama 3 (8B, 70B) |
| Others | Kimi K2, Moonlight-16B |
Each model has pre-configured scripts in `scripts/models/`.
---
## Advanced Topics
### Co-location Mode
Share GPUs between training and inference to reduce memory:
```bash
python train.py \
--colocate \
--actor-num-gpus-per-node 8 \
--sglang-mem-fraction-static 0.4 \
${MODEL_ARGS[@]}
```
### Custom Reward Model
```python
# custom_rm.py
class CustomRewardModel:
def __init__(self, model_path):
self.model = load_model(model_path)
def compute_reward(self, prompts, responses):
inputs = self.tokenize(prompts, responses)
scores = self.model(inputs)
return scores.tolist()
```
```bash
--custom-rm-path custom_rm.py
```
### Evaluation Multi-Task
```bash
--eval-prompt-data aime /path/to/aime.jsonl \
--eval-prompt-data gsm8k /path/to/gsm8k.jsonl \
--n-samples-per-eval-prompt 16
```
---
## Resources
- **Documentation**: https://thudm.github.io/slime/
- **GitHub**: https://github.com/THUDM/slime
- **Blog**: https://lmsys.org/blog/2025-07-09-slime/
- **Examples**: See `examples/` directory for 14+ worked examples

View file

@ -0,0 +1,392 @@
# slime API Reference
## Architecture Overview
slime operates with a three-module architecture orchestrated by Ray:
```
┌─────────────────────────────────────────────────────────┐
│ Data Buffer │
│ - Prompt initialization and management │
│ - Custom data generation and filtering │
│ - Rollout sample storage │
└─────────────┬───────────────────────────┬───────────────┘
│ │
┌─────────────▼───────────┐ ┌─────────────▼───────────────┐
│ Training (Megatron-LM) │ │ Rollout (SGLang + Router) │
│ - Actor model training │ │ - Response generation │
│ - Critic (optional) │ │ - Reward/verifier output │
│ - Weight sync to rollout│ │ - Multi-turn support │
└─────────────────────────┘ └─────────────────────────────┘
```
## Core Data Structures
### Sample Object
The `Sample` object is the core data structure defined in `slime/utils/types.py`:
```python
from slime.utils.types import Sample
@dataclass
class Sample:
# Core fields
group_index: Optional[int] # Group index for batching
index: Optional[int] # Sample index
prompt: str | list[dict] = "" # Input prompt or chat history
tokens: list[int] = field(default_factory=list) # Token IDs
response: str = "" # Generated response
response_length: int = 0 # Response length in tokens
label: Optional[str] = None # Ground truth label
reward: Optional[float | dict] = None # RL reward signal
loss_mask: Optional[list[int]] = None # 1=compute loss, 0=mask
status: Status = Status.PENDING # Sample status
metadata: dict = field(default_factory=dict) # Custom data
# Multimodal support
multimodal_inputs: Optional[Any] = None # Raw multimodal data (images, videos)
multimodal_train_inputs: Optional[Any] = None # Processed multimodal data (pixel_values)
# Rollout tracking
weight_versions: list[str] = field(default_factory=list)
rollout_log_probs: Optional[list[float]] = None # Log probs from SGLang
rollout_routed_experts: Optional[list[list[int]]] = None # Expert routing (MoE)
# Control fields
remove_sample: bool = False
generate_function_path: Optional[str] = None
train_metadata: Optional[dict] = None
non_generation_time: float = 0.0
# Speculative decoding info (nested dataclass)
@dataclass
class SpecInfo:
spec_accept_token_num: int = 0
spec_draft_token_num: int = 0
spec_verify_ct: int = 0
completion_token_num: int = 0
```
### Status Enum
```python
class Status(Enum):
PENDING = "pending" # Not yet processed
COMPLETED = "completed" # Successfully generated
TRUNCATED = "truncated" # Hit max length
ABORTED = "aborted" # Failed generation
FAILED = "failed" # Generation failed
```
## Configuration System
slime uses three categories of command-line arguments:
### 1. Megatron Arguments
All Megatron-LM arguments are supported directly:
```bash
--tensor-model-parallel-size 2
--pipeline-model-parallel-size 1
--num-layers 32
--hidden-size 4096
--num-attention-heads 32
--seq-length 4096
--micro-batch-size 1
--global-batch-size 256
```
### 2. SGLang Arguments
SGLang arguments are prefixed with `--sglang-`:
```bash
--sglang-mem-fraction-static 0.8 # GPU memory for KV cache
--sglang-context-length 8192 # Maximum context length
--sglang-log-level INFO # Logging verbosity
--sglang-tp-size 2 # Tensor parallelism
--sglang-disable-cuda-graph # Disable CUDA graphs
```
### 3. slime-Specific Arguments
Defined in `slime/utils/arguments.py`:
```bash
# Resource Allocation
--actor-num-nodes 1 # Training nodes
--actor-num-gpus-per-node 8 # GPUs per training node
--rollout-num-gpus 8 # Total rollout GPUs
--rollout-num-gpus-per-engine 2 # GPUs per SGLang engine
--colocate # Share GPUs for train/inference
# Data Configuration
--prompt-data /path/to/data.jsonl # Training data path
--input-key prompt # Key for prompts in JSON
--label-key label # Key for labels in JSON
--apply-chat-template # Apply chat formatting
# Training Loop
--num-rollout 3000 # Total rollout iterations
--rollout-batch-size 32 # Prompts per rollout
--n-samples-per-prompt 8 # Responses per prompt
--global-batch-size 256 # Training batch size
--num-steps-per-rollout 1 # Training steps per rollout
# RL Algorithm
--advantage-estimator grpo # grpo, gspo, ppo, reinforce_plus_plus
--use-kl-loss # Enable KL loss
--kl-loss-coef 0.001 # KL coefficient
--calculate-per-token-loss # Token-level loss
# Off-Policy Options
--use-tis # Truncated Importance Sampling
--tis-threshold 0.9 # TIS threshold
--true-on-policy-mode # Force on-policy training
```
## Data Buffer System
### RolloutDataSource (Base Class)
```python
from slime.data import RolloutDataSource
class RolloutDataSource:
def __init__(self, dataset, args):
self.dataset = dataset
self.args = args
def get_samples(self, num_samples: int) -> list[Sample]:
"""Fetch prompts from dataset."""
return [Sample(prompt=p) for p in self.dataset.sample(num_samples)]
def add_samples(self, samples: list[Sample]) -> None:
"""Called after generation (no-op by default)."""
pass
```
### Buffered Data Source (Off-Policy)
```python
from slime.data import RolloutDataSourceWithBuffer
class RolloutDataSourceWithBuffer(RolloutDataSource):
def __init__(self, dataset, args):
super().__init__(dataset, args)
self.buffer = []
def add_samples(self, samples: list[Sample]) -> None:
"""Store generated samples for reuse."""
self.buffer.extend(samples)
def buffer_filter(self, args, buffer, num_samples) -> list[Sample]:
"""Custom selection logic."""
# Example: prioritized sampling based on reward
sorted_buffer = sorted(buffer, key=lambda s: s.reward, reverse=True)
return sorted_buffer[:num_samples]
```
## Custom Functions
### Custom Generate Function
For multi-turn or tool-calling scenarios:
```python
# custom_generate.py
from slime.data import Sample
async def custom_generate(args, samples: list[Sample], evaluation: bool = False) -> list[Sample]:
"""
Custom generation function for multi-turn interactions.
Args:
args: Training arguments
samples: List of Sample objects with prompts
evaluation: Whether this is an evaluation run
Returns:
List of Sample objects with responses and rewards
"""
for sample in samples:
conversation = sample.prompt if isinstance(sample.prompt, list) else [
{"role": "user", "content": sample.prompt}
]
for turn in range(args.max_turns):
# Generate response
response = await generate_single(conversation)
# Check for tool call
tool_call = extract_tool_call(response)
if tool_call:
# Execute tool
tool_result = await execute_tool(tool_call)
conversation.append({"role": "assistant", "content": response})
conversation.append({"role": "tool", "content": tool_result})
else:
# Final response
sample.response = response
break
# Compute reward
sample.reward = compute_reward(sample)
# Set loss mask (1 for model tokens, 0 for tool responses)
sample.loss_mask = build_loss_mask(sample)
return samples
```
Usage:
```bash
python train.py \
--custom-generate-function-path custom_generate.py \
--max-turns 5
```
### Custom Reward Function
```python
# custom_rm.py
from slime.data import Sample
async def reward_func(args, sample: Sample, **kwargs) -> float:
"""
Compute reward for a single sample.
Args:
args: Training arguments
sample: Sample object with response
Returns:
Reward score (float)
"""
response = sample.response
ground_truth = sample.label or sample.metadata.get("answer", "")
# Example: exact match reward
if response.strip() == ground_truth.strip():
return 1.0
return 0.0
# For batched processing (more efficient)
async def batched_custom_rm(args, samples: list[Sample]) -> list[float]:
"""Batch reward computation."""
rewards = []
for sample in samples:
reward = await reward_func(args, sample)
rewards.append(reward)
return rewards
```
Usage:
```bash
python train.py \
--custom-rm-path custom_rm.py \
--group-rm # Enable batched processing
```
## Model Configuration
### Pre-configured Model Scripts
Located in `scripts/models/`:
```bash
# List available models
ls scripts/models/
# glm4-9B.sh, qwen3-4B.sh, qwen3-30B-A3B.sh, deepseek-v3.sh, llama3-8B.sh
# Source model configuration
source scripts/models/qwen3-4B.sh
# This sets MODEL_ARGS and CKPT_ARGS arrays
```
### Example Model Script
```bash
# scripts/models/qwen3-4B.sh
export MODEL_ARGS=(
--num-layers 36
--hidden-size 2560
--num-attention-heads 20
--num-query-groups 4
--ffn-hidden-size 6912
--max-position-embeddings 32768
--rotary-percent 1.0
--rotary-base 1000000
--swiglu
--untie-embeddings-and-output-weights
--no-position-embedding
--normalization RMSNorm
--tokenizer-type HuggingFaceTokenizer
--bf16
)
export CKPT_ARGS=(
--hf-checkpoint /path/to/qwen3-4b-hf
--initial-megatron-checkpoint /path/to/megatron/ckpt
)
```
## Async Training
### Enabling Async Mode
```bash
python train_async.py \
--actor-num-gpus-per-node 8 \
--rollout-num-gpus 8 \
--async-buffer-size 4 \
--update-weights-interval 2 \
${MODEL_ARGS[@]}
```
### Async-Specific Parameters
```bash
--async-buffer-size 4 # Number of rollouts to buffer
--update-weights-interval 2 # Sync weights every N rollouts
```
**Note**: Colocated mode (`--colocate`) is NOT supported with async training.
## Evaluation
### Multi-Task Evaluation
```bash
--eval-prompt-data aime /path/to/aime.jsonl \
--eval-prompt-data gsm8k /path/to/gsm8k.jsonl \
--n-samples-per-eval-prompt 16 \
--eval-interval 50
```
### Evaluation Configuration
```bash
--eval-interval 50 # Evaluate every N rollouts
--n-samples-per-eval-prompt 16 # Samples for evaluation
--eval-temperature 0.0 # Greedy decoding for eval
```
## Supported Models
| Model Family | Configurations |
|--------------|----------------|
| GLM | GLM-4.5, GLM-4.6, GLM-4.7, GLM-Z1-9B |
| Qwen | Qwen3 (4B, 8B, 30B-A3B), Qwen3-MoE, Qwen2.5 |
| DeepSeek | V3, V3.1, R1 |
| Llama | Llama 3 (8B, 70B) |
| Others | Kimi K2, Moonlight-16B |
## Resources
- Documentation: https://thudm.github.io/slime/
- GitHub: https://github.com/THUDM/slime
- Blog: https://lmsys.org/blog/2025-07-09-slime/
- Examples: `examples/` directory (14+ worked examples)

View file

@ -0,0 +1,386 @@
# slime Troubleshooting Guide
## Common Issues and Solutions
### SGLang Issues
#### Issue: SGLang Engine Crash
**Symptoms**: Inference engine dies mid-training, connection errors
**Solutions**:
1. **Enable fault tolerance**:
```bash
--use-fault-tolerance
```
2. **Increase memory allocation**:
```bash
--sglang-mem-fraction-static 0.85 # Increase from 0.8
```
3. **Reduce batch size**:
```bash
--rollout-batch-size 16 # Reduce from 32
```
4. **Disable CUDA graphs** (for debugging):
```bash
--sglang-disable-cuda-graph
```
#### Issue: SGLang Router Load Imbalance
**Symptoms**: Some SGLang engines overloaded while others idle
**Solutions**:
1. **Adjust routing strategy**:
```bash
--sglang-router-strategy round_robin
```
2. **Increase number of engines**:
```bash
--rollout-num-gpus-per-engine 1 # More engines, less GPUs each
```
### Weight Synchronization Issues
#### Issue: Weight Sync Timeout
**Symptoms**: Training hangs after rollout, timeout errors
**Solutions**:
1. **Increase sync interval** (async mode):
```bash
--update-weights-interval 5 # Increase from 2
```
2. **Use colocated mode** (eliminates network transfer):
```bash
--colocate
```
3. **Check network bandwidth**:
```bash
# Verify InfiniBand is enabled
ibstat
```
#### Issue: Weight Sync Failures in Multi-Node
**Symptoms**: Nodes fail to receive updated weights
**Solutions**:
1. **Set NCCL environment**:
```bash
export NCCL_DEBUG=INFO
export NCCL_SOCKET_IFNAME=eth0
export NCCL_IB_DISABLE=0
```
2. **Increase timeout**:
```bash
export NCCL_TIMEOUT=1800
```
### Memory Issues
#### Issue: OOM During Training
**Symptoms**: CUDA OOM in backward pass
**Solutions**:
1. **Enable gradient checkpointing**:
```bash
--recompute-activations
```
2. **Reduce micro-batch size**:
```bash
--micro-batch-size 1
```
3. **Enable sequence parallelism**:
```bash
--sequence-parallel
```
4. **Reduce global batch size**:
```bash
--global-batch-size 128 # Reduce from 256
```
#### Issue: OOM in Colocated Mode
**Symptoms**: OOM when both training and inference run on same GPUs
**Solutions**:
1. **Reduce SGLang memory**:
```bash
--sglang-mem-fraction-static 0.4 # Reduce from 0.8
```
2. **Enable offloading**:
```bash
--offload-optimizer-states
```
3. **Use smaller sequence length**:
```bash
--seq-length 2048 # Reduce from 4096
```
### Data Loading Issues
#### Issue: Slow Data Loading
**Symptoms**: GPU idle during data fetch, low GPU utilization
**Solutions**:
1. **Increase data workers**:
```bash
--num-data-workers 4
```
2. **Use streaming dataset**:
```bash
--streaming-data
```
3. **Pre-tokenize data**:
```python
# Pre-process data offline
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("model_path")
# Save tokenized data
```
#### Issue: Data Format Errors
**Symptoms**: KeyError, missing fields, parsing failures
**Solutions**:
1. **Verify data format**:
```python
import json
with open("data.jsonl") as f:
for line in f:
data = json.loads(line)
assert "prompt" in data, "Missing prompt field"
assert "label" in data, "Missing label field"
```
2. **Check key names**:
```bash
--input-key prompt # Must match your data
--label-key label # Must match your data
```
### Training Stability Issues
#### Issue: Loss Explosion / NaN
**Symptoms**: Loss becomes NaN or explodes
**Solutions**:
1. **Reduce learning rate**:
```bash
--lr 1e-6 # Reduce from 5e-6
```
2. **Enable gradient clipping**:
```bash
--clip-grad 1.0
```
3. **Check for data issues**:
```python
# Verify no empty prompts or responses
for sample in dataset:
assert len(sample["prompt"]) > 0
```
4. **Use BF16 instead of FP16**:
```bash
--bf16 # More numerically stable
```
#### Issue: Reward Collapse
**Symptoms**: Reward drops to zero, model outputs garbage
**Solutions**:
1. **Increase KL penalty**:
```bash
--kl-loss-coef 0.01 # Increase from 0.001
```
2. **Reduce number of samples**:
```bash
--n-samples-per-prompt 4 # Reduce from 8
```
3. **Verify reward function**:
```python
# Test reward function independently
from custom_rm import reward_func
sample = Sample(prompt="test", response="test response")
reward = reward_func(args, sample)
print(f"Reward: {reward}") # Should be reasonable
```
### Async Training Issues
#### Issue: Async Training Not Supported with Colocate
**Symptoms**: Error when using `--colocate` with `train_async.py`
**Solution**: Colocated mode is NOT supported for async training. Use separate GPUs:
```bash
# Remove --colocate flag
python train_async.py \
--actor-num-gpus-per-node 4 \
--rollout-num-gpus 4 \
# No --colocate
```
#### Issue: Stale Weights in Async Mode
**Symptoms**: Policy divergence, inconsistent behavior
**Solutions**:
1. **Reduce async buffer size**:
```bash
--async-buffer-size 2 # Reduce from 4
```
2. **Increase weight update frequency**:
```bash
--update-weights-interval 1 # Sync every rollout
```
### Multi-Turn Training Issues
#### Issue: Tool Responses Included in Loss
**Symptoms**: Model learns to output tool responses verbatim
**Solution**: Properly set loss mask in custom generate function:
```python
def build_loss_mask(sample):
"""Create loss mask that excludes tool responses."""
mask = []
for i, token in enumerate(sample.tokens):
if is_tool_response(token, sample.metadata):
mask.append(0) # Don't compute loss
else:
mask.append(1) # Compute loss
return mask
```
#### Issue: Multi-Turn Context Too Long
**Symptoms**: OOM or truncation in multi-turn conversations
**Solutions**:
1. **Limit conversation history**:
```python
# In custom generate function
conversation = sample.prompt[-10:] # Keep last 10 turns
```
2. **Increase context length**:
```bash
--sglang-context-length 16384
```
### Checkpoint Issues
#### Issue: Checkpoint Loading Fails
**Symptoms**: Cannot load saved checkpoint
**Solutions**:
1. **Verify checkpoint path**:
```bash
ls -la /path/to/checkpoint/
```
2. **Check parallelism matches**:
```bash
# Checkpoint was saved with TP=2, must load with TP=2
--tensor-model-parallel-size 2
```
3. **Convert HuggingFace to Megatron** (if needed):
```bash
python tools/convert_hf_to_megatron.py \
--hf_model_path /path/to/hf/model \
--save_path /path/to/megatron/checkpoint
```
### Debugging Tips
#### Enable Verbose Logging
```bash
--log-level DEBUG
export SLIME_DEBUG=1
```
#### Check GPU Utilization
```bash
watch -n 1 nvidia-smi
```
#### Monitor Training
```bash
tensorboard --logdir outputs/
```
#### Test Custom Functions Independently
```python
# Test reward function
import asyncio
from custom_rm import reward_func
async def test():
sample = Sample(prompt="test", response="test", label="expected")
reward = await reward_func(args, sample)
print(f"Reward: {reward}")
asyncio.run(test())
```
## Constraint Reference
Key constraint to remember:
```
rollout_batch_size × n_samples_per_prompt = global_batch_size × num_steps_per_rollout
```
Example: `32 × 8 = 256 × 1`
## Resources
- GitHub Issues: https://github.com/THUDM/slime/issues
- Documentation: https://thudm.github.io/slime/
- Examples: `examples/` directory

View file

@ -0,0 +1,190 @@
---
name: tensorrt-llm
description: Optimizes LLM inference with NVIDIA TensorRT for maximum throughput and lowest latency. Use for production deployment on NVIDIA GPUs (A100/H100), when you need 10-100x faster inference than PyTorch, or for serving models with quantization (FP8/INT4), in-flight batching, and multi-GPU scaling.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [tensorrt-llm, torch]
metadata:
hermes:
tags: [Inference Serving, TensorRT-LLM, NVIDIA, Inference Optimization, High Throughput, Low Latency, Production, FP8, INT4, In-Flight Batching, Multi-GPU]
---
# TensorRT-LLM
NVIDIA's open-source library for optimizing LLM inference with state-of-the-art performance on NVIDIA GPUs.
## When to use TensorRT-LLM
**Use TensorRT-LLM when:**
- Deploying on NVIDIA GPUs (A100, H100, GB200)
- Need maximum throughput (24,000+ tokens/sec on Llama 3)
- Require low latency for real-time applications
- Working with quantized models (FP8, INT4, FP4)
- Scaling across multiple GPUs or nodes
**Use vLLM instead when:**
- Need simpler setup and Python-first API
- Want PagedAttention without TensorRT compilation
- Working with AMD GPUs or non-NVIDIA hardware
**Use llama.cpp instead when:**
- Deploying on CPU or Apple Silicon
- Need edge deployment without NVIDIA GPUs
- Want simpler GGUF quantization format
## Quick start
### Installation
```bash
# Docker (recommended)
docker pull nvidia/tensorrt_llm:latest
# pip install
pip install tensorrt_llm==1.2.0rc3
# Requires CUDA 13.0.0, TensorRT 10.13.2, Python 3.10-3.12
```
### Basic inference
```python
from tensorrt_llm import LLM, SamplingParams
# Initialize model
llm = LLM(model="meta-llama/Meta-Llama-3-8B")
# Configure sampling
sampling_params = SamplingParams(
max_tokens=100,
temperature=0.7,
top_p=0.9
)
# Generate
prompts = ["Explain quantum computing"]
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(output.text)
```
### Serving with trtllm-serve
```bash
# Start server (automatic model download and compilation)
trtllm-serve meta-llama/Meta-Llama-3-8B \
--tp_size 4 \ # Tensor parallelism (4 GPUs)
--max_batch_size 256 \
--max_num_tokens 4096
# Client request
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Meta-Llama-3-8B",
"messages": [{"role": "user", "content": "Hello!"}],
"temperature": 0.7,
"max_tokens": 100
}'
```
## Key features
### Performance optimizations
- **In-flight batching**: Dynamic batching during generation
- **Paged KV cache**: Efficient memory management
- **Flash Attention**: Optimized attention kernels
- **Quantization**: FP8, INT4, FP4 for 2-4× faster inference
- **CUDA graphs**: Reduced kernel launch overhead
### Parallelism
- **Tensor parallelism (TP)**: Split model across GPUs
- **Pipeline parallelism (PP)**: Layer-wise distribution
- **Expert parallelism**: For Mixture-of-Experts models
- **Multi-node**: Scale beyond single machine
### Advanced features
- **Speculative decoding**: Faster generation with draft models
- **LoRA serving**: Efficient multi-adapter deployment
- **Disaggregated serving**: Separate prefill and generation
## Common patterns
### Quantized model (FP8)
```python
from tensorrt_llm import LLM
# Load FP8 quantized model (2× faster, 50% memory)
llm = LLM(
model="meta-llama/Meta-Llama-3-70B",
dtype="fp8",
max_num_tokens=8192
)
# Inference same as before
outputs = llm.generate(["Summarize this article..."])
```
### Multi-GPU deployment
```python
# Tensor parallelism across 8 GPUs
llm = LLM(
model="meta-llama/Meta-Llama-3-405B",
tensor_parallel_size=8,
dtype="fp8"
)
```
### Batch inference
```python
# Process 100 prompts efficiently
prompts = [f"Question {i}: ..." for i in range(100)]
outputs = llm.generate(
prompts,
sampling_params=SamplingParams(max_tokens=200)
)
# Automatic in-flight batching for maximum throughput
```
## Performance benchmarks
**Meta Llama 3-8B** (H100 GPU):
- Throughput: 24,000 tokens/sec
- Latency: ~10ms per token
- vs PyTorch: **100× faster**
**Llama 3-70B** (8× A100 80GB):
- FP8 quantization: 2× faster than FP16
- Memory: 50% reduction with FP8
## Supported models
- **LLaMA family**: Llama 2, Llama 3, CodeLlama
- **GPT family**: GPT-2, GPT-J, GPT-NeoX
- **Qwen**: Qwen, Qwen2, QwQ
- **DeepSeek**: DeepSeek-V2, DeepSeek-V3
- **Mixtral**: Mixtral-8x7B, Mixtral-8x22B
- **Vision**: LLaVA, Phi-3-vision
- **100+ models** on HuggingFace
## References
- **[Optimization Guide](references/optimization.md)** - Quantization, batching, KV cache tuning
- **[Multi-GPU Setup](references/multi-gpu.md)** - Tensor/pipeline parallelism, multi-node
- **[Serving Guide](references/serving.md)** - Production deployment, monitoring, autoscaling
## Resources
- **Docs**: https://nvidia.github.io/TensorRT-LLM/
- **GitHub**: https://github.com/NVIDIA/TensorRT-LLM
- **Models**: https://huggingface.co/models?library=tensorrt_llm

View file

@ -0,0 +1,298 @@
# Multi-GPU Deployment Guide
Comprehensive guide to scaling TensorRT-LLM across multiple GPUs and nodes.
## Parallelism Strategies
### Tensor Parallelism (TP)
**What it does**: Splits model layers across GPUs horizontally.
**Use case**:
- Model fits in total GPU memory but not single GPU
- Need low latency (single forward pass)
- GPUs on same node (NVLink required for best performance)
**Example** (Llama 3-70B on 4× A100):
```python
from tensorrt_llm import LLM
llm = LLM(
model="meta-llama/Meta-Llama-3-70B",
tensor_parallel_size=4, # Split across 4 GPUs
dtype="fp16"
)
# Model automatically sharded across GPUs
# Single forward pass, low latency
```
**Performance**:
- Latency: ~Same as single GPU
- Throughput: 4× higher (4 GPUs)
- Communication: High (activations synced every layer)
### Pipeline Parallelism (PP)
**What it does**: Splits model layers across GPUs vertically (layer-wise).
**Use case**:
- Very large models (175B+)
- Can tolerate higher latency
- GPUs across multiple nodes
**Example** (Llama 3-405B on 8× H100):
```python
llm = LLM(
model="meta-llama/Meta-Llama-3-405B",
tensor_parallel_size=4, # TP=4 within nodes
pipeline_parallel_size=2, # PP=2 across nodes
dtype="fp8"
)
# Total: 8 GPUs (4×2)
# Layers 0-40: Node 1 (4 GPUs with TP)
# Layers 41-80: Node 2 (4 GPUs with TP)
```
**Performance**:
- Latency: Higher (sequential through pipeline)
- Throughput: High with micro-batching
- Communication: Lower than TP
### Expert Parallelism (EP)
**What it does**: Distributes MoE experts across GPUs.
**Use case**: Mixture-of-Experts models (Mixtral, DeepSeek-V2)
**Example** (Mixtral-8x22B on 8× A100):
```python
llm = LLM(
model="mistralai/Mixtral-8x22B",
tensor_parallel_size=4,
expert_parallel_size=2, # Distribute 8 experts across 2 groups
dtype="fp8"
)
```
## Configuration Examples
### Small model (7-13B) - Single GPU
```python
# Llama 3-8B on 1× A100 80GB
llm = LLM(
model="meta-llama/Meta-Llama-3-8B",
dtype="fp16" # or fp8 for H100
)
```
**Resources**:
- GPU: 1× A100 80GB
- Memory: ~16GB model + 30GB KV cache
- Throughput: 3,000-5,000 tokens/sec
### Medium model (70B) - Multi-GPU same node
```python
# Llama 3-70B on 4× A100 80GB (NVLink)
llm = LLM(
model="meta-llama/Meta-Llama-3-70B",
tensor_parallel_size=4,
dtype="fp8" # 70GB → 35GB per GPU
)
```
**Resources**:
- GPU: 4× A100 80GB with NVLink
- Memory: ~35GB per GPU (FP8)
- Throughput: 10,000-15,000 tokens/sec
- Latency: 15-20ms per token
### Large model (405B) - Multi-node
```python
# Llama 3-405B on 2 nodes × 8 H100 = 16 GPUs
llm = LLM(
model="meta-llama/Meta-Llama-3-405B",
tensor_parallel_size=8, # TP within each node
pipeline_parallel_size=2, # PP across 2 nodes
dtype="fp8"
)
```
**Resources**:
- GPU: 2 nodes × 8 H100 80GB
- Memory: ~25GB per GPU (FP8)
- Throughput: 20,000-30,000 tokens/sec
- Network: InfiniBand recommended
## Server Deployment
### Single-node multi-GPU
```bash
# Llama 3-70B on 4 GPUs (automatic TP)
trtllm-serve meta-llama/Meta-Llama-3-70B \
--tp_size 4 \
--max_batch_size 256 \
--dtype fp8
# Listens on http://localhost:8000
```
### Multi-node with Ray
```bash
# Node 1 (head node)
ray start --head --port=6379
# Node 2 (worker)
ray start --address='node1:6379'
# Deploy across cluster
trtllm-serve meta-llama/Meta-Llama-3-405B \
--tp_size 8 \
--pp_size 2 \
--num_workers 2 \ # 2 nodes
--dtype fp8
```
### Kubernetes deployment
```yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: tensorrt-llm-llama3-70b
spec:
replicas: 1
template:
spec:
containers:
- name: trtllm
image: nvidia/tensorrt_llm:latest
command:
- trtllm-serve
- meta-llama/Meta-Llama-3-70B
- --tp_size=4
- --max_batch_size=256
resources:
limits:
nvidia.com/gpu: 4 # Request 4 GPUs
```
## Parallelism Decision Tree
```
Model size < 20GB?
├─ YES: Single GPU (no parallelism)
└─ NO: Model size < 80GB?
├─ YES: TP=2 or TP=4 (same node)
└─ NO: Model size < 320GB?
├─ YES: TP=4 or TP=8 (same node, NVLink required)
└─ NO: TP=8 + PP=2 (multi-node)
```
## Communication Optimization
### NVLink vs PCIe
**NVLink** (DGX A100, HGX H100):
- Bandwidth: 600 GB/s (A100), 900 GB/s (H100)
- Ideal for TP (high communication)
- **Recommended for all multi-GPU setups**
**PCIe**:
- Bandwidth: 64 GB/s (PCIe 4.0 x16)
- 10× slower than NVLink
- Avoid TP, use PP instead
### InfiniBand for multi-node
**HDR InfiniBand** (200 Gb/s):
- Required for multi-node TP or PP
- Latency: <1μs
- **Essential for 405B+ models**
## Monitoring Multi-GPU
```python
# Monitor GPU utilization
nvidia-smi dmon -s u
# Monitor memory
nvidia-smi dmon -s m
# Monitor NVLink utilization
nvidia-smi nvlink --status
# TensorRT-LLM built-in metrics
curl http://localhost:8000/metrics
```
**Key metrics**:
- GPU utilization: Target 80-95%
- Memory usage: Should be balanced across GPUs
- NVLink traffic: High for TP, low for PP
- Throughput: Tokens/sec across all GPUs
## Common Issues
### Imbalanced GPU memory
**Symptom**: GPU 0 has 90% memory, GPU 3 has 40%
**Solutions**:
- Verify TP/PP configuration
- Check model sharding (should be equal)
- Restart server to reset state
### Low NVLink utilization
**Symptom**: NVLink bandwidth <100 GB/s with TP=4
**Solutions**:
- Verify NVLink topology: `nvidia-smi topo -m`
- Check for PCIe fallback
- Ensure GPUs are on same NVSwitch
### OOM with multi-GPU
**Solutions**:
- Increase TP size (more GPUs)
- Reduce batch size
- Enable FP8 quantization
- Use pipeline parallelism
## Performance Scaling
### TP Scaling (Llama 3-70B, FP8)
| GPUs | TP Size | Throughput | Latency | Efficiency |
|------|---------|------------|---------|------------|
| 1 | 1 | OOM | - | - |
| 2 | 2 | 6,000 tok/s | 18ms | 85% |
| 4 | 4 | 11,000 tok/s | 16ms | 78% |
| 8 | 8 | 18,000 tok/s | 15ms | 64% |
**Note**: Efficiency drops with more GPUs due to communication overhead.
### PP Scaling (Llama 3-405B, FP8)
| Nodes | TP | PP | Total GPUs | Throughput |
|-------|----|----|------------|------------|
| 1 | 8 | 1 | 8 | OOM |
| 2 | 8 | 2 | 16 | 25,000 tok/s |
| 4 | 8 | 4 | 32 | 45,000 tok/s |
## Best Practices
1. **Prefer TP over PP** when possible (lower latency)
2. **Use NVLink** for all TP deployments
3. **Use InfiniBand** for multi-node deployments
4. **Start with smallest TP** that fits model in memory
5. **Monitor GPU balance** - all GPUs should have similar utilization
6. **Test with benchmark** before production
7. **Use FP8** on H100 for 2× speedup

View file

@ -0,0 +1,242 @@
# TensorRT-LLM Optimization Guide
Comprehensive guide to optimizing LLM inference with TensorRT-LLM.
## Quantization
### FP8 Quantization (Recommended for H100)
**Benefits**:
- 2× faster inference
- 50% memory reduction
- Minimal accuracy loss (<1% perplexity degradation)
**Usage**:
```python
from tensorrt_llm import LLM
# Automatic FP8 quantization
llm = LLM(
model="meta-llama/Meta-Llama-3-70B",
dtype="fp8",
quantization="fp8"
)
```
**Performance** (Llama 3-70B on 8× H100):
- FP16: 5,000 tokens/sec
- FP8: **10,000 tokens/sec** (2× speedup)
- Memory: 140GB → 70GB
### INT4 Quantization (Maximum compression)
**Benefits**:
- 4× memory reduction
- 3-4× faster inference
- Fits larger models on same hardware
**Usage**:
```python
# INT4 with AWQ calibration
llm = LLM(
model="meta-llama/Meta-Llama-3-405B",
dtype="int4_awq",
quantization="awq"
)
# INT4 with GPTQ calibration
llm = LLM(
model="meta-llama/Meta-Llama-3-405B",
dtype="int4_gptq",
quantization="gptq"
)
```
**Trade-offs**:
- Accuracy: 1-3% perplexity increase
- Speed: 3-4× faster than FP16
- Use case: When memory is critical
## In-Flight Batching
**What it does**: Dynamically batches requests during generation instead of waiting for all sequences to finish.
**Configuration**:
```python
# Server configuration
trtllm-serve meta-llama/Meta-Llama-3-8B \
--max_batch_size 256 \ # Maximum concurrent sequences
--max_num_tokens 4096 \ # Total tokens in batch
--enable_chunked_context \ # Split long prompts
--scheduler_policy max_utilization
```
**Performance**:
- Throughput: **4-8× higher** vs static batching
- Latency: Lower P50/P99 for mixed workloads
- GPU utilization: 80-95% vs 40-60%
## Paged KV Cache
**What it does**: Manages KV cache memory like OS manages virtual memory (paging).
**Benefits**:
- 40-60% higher throughput
- No memory fragmentation
- Supports longer sequences
**Configuration**:
```python
# Automatic paged KV cache (default)
llm = LLM(
model="meta-llama/Meta-Llama-3-8B",
kv_cache_free_gpu_mem_fraction=0.9, # Use 90% GPU mem for cache
enable_prefix_caching=True # Cache common prefixes
)
```
## Speculative Decoding
**What it does**: Uses small draft model to predict multiple tokens, verified by target model in parallel.
**Speedup**: 2-3× faster for long generations
**Usage**:
```python
from tensorrt_llm import LLM
# Target model (Llama 3-70B)
llm = LLM(
model="meta-llama/Meta-Llama-3-70B",
speculative_model="meta-llama/Meta-Llama-3-8B", # Draft model
num_speculative_tokens=5 # Tokens to predict ahead
)
# Same API, 2-3× faster
outputs = llm.generate(prompts)
```
**Best models for drafting**:
- Target: Llama 3-70B → Draft: Llama 3-8B
- Target: Qwen2-72B → Draft: Qwen2-7B
- Same family, 8-10× smaller
## CUDA Graphs
**What it does**: Reduces kernel launch overhead by recording GPU operations.
**Benefits**:
- 10-20% lower latency
- More stable P99 latency
- Better for small batch sizes
**Configuration** (automatic by default):
```python
llm = LLM(
model="meta-llama/Meta-Llama-3-8B",
enable_cuda_graph=True, # Default: True
cuda_graph_cache_size=2 # Cache 2 graph variants
)
```
## Chunked Context
**What it does**: Splits long prompts into chunks to reduce memory spikes.
**Use case**: Prompts >8K tokens with limited GPU memory
**Configuration**:
```bash
trtllm-serve meta-llama/Meta-Llama-3-8B \
--max_num_tokens 4096 \
--enable_chunked_context \
--max_chunked_prefill_length 2048 # Process 2K tokens at a time
```
## Overlap Scheduling
**What it does**: Overlaps compute and memory operations.
**Benefits**:
- 15-25% higher throughput
- Better GPU utilization
- Default in v1.2.0+
**No configuration needed** - enabled automatically.
## Quantization Comparison Table
| Method | Memory | Speed | Accuracy | Use Case |
|--------|--------|-------|----------|----------|
| FP16 | 1× (baseline) | 1× | Best | High accuracy needed |
| FP8 | 0.5× | 2× | -0.5% ppl | **H100 default** |
| INT4 AWQ | 0.25× | 3-4× | -1.5% ppl | Memory critical |
| INT4 GPTQ | 0.25× | 3-4× | -2% ppl | Maximum speed |
## Tuning Workflow
1. **Start with defaults**:
```python
llm = LLM(model="meta-llama/Meta-Llama-3-70B")
```
2. **Enable FP8** (if H100):
```python
llm = LLM(model="...", dtype="fp8")
```
3. **Tune batch size**:
```python
# Increase until OOM, then reduce 20%
trtllm-serve ... --max_batch_size 256
```
4. **Enable chunked context** (if long prompts):
```bash
--enable_chunked_context --max_chunked_prefill_length 2048
```
5. **Try speculative decoding** (if latency critical):
```python
llm = LLM(model="...", speculative_model="...")
```
## Benchmarking
```bash
# Install benchmark tool
pip install tensorrt_llm[benchmark]
# Run benchmark
python benchmarks/python/benchmark.py \
--model meta-llama/Meta-Llama-3-8B \
--batch_size 64 \
--input_len 128 \
--output_len 256 \
--dtype fp8
```
**Metrics to track**:
- Throughput (tokens/sec)
- Latency P50/P90/P99 (ms)
- GPU memory usage (GB)
- GPU utilization (%)
## Common Issues
**OOM errors**:
- Reduce `max_batch_size`
- Reduce `max_num_tokens`
- Enable INT4 quantization
- Increase `tensor_parallel_size`
**Low throughput**:
- Increase `max_batch_size`
- Enable in-flight batching
- Verify CUDA graphs enabled
- Check GPU utilization
**High latency**:
- Try speculative decoding
- Reduce `max_batch_size` (less queueing)
- Use FP8 instead of FP16

View file

@ -0,0 +1,470 @@
# Production Serving Guide
Comprehensive guide to deploying TensorRT-LLM in production environments.
## Server Modes
### trtllm-serve (Recommended)
**Features**:
- OpenAI-compatible API
- Automatic model download and compilation
- Built-in load balancing
- Prometheus metrics
- Health checks
**Basic usage**:
```bash
trtllm-serve meta-llama/Meta-Llama-3-8B \
--tp_size 1 \
--max_batch_size 256 \
--port 8000
```
**Advanced configuration**:
```bash
trtllm-serve meta-llama/Meta-Llama-3-70B \
--tp_size 4 \
--dtype fp8 \
--max_batch_size 256 \
--max_num_tokens 4096 \
--enable_chunked_context \
--scheduler_policy max_utilization \
--port 8000 \
--api_key $API_KEY # Optional authentication
```
### Python LLM API (For embedding)
```python
from tensorrt_llm import LLM
class LLMService:
def __init__(self):
self.llm = LLM(
model="meta-llama/Meta-Llama-3-8B",
dtype="fp8"
)
def generate(self, prompt, max_tokens=100):
from tensorrt_llm import SamplingParams
params = SamplingParams(
max_tokens=max_tokens,
temperature=0.7
)
outputs = self.llm.generate([prompt], params)
return outputs[0].text
# Use in FastAPI, Flask, etc
from fastapi import FastAPI
app = FastAPI()
service = LLMService()
@app.post("/generate")
def generate(prompt: str):
return {"response": service.generate(prompt)}
```
## OpenAI-Compatible API
### Chat Completions
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Meta-Llama-3-8B",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Explain quantum computing"}
],
"temperature": 0.7,
"max_tokens": 500,
"stream": false
}'
```
**Response**:
```json
{
"id": "chat-abc123",
"object": "chat.completion",
"created": 1234567890,
"model": "meta-llama/Meta-Llama-3-8B",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Quantum computing is..."
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 25,
"completion_tokens": 150,
"total_tokens": 175
}
}
```
### Streaming
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Meta-Llama-3-8B",
"messages": [{"role": "user", "content": "Count to 10"}],
"stream": true
}'
```
**Response** (SSE stream):
```
data: {"choices":[{"delta":{"content":"1"}}]}
data: {"choices":[{"delta":{"content":", 2"}}]}
data: {"choices":[{"delta":{"content":", 3"}}]}
data: [DONE]
```
### Completions
```bash
curl -X POST http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Meta-Llama-3-8B",
"prompt": "The capital of France is",
"max_tokens": 10,
"temperature": 0.0
}'
```
## Monitoring
### Prometheus Metrics
**Enable metrics**:
```bash
trtllm-serve meta-llama/Meta-Llama-3-8B \
--enable_metrics \
--metrics_port 9090
```
**Key metrics**:
```bash
# Scrape metrics
curl http://localhost:9090/metrics
# Important metrics:
# - trtllm_request_success_total - Total successful requests
# - trtllm_request_latency_seconds - Request latency histogram
# - trtllm_tokens_generated_total - Total tokens generated
# - trtllm_active_requests - Current active requests
# - trtllm_queue_size - Requests waiting in queue
# - trtllm_gpu_memory_usage_bytes - GPU memory usage
# - trtllm_kv_cache_usage_ratio - KV cache utilization
```
### Health Checks
```bash
# Readiness probe
curl http://localhost:8000/health/ready
# Liveness probe
curl http://localhost:8000/health/live
# Model info
curl http://localhost:8000/v1/models
```
**Kubernetes probes**:
```yaml
livenessProbe:
httpGet:
path: /health/live
port: 8000
initialDelaySeconds: 60
periodSeconds: 10
readinessProbe:
httpGet:
path: /health/ready
port: 8000
initialDelaySeconds: 30
periodSeconds: 5
```
## Production Deployment
### Docker Deployment
**Dockerfile**:
```dockerfile
FROM nvidia/tensorrt_llm:latest
# Copy any custom configs
COPY config.yaml /app/config.yaml
# Expose ports
EXPOSE 8000 9090
# Start server
CMD ["trtllm-serve", "meta-llama/Meta-Llama-3-8B", \
"--tp_size", "4", \
"--dtype", "fp8", \
"--max_batch_size", "256", \
"--enable_metrics", \
"--metrics_port", "9090"]
```
**Run container**:
```bash
docker run --gpus all -p 8000:8000 -p 9090:9090 \
tensorrt-llm:latest
```
### Kubernetes Deployment
**Complete deployment**:
```yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: tensorrt-llm
spec:
replicas: 2 # Multiple replicas for HA
selector:
matchLabels:
app: tensorrt-llm
template:
metadata:
labels:
app: tensorrt-llm
spec:
containers:
- name: trtllm
image: nvidia/tensorrt_llm:latest
command:
- trtllm-serve
- meta-llama/Meta-Llama-3-70B
- --tp_size=4
- --dtype=fp8
- --max_batch_size=256
- --enable_metrics
ports:
- containerPort: 8000
name: http
- containerPort: 9090
name: metrics
resources:
limits:
nvidia.com/gpu: 4
livenessProbe:
httpGet:
path: /health/live
port: 8000
readinessProbe:
httpGet:
path: /health/ready
port: 8000
---
apiVersion: v1
kind: Service
metadata:
name: tensorrt-llm
spec:
selector:
app: tensorrt-llm
ports:
- name: http
port: 80
targetPort: 8000
- name: metrics
port: 9090
targetPort: 9090
type: LoadBalancer
```
### Load Balancing
**NGINX configuration**:
```nginx
upstream tensorrt_llm {
least_conn; # Route to least busy server
server trtllm-1:8000 max_fails=3 fail_timeout=30s;
server trtllm-2:8000 max_fails=3 fail_timeout=30s;
server trtllm-3:8000 max_fails=3 fail_timeout=30s;
}
server {
listen 80;
location / {
proxy_pass http://tensorrt_llm;
proxy_read_timeout 300s; # Long timeout for slow generations
proxy_connect_timeout 10s;
}
}
```
## Autoscaling
### Horizontal Pod Autoscaler (HPA)
```yaml
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: tensorrt-llm-hpa
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: tensorrt-llm
minReplicas: 2
maxReplicas: 10
metrics:
- type: Pods
pods:
metric:
name: trtllm_active_requests
target:
type: AverageValue
averageValue: "50" # Scale when avg >50 active requests
```
### Custom Metrics
```yaml
# Scale based on queue size
- type: Pods
pods:
metric:
name: trtllm_queue_size
target:
type: AverageValue
averageValue: "10"
```
## Cost Optimization
### GPU Selection
**A100 80GB** ($3-4/hour):
- Use for: 70B models with FP8
- Throughput: 10,000-15,000 tok/s (TP=4)
- Cost per 1M tokens: $0.20-0.30
**H100 80GB** ($6-8/hour):
- Use for: 70B models with FP8, 405B models
- Throughput: 20,000-30,000 tok/s (TP=4)
- Cost per 1M tokens: $0.15-0.25 (2× faster = lower cost)
**L4** ($0.50-1/hour):
- Use for: 7-8B models
- Throughput: 1,000-2,000 tok/s
- Cost per 1M tokens: $0.25-0.50
### Batch Size Tuning
**Impact on cost**:
- Batch size 1: 1,000 tok/s → $3/hour per 1M = $3/M tokens
- Batch size 64: 5,000 tok/s → $3/hour per 5M = $0.60/M tokens
- **5× cost reduction** with batching
**Recommendation**: Target batch size 32-128 for cost efficiency.
## Security
### API Authentication
```bash
# Generate API key
export API_KEY=$(openssl rand -hex 32)
# Start server with authentication
trtllm-serve meta-llama/Meta-Llama-3-8B \
--api_key $API_KEY
# Client request
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Authorization: Bearer $API_KEY" \
-H "Content-Type: application/json" \
-d '{"model": "...", "messages": [...]}'
```
### Network Policies
```yaml
apiVersion: networking.k8s.io/v1
kind: NetworkPolicy
metadata:
name: tensorrt-llm-policy
spec:
podSelector:
matchLabels:
app: tensorrt-llm
policyTypes:
- Ingress
ingress:
- from:
- podSelector:
matchLabels:
app: api-gateway # Only allow from gateway
ports:
- protocol: TCP
port: 8000
```
## Troubleshooting
### High latency
**Diagnosis**:
```bash
# Check queue size
curl http://localhost:9090/metrics | grep queue_size
# Check active requests
curl http://localhost:9090/metrics | grep active_requests
```
**Solutions**:
- Scale horizontally (more replicas)
- Increase batch size (if GPU underutilized)
- Enable chunked context (if long prompts)
- Use FP8 quantization
### OOM crashes
**Solutions**:
- Reduce `max_batch_size`
- Reduce `max_num_tokens`
- Enable FP8 or INT4 quantization
- Increase `tensor_parallel_size`
### Timeout errors
**NGINX config**:
```nginx
proxy_read_timeout 600s; # 10 minutes for very long generations
proxy_send_timeout 600s;
```
## Best Practices
1. **Use FP8 on H100** for 2× speedup and 50% cost reduction
2. **Monitor metrics** - Set up Prometheus + Grafana
3. **Set readiness probes** - Prevent routing to unhealthy pods
4. **Use load balancing** - Distribute load across replicas
5. **Tune batch size** - Balance latency and throughput
6. **Enable streaming** - Better UX for chat applications
7. **Set up autoscaling** - Handle traffic spikes
8. **Use persistent volumes** - Cache compiled models
9. **Implement retries** - Handle transient failures
10. **Monitor costs** - Track cost per token

View file

@ -0,0 +1,361 @@
---
name: distributed-llm-pretraining-torchtitan
description: Provides PyTorch-native distributed LLM pretraining using torchtitan with 4D parallelism (FSDP2, TP, PP, CP). Use when pretraining Llama 3.1, DeepSeek V3, or custom models at scale from 8 to 512+ GPUs with Float8, torch.compile, and distributed checkpointing.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [torch>=2.6.0, torchtitan>=0.2.0, torchao>=0.5.0]
metadata:
hermes:
tags: [Model Architecture, Distributed Training, TorchTitan, FSDP2, Tensor Parallel, Pipeline Parallel, Context Parallel, Float8, Llama, Pretraining]
---
# TorchTitan - PyTorch Native Distributed LLM Pretraining
## Quick start
TorchTitan is PyTorch's official platform for large-scale LLM pretraining with composable 4D parallelism (FSDP2, TP, PP, CP), achieving 65%+ speedups over baselines on H100 GPUs.
**Installation**:
```bash
# From PyPI (stable)
pip install torchtitan
# From source (latest features, requires PyTorch nightly)
git clone https://github.com/pytorch/torchtitan
cd torchtitan
pip install -r requirements.txt
```
**Download tokenizer**:
```bash
# Get HF token from https://huggingface.co/settings/tokens
python scripts/download_hf_assets.py --repo_id meta-llama/Llama-3.1-8B --assets tokenizer --hf_token=...
```
**Start training on 8 GPUs**:
```bash
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh
```
## Common workflows
### Workflow 1: Pretrain Llama 3.1 8B on single node
Copy this checklist:
```
Single Node Pretraining:
- [ ] Step 1: Download tokenizer
- [ ] Step 2: Configure training
- [ ] Step 3: Launch training
- [ ] Step 4: Monitor and checkpoint
```
**Step 1: Download tokenizer**
```bash
python scripts/download_hf_assets.py \
--repo_id meta-llama/Llama-3.1-8B \
--assets tokenizer \
--hf_token=YOUR_HF_TOKEN
```
**Step 2: Configure training**
Edit or create a TOML config file:
```toml
# llama3_8b_custom.toml
[job]
dump_folder = "./outputs"
description = "Llama 3.1 8B training"
[model]
name = "llama3"
flavor = "8B"
hf_assets_path = "./assets/hf/Llama-3.1-8B"
[optimizer]
name = "AdamW"
lr = 3e-4
[lr_scheduler]
warmup_steps = 200
[training]
local_batch_size = 2
seq_len = 8192
max_norm = 1.0
steps = 1000
dataset = "c4"
[parallelism]
data_parallel_shard_degree = -1 # Use all GPUs for FSDP
[activation_checkpoint]
mode = "selective"
selective_ac_option = "op"
[checkpoint]
enable = true
folder = "checkpoint"
interval = 500
```
**Step 3: Launch training**
```bash
# 8 GPUs on single node
CONFIG_FILE="./llama3_8b_custom.toml" ./run_train.sh
# Or explicitly with torchrun
torchrun --nproc_per_node=8 \
-m torchtitan.train \
--job.config_file ./llama3_8b_custom.toml
```
**Step 4: Monitor and checkpoint**
TensorBoard logs are saved to `./outputs/tb/`:
```bash
tensorboard --logdir ./outputs/tb
```
### Workflow 2: Multi-node training with SLURM
```
Multi-Node Training:
- [ ] Step 1: Configure parallelism for scale
- [ ] Step 2: Set up SLURM script
- [ ] Step 3: Submit job
- [ ] Step 4: Resume from checkpoint
```
**Step 1: Configure parallelism for scale**
For 70B model on 256 GPUs (32 nodes):
```toml
[parallelism]
data_parallel_shard_degree = 32 # FSDP across 32 ranks
tensor_parallel_degree = 8 # TP within node
pipeline_parallel_degree = 1 # No PP for 70B
context_parallel_degree = 1 # Increase for long sequences
```
**Step 2: Set up SLURM script**
```bash
#!/bin/bash
#SBATCH --job-name=llama70b
#SBATCH --nodes=32
#SBATCH --ntasks-per-node=8
#SBATCH --gpus-per-node=8
srun torchrun \
--nnodes=32 \
--nproc_per_node=8 \
--rdzv_backend=c10d \
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
-m torchtitan.train \
--job.config_file ./llama3_70b.toml
```
**Step 3: Submit job**
```bash
sbatch multinode_trainer.slurm
```
**Step 4: Resume from checkpoint**
Training auto-resumes if checkpoint exists in configured folder.
### Workflow 3: Enable Float8 training for H100s
Float8 provides 30-50% speedup on H100 GPUs.
```
Float8 Training:
- [ ] Step 1: Install torchao
- [ ] Step 2: Configure Float8
- [ ] Step 3: Launch with compile
```
**Step 1: Install torchao**
```bash
USE_CPP=0 pip install git+https://github.com/pytorch/ao.git
```
**Step 2: Configure Float8**
Add to your TOML config:
```toml
[model]
converters = ["quantize.linear.float8"]
[quantize.linear.float8]
enable_fsdp_float8_all_gather = true
precompute_float8_dynamic_scale_for_fsdp = true
filter_fqns = ["output"] # Exclude output layer
[compile]
enable = true
components = ["model", "loss"]
```
**Step 3: Launch with compile**
```bash
CONFIG_FILE="./llama3_8b.toml" ./run_train.sh \
--model.converters="quantize.linear.float8" \
--quantize.linear.float8.enable_fsdp_float8_all_gather \
--compile.enable
```
### Workflow 4: 4D parallelism for 405B models
```
4D Parallelism (FSDP + TP + PP + CP):
- [ ] Step 1: Create seed checkpoint
- [ ] Step 2: Configure 4D parallelism
- [ ] Step 3: Launch on 512 GPUs
```
**Step 1: Create seed checkpoint**
Required for consistent initialization across PP stages:
```bash
NGPU=1 CONFIG_FILE=./llama3_405b.toml ./run_train.sh \
--checkpoint.enable \
--checkpoint.create_seed_checkpoint \
--parallelism.data_parallel_shard_degree 1 \
--parallelism.tensor_parallel_degree 1 \
--parallelism.pipeline_parallel_degree 1
```
**Step 2: Configure 4D parallelism**
```toml
[parallelism]
data_parallel_shard_degree = 8 # FSDP
tensor_parallel_degree = 8 # TP within node
pipeline_parallel_degree = 8 # PP across nodes
context_parallel_degree = 1 # CP for long sequences
[training]
local_batch_size = 32
seq_len = 8192
```
**Step 3: Launch on 512 GPUs**
```bash
# 64 nodes x 8 GPUs = 512 GPUs
srun torchrun --nnodes=64 --nproc_per_node=8 \
-m torchtitan.train \
--job.config_file ./llama3_405b.toml
```
## When to use vs alternatives
**Use TorchTitan when:**
- Pretraining LLMs from scratch (8B to 405B+)
- Need PyTorch-native solution without third-party dependencies
- Require composable 4D parallelism (FSDP2, TP, PP, CP)
- Training on H100s with Float8 support
- Want interoperable checkpoints with torchtune/HuggingFace
**Use alternatives instead:**
- **Megatron-LM**: Maximum performance for NVIDIA-only deployments
- **DeepSpeed**: Broader ZeRO optimization ecosystem, inference support
- **Axolotl/TRL**: Fine-tuning rather than pretraining
- **LitGPT**: Educational, smaller-scale training
## Common issues
**Issue: Out of memory on large models**
Enable activation checkpointing and reduce batch size:
```toml
[activation_checkpoint]
mode = "full" # Instead of "selective"
[training]
local_batch_size = 1
```
Or use gradient accumulation:
```toml
[training]
local_batch_size = 1
global_batch_size = 32 # Accumulates gradients
```
**Issue: TP causes high memory with async collectives**
Set environment variable:
```bash
export TORCH_NCCL_AVOID_RECORD_STREAMS=1
```
**Issue: Float8 training not faster**
Float8 only benefits large GEMMs. Filter small layers:
```toml
[quantize.linear.float8]
filter_fqns = ["attention.wk", "attention.wv", "output", "auto_filter_small_kn"]
```
**Issue: Checkpoint loading fails after parallelism change**
Use DCP's resharding capability:
```bash
# Convert sharded checkpoint to single file
python -m torch.distributed.checkpoint.format_utils \
dcp_to_torch checkpoint/step-1000 checkpoint.pt
```
**Issue: Pipeline parallelism initialization**
Create seed checkpoint first (see Workflow 4, Step 1).
## Supported models
| Model | Sizes | Status |
|-------|-------|--------|
| Llama 3.1 | 8B, 70B, 405B | Production |
| Llama 4 | Various | Experimental |
| DeepSeek V3 | 16B, 236B, 671B (MoE) | Experimental |
| GPT-OSS | 20B, 120B (MoE) | Experimental |
| Qwen 3 | Various | Experimental |
| Flux | Diffusion | Experimental |
## Performance benchmarks (H100)
| Model | GPUs | Parallelism | TPS/GPU | Techniques |
|-------|------|-------------|---------|------------|
| Llama 8B | 8 | FSDP | 5,762 | Baseline |
| Llama 8B | 8 | FSDP+compile+FP8 | 8,532 | +48% |
| Llama 70B | 256 | FSDP+TP+AsyncTP | 876 | 2D parallel |
| Llama 405B | 512 | FSDP+TP+PP | 128 | 3D parallel |
## Advanced topics
**FSDP2 configuration**: See [references/fsdp.md](references/fsdp.md) for detailed FSDP2 vs FSDP1 comparison and ZeRO equivalents.
**Float8 training**: See [references/float8.md](references/float8.md) for tensorwise vs rowwise scaling recipes.
**Checkpointing**: See [references/checkpoint.md](references/checkpoint.md) for HuggingFace conversion and async checkpointing.
**Adding custom models**: See [references/custom-models.md](references/custom-models.md) for TrainSpec protocol.
## Resources
- GitHub: https://github.com/pytorch/torchtitan
- Paper: https://arxiv.org/abs/2410.06511
- ICLR 2025: https://iclr.cc/virtual/2025/poster/29620
- PyTorch Forum: https://discuss.pytorch.org/c/distributed/torchtitan/44

View file

@ -0,0 +1,181 @@
# Checkpointing in TorchTitan
TorchTitan uses PyTorch Distributed Checkpoint (DCP) for fault-tolerant, interoperable checkpointing.
## Basic Configuration
```toml
[checkpoint]
enable = true
folder = "checkpoint"
interval = 500
```
## Save Model Only (Smaller Checkpoints)
Exclude optimizer state and training metadata:
```toml
[checkpoint]
enable = true
last_save_model_only = true
export_dtype = "bfloat16" # Optional: export in lower precision
```
## Excluding Keys from Loading
Partial checkpoint loading for modified settings:
```toml
[checkpoint]
enable = true
exclude_from_loading = ["data_loader", "lr_scheduler"]
```
CLI equivalent:
```bash
--checkpoint.exclude_from_loading data_loader,lr_scheduler
```
## Creating Seed Checkpoints
Required for Pipeline Parallelism to ensure consistent initialization:
```bash
NGPU=1 CONFIG_FILE=<path_to_config> ./run_train.sh \
--checkpoint.enable \
--checkpoint.create_seed_checkpoint \
--parallelism.data_parallel_replicate_degree 1 \
--parallelism.data_parallel_shard_degree 1 \
--parallelism.tensor_parallel_degree 1 \
--parallelism.pipeline_parallel_degree 1 \
--parallelism.context_parallel_degree 1 \
--parallelism.expert_parallel_degree 1
```
This initializes on single CPU for reproducible initialization across any GPU count.
## Async Checkpointing
Reduce checkpoint overhead with async writes:
```toml
[checkpoint]
enable = true
async_mode = "async" # Options: "disabled", "async", "async_with_pinned_mem"
```
## HuggingFace Conversion
### During Training
Save directly in HuggingFace format:
```toml
[checkpoint]
last_save_in_hf = true
last_save_model_only = true
```
Load from HuggingFace:
```toml
[checkpoint]
initial_load_in_hf = true
[model]
hf_assets_path = "./path/to/hf/checkpoint"
```
### Offline Conversion
Convert without running training:
```bash
# HuggingFace -> TorchTitan
python ./scripts/checkpoint_conversion/convert_from_hf.py \
<input_dir> <output_dir> \
--model_name llama3 \
--model_flavor 8B
# TorchTitan -> HuggingFace
python ./scripts/checkpoint_conversion/convert_to_hf.py \
<input_dir> <output_dir> \
--hf_assets_path ./assets/hf/Llama3.1-8B \
--model_name llama3 \
--model_flavor 8B
```
### Example
```bash
python ./scripts/convert_from_hf.py \
~/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920/ \
./initial_load_path/ \
--model_name llama3 \
--model_flavor 8B
```
## Converting to Single .pt File
Convert DCP sharded checkpoint to single PyTorch file:
```bash
python -m torch.distributed.checkpoint.format_utils \
dcp_to_torch \
torchtitan/outputs/checkpoint/step-1000 \
checkpoint.pt
```
## Checkpoint Structure
DCP saves sharded checkpoints that can be resharded for different parallelism configurations:
```
checkpoint/
├── step-500/
│ ├── .metadata
│ ├── __0_0.distcp
│ ├── __0_1.distcp
│ └── ...
└── step-1000/
└── ...
```
## Resume Training
Training auto-resumes from the latest checkpoint in the configured folder. To resume from a specific step:
```toml
[checkpoint]
load_step = 500 # Resume from step 500
```
## Interoperability with TorchTune
Checkpoints saved with `last_save_model_only = true` can be loaded directly into [torchtune](https://github.com/pytorch/torchtune) for fine-tuning.
## Full Configuration Example
```toml
[checkpoint]
enable = true
folder = "checkpoint"
interval = 500
load_step = -1 # -1 = latest, or specify step number
last_save_model_only = true
export_dtype = "bfloat16"
async_mode = "async"
exclude_from_loading = []
last_save_in_hf = false
initial_load_in_hf = false
create_seed_checkpoint = false
```
## Best Practices
1. **Large models**: Use `async_mode = "async"` to overlap checkpoint saves with training
2. **Fine-tuning export**: Enable `last_save_model_only` and `export_dtype = "bfloat16"` for smaller files
3. **Pipeline parallelism**: Always create seed checkpoint first
4. **Debugging**: Save frequent checkpoints during development, reduce for production
5. **HF interop**: Use conversion scripts for offline conversion, direct save/load for training workflows

View file

@ -0,0 +1,258 @@
# Adding Custom Models to TorchTitan
This guide explains how to add a new model to TorchTitan following the established patterns.
## Directory Structure
```
torchtitan/models/your_model/
├── model/
│ ├── __init__.py
│ ├── args.py # Model arguments
│ ├── model.py # Model definition
│ └── state_dict_adapter.py # HF conversion (optional)
├── infra/
│ ├── __init__.py
│ ├── parallelize.py # TP, FSDP, compile application
│ └── pipeline.py # PP application (optional)
├── train_configs/
│ ├── debug_model.toml
│ └── your_model_XB.toml
├── __init__.py # TrainSpec registration
└── README.md
```
## Step 1: Define Model Arguments
Inherit from `BaseModelArgs`:
```python
# model/args.py
from torchtitan.protocols.model import BaseModelArgs
from dataclasses import dataclass
@dataclass
class YourModelArgs(BaseModelArgs):
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
vocab_size: int = 128256
def get_nparams_and_flops(self, seq_len: int) -> tuple[int, int]:
"""Return (num_params, flops_per_token) for throughput calculation."""
nparams = self.vocab_size * self.dim + ... # Calculate params
flops = 6 * nparams # Approximate: 6 * params for forward+backward
return nparams, flops
def update_from_config(self, job_config) -> "YourModelArgs":
"""Update args from training config."""
# Override specific args from job_config if needed
return self
```
## Step 2: Define Model
Inherit from `ModelProtocol`:
```python
# model/model.py
import torch.nn as nn
from torchtitan.protocols.model import ModelProtocol
from .args import YourModelArgs
class YourModel(ModelProtocol):
def __init__(self, args: YourModelArgs):
super().__init__()
self.args = args
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = nn.ModuleDict({
str(i): TransformerBlock(args) for i in range(args.n_layers)
})
self.norm = RMSNorm(args.dim)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
h = self.tok_embeddings(tokens)
for layer in self.layers.values():
h = layer(h)
h = self.norm(h)
return self.output(h)
def init_weights(self):
"""Initialize weights recursively."""
for module in self.modules():
if hasattr(module, 'init_weights') and module is not self:
module.init_weights()
elif isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=0.02)
```
**Important guidelines**:
- Write single-device model code (parallelism applied externally)
- Use `nn.ModuleDict` for layers (preserves FQNs when deleting for PP)
- Make input/output layers optional for PP compatibility
- Define `init_weights()` recursively
## Step 3: Parallelize Function
```python
# infra/parallelize.py
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.tensor.parallel import parallelize_module
def parallelize_your_model(
model: YourModel,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
):
# Apply in this order: TP -> AC -> compile -> FSDP
# 1. Tensor Parallelism
if parallel_dims.tp_enabled:
apply_tp(model, world_mesh["tp"], job_config)
# 2. Activation Checkpointing
if job_config.activation_checkpoint.mode == "full":
apply_ac(model, job_config)
# 3. torch.compile
if job_config.compile.enable:
model = torch.compile(model)
# 4. FSDP
if parallel_dims.dp_enabled:
apply_fsdp(model, world_mesh["dp"], job_config)
return model
```
## Step 4: Create TrainSpec
```python
# __init__.py
from torchtitan.protocols.train_spec import TrainSpec, register_train_spec
from .model.model import YourModel
from .model.args import YourModelArgs
from .infra.parallelize import parallelize_your_model
MODEL_CONFIGS = {
"8B": YourModelArgs(dim=4096, n_layers=32, n_heads=32),
"70B": YourModelArgs(dim=8192, n_layers=80, n_heads=64),
}
def get_train_spec(flavor: str) -> TrainSpec:
return TrainSpec(
model_cls=YourModel,
model_args=MODEL_CONFIGS[flavor],
parallelize_fn=parallelize_your_model,
pipeline_fn=None, # Or your_pipeline_fn for PP
build_optimizer_fn=build_optimizer, # Reuse existing
build_lr_scheduler_fn=build_lr_scheduler, # Reuse existing
build_dataloader_fn=build_dataloader, # Reuse existing
build_tokenizer_fn=build_tokenizer, # Reuse existing
build_loss_fn=build_loss, # Reuse existing
state_dict_adapter=None, # Or YourStateDictAdapter
)
# Register so train.py can find it
register_train_spec("your_model", get_train_spec)
```
## Step 5: State Dict Adapter (Optional)
For HuggingFace checkpoint conversion:
```python
# model/state_dict_adapter.py
from torchtitan.protocols.state_dict_adapter import BaseStateDictAdapter
class YourStateDictAdapter(BaseStateDictAdapter):
def to_hf(self, state_dict: dict) -> dict:
"""Convert torchtitan state dict to HF format."""
hf_state_dict = {}
for key, value in state_dict.items():
hf_key = self._convert_key_to_hf(key)
hf_state_dict[hf_key] = value
return hf_state_dict
def from_hf(self, state_dict: dict) -> dict:
"""Convert HF state dict to torchtitan format."""
tt_state_dict = {}
for key, value in state_dict.items():
tt_key = self._convert_key_from_hf(key)
tt_state_dict[tt_key] = value
return tt_state_dict
```
## Step 6: Training Config
```toml
# train_configs/your_model_8b.toml
[job]
dump_folder = "./outputs"
description = "Your Model 8B training"
[model]
name = "your_model"
flavor = "8B"
[optimizer]
name = "AdamW"
lr = 3e-4
[training]
local_batch_size = 2
seq_len = 8192
steps = 1000
dataset = "c4"
[parallelism]
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
```
## Step 7: Register Model
Add to `torchtitan/models/__init__.py`:
```python
from .your_model import get_train_spec as get_your_model_train_spec
MODEL_REGISTRY["your_model"] = get_your_model_train_spec
```
## Testing
### Numerics Test
Compare output with HuggingFace implementation:
```python
def test_numerics():
# Load same checkpoint into both implementations
tt_model = YourModel(args).load_checkpoint(...)
hf_model = HFYourModel.from_pretrained(...)
# Compare outputs
input_ids = torch.randint(0, vocab_size, (1, 128))
tt_output = tt_model(input_ids)
hf_output = hf_model(input_ids).logits
torch.testing.assert_close(tt_output, hf_output, atol=1e-4, rtol=1e-4)
```
### Loss Convergence
Compare loss curves with verified baseline (see `docs/converging.md`).
### Performance Benchmark
Add benchmark config to `benchmarks/` folder.
## Guiding Principles
1. **Readability over flexibility**: Don't over-abstract
2. **Minimal model changes**: Parallelism applied externally
3. **Clean, minimal codebase**: Reuse existing components where possible
4. **Single-device semantics**: Model code should work on single GPU

View file

@ -0,0 +1,133 @@
# Float8 Training in TorchTitan
Float8 training provides substantial speedups for models where GEMMs are large enough that the FP8 tensorcore speedup outweighs dynamic quantization overhead.
## Hardware Requirements
- NVIDIA H100 or newer GPUs (FP8 Tensor Cores)
- Blackwell GPUs for MXFP8 training
## Installation
```bash
USE_CPP=0 pip install git+https://github.com/pytorch/ao.git
```
## Usage: Tensorwise Scaling
Standard Float8 with tensorwise dynamic scaling:
```bash
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \
--model.converters="quantize.linear.float8" \
--quantize.linear.float8.enable_fsdp_float8_all_gather \
--quantize.linear.float8.precompute_float8_dynamic_scale_for_fsdp \
--compile.enable
```
### Key Arguments
| Argument | Description |
|----------|-------------|
| `--model.converters="quantize.linear.float8"` | Swap `nn.Linear` with `Float8Linear` |
| `--quantize.linear.float8.enable_fsdp_float8_all_gather` | Communicate in float8 to save bandwidth |
| `--quantize.linear.float8.precompute_float8_dynamic_scale_for_fsdp` | Single all-reduce for all AMAX/scales |
| `--compile.enable` | Required - fuses float8 scaling/casting kernels |
## Usage: Rowwise Scaling
Higher accuracy than tensorwise scaling:
```bash
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \
--model.converters="quantize.linear.float8" \
--quantize.linear.float8.recipe_name rowwise \
--compile.enable
```
## Filtering Layers
Not all layers benefit from Float8. Filter small layers:
```bash
--quantize.linear.float8.filter_fqns="attention.wk,attention.wv,output"
```
### Auto-filtering
Automatically skip layers too small to benefit:
```bash
--quantize.linear.float8.filter_fqns="auto_filter_small_kn"
```
Thresholds based on H100 microbenchmarks where speedup > overhead.
## TOML Configuration
```toml
[model]
converters = ["quantize.linear.float8"]
[quantize.linear.float8]
enable_fsdp_float8_all_gather = true
precompute_float8_dynamic_scale_for_fsdp = true
filter_fqns = ["output", "auto_filter_small_kn"]
[compile]
enable = true
components = ["model", "loss"]
```
## How Float8 Works with Distributed Training
### Single Device
Cast input and weight to float8 inside forward before calling `torch._scaled_mm`:
```python
# Float8 matmul requires scales
torch._scaled_mm(input_fp8, weight_fp8, scale_a=scale_input, scale_b=scale_weight)
```
### FSDP + Float8
1. Cast sharded high-precision weights (1/N per rank) to float8
2. Perform float8 all-gather (saves bandwidth vs bf16/fp32)
3. Communicate `max(abs)` across ranks for scale computation
4. At forward start, have unsharded float8 weights ready
**Net benefit**: Float8 all-gather + amax communication can beat bf16/fp32 all-gather, depending on world size and message size.
### TP + Float8
- **Input**: Cast sharded input to float8, all-gather in float8
- **Weights**: Communicate `max(abs)` for sharded weights
- **Matmul**: Float8 input (unsharded) x float8 weight (sharded) with global scales
## Scaling Strategies
| Strategy | Status | Description |
|----------|--------|-------------|
| Tensorwise dynamic | Stable | Single scale per tensor |
| Rowwise dynamic | Alpha | Scale per row, higher accuracy |
## Performance Gains
From benchmarks on H100:
| Configuration | TPS/GPU | vs Baseline |
|---------------|---------|-------------|
| FSDP only | 5,762 | - |
| FSDP + compile | 6,667 | +16% |
| FSDP + compile + Float8 | 8,532 | +48% |
## Determining Float8 Benefit
Check [torchao microbenchmarks](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) for forward+backward pass speedups on "layer norm => linear => sigmoid" for different M,N,K sizes.
Rule of thumb: GEMMs with K,N > 4096 typically benefit from Float8.
## MXFP8 Training (Blackwell)
For NVIDIA Blackwell GPUs, TorchTitan supports MXFP8 (Microscaling FP8) for both dense and MoE models. See [docs/mxfp8.md](https://github.com/pytorch/torchtitan/blob/main/docs/mxfp8.md) for details.

View file

@ -0,0 +1,126 @@
# FSDP2 in TorchTitan
## Why FSDP2?
FSDP2 is a rewrite of PyTorch's Fully Sharded Data Parallel (FSDP) API, removing the `FlatParameter` abstraction for better composability and simpler implementation.
### Key improvements over FSDP1
- **DTensor-based sharding**: Sharded parameters are `DTensor`s on dim-0, enabling easy manipulation and communication-free sharded state dicts
- **Better memory management**: Deterministic and lower GPU memory (7% reduction) by avoiding `recordStream`
- **Simplified API**: Fewer arguments, no wrapper class
### Performance
On Llama-7B with 8x H100s, FSDP2 achieves higher MFU with 7% lower peak memory than FSDP1, matching the same loss curve.
## API Reference
```python
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, OffloadPolicy
@contract(state_cls=FSDPState)
def fully_shard(
module: nn.Module,
*,
mesh: Optional[DeviceMesh] = None,
reshard_after_forward: Union[bool, int] = True,
mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
offload_policy: OffloadPolicy = OffloadPolicy(),
) -> nn.Module:
```
## Sharding Strategies (ZeRO Equivalents)
| FSDP2 Configuration | FSDP1 Equivalent | DeepSpeed |
|---------------------|------------------|-----------|
| 1D mesh + `reshard_after_forward=True` | FULL_SHARD | ZeRO-3 |
| 1D mesh + `reshard_after_forward=False` | SHARD_GRAD_OP | ZeRO-2 |
| 2D mesh + `reshard_after_forward=True` | HYBRID_SHARD | MiCS |
| 1D/2D mesh + `reshard_after_forward=8` (int) | - | ZeRO++ hpZ |
## Meta-Device Initialization
FSDP2 supports materializing tensors onto GPU _after_ sharding:
```python
# Initialize on meta device (no memory)
with torch.device("meta"):
model = Transformer()
# Apply FSDP2 sharding
for module in model.modules():
if isinstance(module, TransformerBlock):
fully_shard(module)
fully_shard(model)
# Parameters still on meta device
for tensor in itertools.chain(model.parameters(), model.buffers()):
assert tensor.device == torch.device("meta")
# Allocate sharded parameters on GPU
model.to_empty(device="cuda")
# Initialize weights
model.init_weights()
```
## State Dict Differences
| Operation | FSDP1 | FSDP2 |
|-----------|-------|-------|
| `model.state_dict()` | Full state dict | Sharded state dict (no communication) |
| `optim.state_dict()` | Local state dict | Sharded state dict (no communication) |
| `summon_full_params()` | Supported | Use `DTensor` APIs like `full_tensor()` |
| Gradient clipping | `FSDP.clip_grad_norm_()` | `nn.utils.clip_grad_norm_()` |
## Mixed Precision
```python
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
output_dtype=torch.bfloat16,
cast_forward_inputs=True,
)
fully_shard(model, mp_policy=mp_policy)
```
## HSDP (Hybrid Sharded Data Parallel)
For 2D parallelism with replication + sharding:
```python
from torch.distributed.device_mesh import init_device_mesh
# Replicate across 4 groups, shard within 8 GPUs each
mesh = init_device_mesh("cuda", (4, 8), mesh_dim_names=("replicate", "shard"))
fully_shard(model, mesh=mesh)
```
## Configuration in TorchTitan
```toml
[parallelism]
# FSDP sharding degree (-1 = auto, use all available GPUs)
data_parallel_shard_degree = -1
# HSDP replication degree (1 = pure FSDP, >1 = HSDP)
data_parallel_replicate_degree = 1
```
## Removed Arguments from FSDP1
These FSDP1 arguments are no longer needed:
- `auto_wrap_policy`: Apply `fully_shard` directly to modules
- `backward_prefetch`: Always uses BACKWARD_PRE
- `param_init_fn`: Use meta-device initialization
- `device_id`: Uses mesh's device automatically
- `sync_module_states`: Not needed with DTensor
- `limit_all_gathers`: New memory management doesn't need it
- `use_orig_params`: Always true (no FlatParameter)

View file

@ -0,0 +1,96 @@
---
name: domain-intel
description: Passive domain reconnaissance using Python stdlib. Subdomain discovery, SSL certificate inspection, WHOIS lookups, DNS records, domain availability checks, and bulk multi-domain analysis. No API keys required.
---
# Domain Intelligence — Passive OSINT
Passive domain reconnaissance using only Python stdlib.
**Zero dependencies. Zero API keys. Works on Linux, macOS, and Windows.**
## Helper script
This skill includes `scripts/domain_intel.py` — a complete CLI tool for all domain intelligence operations.
```bash
# Subdomain discovery via Certificate Transparency logs
python3 SKILL_DIR/scripts/domain_intel.py subdomains example.com
# SSL certificate inspection (expiry, cipher, SANs, issuer)
python3 SKILL_DIR/scripts/domain_intel.py ssl example.com
# WHOIS lookup (registrar, dates, name servers — 100+ TLDs)
python3 SKILL_DIR/scripts/domain_intel.py whois example.com
# DNS records (A, AAAA, MX, NS, TXT, CNAME)
python3 SKILL_DIR/scripts/domain_intel.py dns example.com
# Domain availability check (passive: DNS + WHOIS + SSL signals)
python3 SKILL_DIR/scripts/domain_intel.py available coolstartup.io
# Bulk analysis — multiple domains, multiple checks in parallel
python3 SKILL_DIR/scripts/domain_intel.py bulk example.com github.com google.com
python3 SKILL_DIR/scripts/domain_intel.py bulk example.com github.com --checks ssl,dns
```
`SKILL_DIR` is the directory containing this SKILL.md file. All output is structured JSON.
## Available commands
| Command | What it does | Data source |
|---------|-------------|-------------|
| `subdomains` | Find subdomains from certificate logs | crt.sh (HTTPS) |
| `ssl` | Inspect TLS certificate details | Direct TCP:443 to target |
| `whois` | Registration info, registrar, dates | WHOIS servers (TCP:43) |
| `dns` | A, AAAA, MX, NS, TXT, CNAME records | System DNS + Google DoH |
| `available` | Check if domain is registered | DNS + WHOIS + SSL signals |
| `bulk` | Run multiple checks on multiple domains | All of the above |
## When to use this vs built-in tools
- **Use this skill** for infrastructure questions: subdomains, SSL certs, WHOIS, DNS records, availability
- **Use `web_search`** for general research about what a domain/company does
- **Use `web_extract`** to get the actual content of a webpage
- **Use `terminal` with `curl -I`** for a simple "is this URL reachable" check
| Task | Better tool | Why |
|------|-------------|-----|
| "What does example.com do?" | `web_extract` | Gets page content, not DNS/WHOIS data |
| "Find info about a company" | `web_search` | General research, not domain-specific |
| "Is this website safe?" | `web_search` | Reputation checks need web context |
| "Check if a URL is reachable" | `terminal` with `curl -I` | Simple HTTP check |
| "Find subdomains of X" | **This skill** | Only passive source for this |
| "When does the SSL cert expire?" | **This skill** | Built-in tools can't inspect TLS |
| "Who registered this domain?" | **This skill** | WHOIS data not in web search |
| "Is coolstartup.io available?" | **This skill** | Passive availability via DNS+WHOIS+SSL |
## Platform compatibility
Pure Python stdlib (`socket`, `ssl`, `urllib`, `json`, `concurrent.futures`).
Works identically on Linux, macOS, and Windows with no dependencies.
- **crt.sh queries** use HTTPS (port 443) — works behind most firewalls
- **WHOIS queries** use TCP port 43 — may be blocked on restrictive networks
- **DNS queries** use Google DoH (HTTPS) for MX/NS/TXT — firewall-friendly
- **SSL checks** connect to the target on port 443 — the only "active" operation
## Data sources
All queries are **passive** — no port scanning, no vulnerability testing:
- **crt.sh** — Certificate Transparency logs (subdomain discovery, HTTPS only)
- **WHOIS servers** — Direct TCP to 100+ authoritative TLD registrars
- **Google DNS-over-HTTPS** — MX, NS, TXT, CNAME resolution (firewall-friendly)
- **System DNS** — A/AAAA record resolution
- **SSL check** is the only "active" operation (TCP connection to target:443)
## Notes
- WHOIS queries use TCP port 43 — may be blocked on restrictive networks
- Some WHOIS servers redact registrant info (GDPR) — mention this to the user
- crt.sh can be slow for very popular domains (thousands of certs) — set reasonable expectations
- The availability check is heuristic-based (3 passive signals) — not authoritative like a registrar API
---
*Contributed by [@FurkanL0](https://github.com/FurkanL0)*

View file

@ -0,0 +1,397 @@
#!/usr/bin/env python3
"""
Domain Intelligence Passive OSINT via Python stdlib.
Usage:
python domain_intel.py subdomains example.com
python domain_intel.py ssl example.com
python domain_intel.py whois example.com
python domain_intel.py dns example.com
python domain_intel.py available example.com
python domain_intel.py bulk example.com github.com google.com --checks ssl,dns
All output is structured JSON. No dependencies beyond Python stdlib.
Works on Linux, macOS, and Windows.
"""
import json
import re
import socket
import ssl
import sys
import urllib.request
import urllib.parse
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timezone
# ─── Subdomain Discovery (crt.sh) ──────────────────────────────────────────
def subdomains(domain, include_expired=False, limit=200):
"""Find subdomains via Certificate Transparency logs."""
url = f"https://crt.sh/?q=%25.{urllib.parse.quote(domain)}&output=json"
req = urllib.request.Request(url, headers={
"User-Agent": "domain-intel-skill/1.0", "Accept": "application/json",
})
with urllib.request.urlopen(req, timeout=15) as r:
entries = json.loads(r.read().decode())
seen, results = set(), []
now = datetime.now(timezone.utc)
for e in entries:
not_after = e.get("not_after", "")
if not include_expired and not_after:
try:
dt = datetime.strptime(not_after[:19], "%Y-%m-%dT%H:%M:%S").replace(tzinfo=timezone.utc)
if dt <= now:
continue
except ValueError:
pass
for name in e.get("name_value", "").splitlines():
name = name.strip().lower()
if name and name not in seen:
seen.add(name)
results.append({
"subdomain": name,
"issuer": e.get("issuer_name", ""),
"not_after": not_after,
})
results.sort(key=lambda r: (r["subdomain"].startswith("*"), r["subdomain"]))
return {"domain": domain, "count": min(len(results), limit), "subdomains": results[:limit]}
# ─── SSL Certificate Inspection ────────────────────────────────────────────
def check_ssl(host, port=443, timeout=10):
"""Inspect the TLS certificate of a host."""
def flat(rdns):
r = {}
for rdn in rdns:
for item in rdn:
if isinstance(item, (list, tuple)) and len(item) == 2:
r[item[0]] = item[1]
return r
def parse_date(s):
for fmt in ("%b %d %H:%M:%S %Y %Z", "%b %d %H:%M:%S %Y %Z"):
try:
return datetime.strptime(s, fmt).replace(tzinfo=timezone.utc)
except ValueError:
pass
return None
warning = None
try:
ctx = ssl.create_default_context()
with socket.create_connection((host, port), timeout=timeout) as sock:
with ctx.wrap_socket(sock, server_hostname=host) as s:
cert, cipher, proto = s.getpeercert(), s.cipher(), s.version()
except ssl.SSLCertVerificationError as e:
warning = str(e)
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
with socket.create_connection((host, port), timeout=timeout) as sock:
with ctx.wrap_socket(sock, server_hostname=host) as s:
cert, cipher, proto = s.getpeercert(), s.cipher(), s.version()
not_after = parse_date(cert.get("notAfter", ""))
now = datetime.now(timezone.utc)
days = (not_after - now).days if not_after else None
is_expired = days is not None and days < 0
if is_expired:
status = f"EXPIRED ({abs(days)} days ago)"
elif days is not None and days <= 14:
status = f"CRITICAL — {days} day(s) left"
elif days is not None and days <= 30:
status = f"WARNING — {days} day(s) left"
else:
status = f"OK — {days} day(s) remaining" if days is not None else "unknown"
return {
"host": host, "port": port,
"subject": flat(cert.get("subject", [])),
"issuer": flat(cert.get("issuer", [])),
"subject_alt_names": [f"{t}:{v}" for t, v in cert.get("subjectAltName", [])],
"not_before": parse_date(cert.get("notBefore", "")).isoformat() if parse_date(cert.get("notBefore", "")) else "",
"not_after": not_after.isoformat() if not_after else "",
"days_remaining": days, "is_expired": is_expired, "expiry_status": status,
"tls_version": proto,
"cipher_suite": cipher[0] if cipher else None,
"serial_number": cert.get("serialNumber", ""),
"verification_warning": warning,
}
# ─── WHOIS Lookup ──────────────────────────────────────────────────────────
WHOIS_SERVERS = {
"com": "whois.verisign-grs.com", "net": "whois.verisign-grs.com",
"org": "whois.pir.org", "io": "whois.nic.io", "co": "whois.nic.co",
"ai": "whois.nic.ai", "dev": "whois.nic.google", "app": "whois.nic.google",
"tech": "whois.nic.tech", "shop": "whois.nic.shop", "store": "whois.nic.store",
"online": "whois.nic.online", "site": "whois.nic.site", "cloud": "whois.nic.cloud",
"digital": "whois.nic.digital", "media": "whois.nic.media", "blog": "whois.nic.blog",
"info": "whois.afilias.net", "biz": "whois.biz", "me": "whois.nic.me",
"tv": "whois.nic.tv", "cc": "whois.nic.cc", "ws": "whois.website.ws",
"uk": "whois.nic.uk", "co.uk": "whois.nic.uk", "de": "whois.denic.de",
"nl": "whois.domain-registry.nl", "fr": "whois.nic.fr", "it": "whois.nic.it",
"es": "whois.nic.es", "pl": "whois.dns.pl", "ru": "whois.tcinet.ru",
"se": "whois.iis.se", "no": "whois.norid.no", "fi": "whois.fi",
"ch": "whois.nic.ch", "at": "whois.nic.at", "be": "whois.dns.be",
"cz": "whois.nic.cz", "br": "whois.registro.br", "ca": "whois.cira.ca",
"mx": "whois.mx", "au": "whois.auda.org.au", "jp": "whois.jprs.jp",
"cn": "whois.cnnic.cn", "in": "whois.inregistry.net", "kr": "whois.kr",
"sg": "whois.sgnic.sg", "hk": "whois.hkirc.hk", "tr": "whois.nic.tr",
"ae": "whois.aeda.net.ae", "za": "whois.registry.net.za",
"space": "whois.nic.space", "zone": "whois.nic.zone", "ninja": "whois.nic.ninja",
"guru": "whois.nic.guru", "rocks": "whois.nic.rocks", "live": "whois.nic.live",
"game": "whois.nic.game", "games": "whois.nic.games",
}
def whois_lookup(domain):
"""Query WHOIS servers for domain registration info."""
parts = domain.split(".")
server = WHOIS_SERVERS.get(".".join(parts[-2:])) or WHOIS_SERVERS.get(parts[-1])
if not server:
return {"error": f"No WHOIS server for .{parts[-1]}"}
try:
with socket.create_connection((server, 43), timeout=10) as s:
s.sendall((domain + "\r\n").encode())
chunks = []
while True:
c = s.recv(4096)
if not c:
break
chunks.append(c)
raw = b"".join(chunks).decode("utf-8", errors="replace")
except Exception as e:
return {"error": str(e)}
patterns = {
"registrar": r"(?:Registrar|registrar):\s*(.+)",
"creation_date": r"(?:Creation Date|Created|created):\s*(.+)",
"expiration_date": r"(?:Registry Expiry Date|Expiration Date|Expiry Date):\s*(.+)",
"updated_date": r"(?:Updated Date|Last Modified):\s*(.+)",
"name_servers": r"(?:Name Server|nserver):\s*(.+)",
"status": r"(?:Domain Status|status):\s*(.+)",
"dnssec": r"DNSSEC:\s*(.+)",
}
result = {"domain": domain, "whois_server": server}
for key, pat in patterns.items():
matches = re.findall(pat, raw, re.IGNORECASE)
if matches:
if key in ("name_servers", "status"):
result[key] = list(dict.fromkeys(m.strip().lower() for m in matches))
else:
result[key] = matches[0].strip()
for field in ("creation_date", "expiration_date", "updated_date"):
if field in result:
for fmt in ("%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%d %H:%M:%S", "%Y-%m-%d"):
try:
dt = datetime.strptime(result[field][:19], fmt).replace(tzinfo=timezone.utc)
result[field] = dt.isoformat()
if field == "expiration_date":
days = (dt - datetime.now(timezone.utc)).days
result["expiration_days_remaining"] = days
result["is_expired"] = days < 0
break
except ValueError:
pass
return result
# ─── DNS Records ───────────────────────────────────────────────────────────
def dns_records(domain, types=None):
"""Resolve DNS records using system DNS + Google DoH."""
if not types:
types = ["A", "AAAA", "MX", "NS", "TXT", "CNAME"]
records = {}
for qtype in types:
if qtype == "A":
try:
records["A"] = list(dict.fromkeys(
i[4][0] for i in socket.getaddrinfo(domain, None, socket.AF_INET)
))
except Exception:
records["A"] = []
elif qtype == "AAAA":
try:
records["AAAA"] = list(dict.fromkeys(
i[4][0] for i in socket.getaddrinfo(domain, None, socket.AF_INET6)
))
except Exception:
records["AAAA"] = []
else:
url = f"https://dns.google/resolve?name={urllib.parse.quote(domain)}&type={qtype}"
try:
req = urllib.request.Request(url, headers={"User-Agent": "domain-intel-skill/1.0"})
with urllib.request.urlopen(req, timeout=10) as r:
data = json.loads(r.read())
records[qtype] = [
a.get("data", "").strip().rstrip(".")
for a in data.get("Answer", []) if a.get("data")
]
except Exception:
records[qtype] = []
return {"domain": domain, "records": records}
# ─── Domain Availability Check ─────────────────────────────────────────────
def check_available(domain):
"""Check domain availability using passive signals (DNS + WHOIS + SSL)."""
signals = {}
# DNS
try:
a = [i[4][0] for i in socket.getaddrinfo(domain, None, socket.AF_INET)]
except Exception:
a = []
try:
ns_url = f"https://dns.google/resolve?name={urllib.parse.quote(domain)}&type=NS"
req = urllib.request.Request(ns_url, headers={"User-Agent": "domain-intel-skill/1.0"})
with urllib.request.urlopen(req, timeout=10) as r:
ns = [x.get("data", "") for x in json.loads(r.read()).get("Answer", [])]
except Exception:
ns = []
signals["dns_a"] = a
signals["dns_ns"] = ns
dns_exists = bool(a or ns)
# SSL
ssl_up = False
try:
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
with socket.create_connection((domain, 443), timeout=3) as s:
with ctx.wrap_socket(s, server_hostname=domain):
ssl_up = True
except Exception:
pass
signals["ssl_reachable"] = ssl_up
# WHOIS (quick check)
tld = domain.rsplit(".", 1)[-1]
server = WHOIS_SERVERS.get(tld)
whois_avail = None
whois_note = ""
if server:
try:
with socket.create_connection((server, 43), timeout=10) as s:
s.sendall((domain + "\r\n").encode())
raw = b""
while True:
c = s.recv(4096)
if not c:
break
raw += c
raw = raw.decode("utf-8", errors="replace").lower()
if any(p in raw for p in ["no match", "not found", "no data found", "status: free"]):
whois_avail = True
whois_note = "WHOIS: not found"
elif "registrar:" in raw or "creation date:" in raw:
whois_avail = False
whois_note = "WHOIS: registered"
else:
whois_note = "WHOIS: inconclusive"
except Exception as e:
whois_note = f"WHOIS error: {e}"
signals["whois_available"] = whois_avail
signals["whois_note"] = whois_note
if not dns_exists and whois_avail is True:
verdict, conf = "LIKELY AVAILABLE", "high"
elif dns_exists or whois_avail is False or ssl_up:
verdict, conf = "REGISTERED / IN USE", "high"
elif not dns_exists and whois_avail is None:
verdict, conf = "POSSIBLY AVAILABLE", "medium"
else:
verdict, conf = "UNCERTAIN", "low"
return {"domain": domain, "verdict": verdict, "confidence": conf, "signals": signals}
# ─── Bulk Analysis ─────────────────────────────────────────────────────────
COMMAND_MAP = {
"subdomains": subdomains,
"ssl": check_ssl,
"whois": whois_lookup,
"dns": dns_records,
"available": check_available,
}
def bulk_check(domains, checks=None, max_workers=5):
"""Run multiple checks across multiple domains in parallel."""
if not checks:
checks = ["ssl", "whois", "dns"]
def run_one(d):
entry = {"domain": d}
for check in checks:
fn = COMMAND_MAP.get(check)
if fn:
try:
entry[check] = fn(d)
except Exception as e:
entry[check] = {"error": str(e)}
return entry
results = []
with ThreadPoolExecutor(max_workers=min(max_workers, 10)) as ex:
futures = {ex.submit(run_one, d): d for d in domains[:20]}
for f in as_completed(futures):
results.append(f.result())
return {"total": len(results), "checks": checks, "results": results}
# ─── CLI Entry Point ───────────────────────────────────────────────────────
def main():
if len(sys.argv) < 3:
print(__doc__)
sys.exit(1)
command = sys.argv[1].lower()
args = sys.argv[2:]
if command == "bulk":
# Parse --checks flag
checks = None
domains = []
i = 0
while i < len(args):
if args[i] == "--checks" and i + 1 < len(args):
checks = [c.strip() for c in args[i + 1].split(",")]
i += 2
else:
domains.append(args[i])
i += 1
result = bulk_check(domains, checks)
elif command in COMMAND_MAP:
result = COMMAND_MAP[command](args[0])
else:
print(f"Unknown command: {command}")
print(f"Available: {', '.join(COMMAND_MAP.keys())}, bulk")
sys.exit(1)
print(json.dumps(result, indent=2))
if __name__ == "__main__":
main()

View file

@ -0,0 +1,237 @@
---
name: duckduckgo-search
description: Free web search via DuckDuckGo — text, news, images, videos. No API key needed. Prefer the `ddgs` CLI when installed; use the Python DDGS library only after verifying that `ddgs` is available in the current runtime.
version: 1.3.0
author: gamedevCloudy
license: MIT
metadata:
hermes:
tags: [search, duckduckgo, web-search, free, fallback]
related_skills: [arxiv]
fallback_for_toolsets: [web]
---
# DuckDuckGo Search
Free web search using DuckDuckGo. **No API key required.**
Preferred when `web_search` is unavailable or unsuitable (for example when `FIRECRAWL_API_KEY` is not set). Can also be used as a standalone search path when DuckDuckGo results are specifically desired.
## Detection Flow
Check what is actually available before choosing an approach:
```bash
# Check CLI availability
command -v ddgs >/dev/null && echo "DDGS_CLI=installed" || echo "DDGS_CLI=missing"
```
Decision tree:
1. If `ddgs` CLI is installed, prefer `terminal` + `ddgs`
2. If `ddgs` CLI is missing, do not assume `execute_code` can import `ddgs`
3. If the user wants DuckDuckGo specifically, install `ddgs` first in the relevant environment
4. Otherwise fall back to built-in web/browser tools
Important runtime note:
- Terminal and `execute_code` are separate runtimes
- A successful shell install does not guarantee `execute_code` can import `ddgs`
- Never assume third-party Python packages are preinstalled inside `execute_code`
## Installation
Install `ddgs` only when DuckDuckGo search is specifically needed and the runtime does not already provide it.
```bash
# Python package + CLI entrypoint
pip install ddgs
# Verify CLI
ddgs --help
```
If a workflow depends on Python imports, verify that same runtime can import `ddgs` before using `from ddgs import DDGS`.
## Method 1: CLI Search (Preferred)
Use the `ddgs` command via `terminal` when it exists. This is the preferred path because it avoids assuming the `execute_code` sandbox has the `ddgs` Python package installed.
```bash
# Text search
ddgs text -k "python async programming" -m 5
# News search
ddgs news -k "artificial intelligence" -m 5
# Image search
ddgs images -k "landscape photography" -m 10
# Video search
ddgs videos -k "python tutorial" -m 5
# With region filter
ddgs text -k "best restaurants" -m 5 -r us-en
# Recent results only (d=day, w=week, m=month, y=year)
ddgs text -k "latest AI news" -m 5 -t w
# JSON output for parsing
ddgs text -k "fastapi tutorial" -m 5 -o json
```
### CLI Flags
| Flag | Description | Example |
|------|-------------|---------|
| `-k` | Keywords (query) — **required** | `-k "search terms"` |
| `-m` | Max results | `-m 5` |
| `-r` | Region | `-r us-en` |
| `-t` | Time limit | `-t w` (week) |
| `-s` | Safe search | `-s off` |
| `-o` | Output format | `-o json` |
## Method 2: Python API (Only After Verification)
Use the `DDGS` class in `execute_code` or another Python runtime only after verifying that `ddgs` is installed there. Do not assume `execute_code` includes third-party packages by default.
Safe wording:
- "Use `execute_code` with `ddgs` after installing or verifying the package if needed"
Avoid saying:
- "`execute_code` includes `ddgs`"
- "DuckDuckGo search works by default in `execute_code`"
**Important:** `max_results` must always be passed as a **keyword argument** — positional usage raises an error on all methods.
### Text Search
Best for: general research, companies, documentation.
```python
from ddgs import DDGS
with DDGS() as ddgs:
for r in ddgs.text("python async programming", max_results=5):
print(r["title"])
print(r["href"])
print(r.get("body", "")[:200])
print()
```
Returns: `title`, `href`, `body`
### News Search
Best for: current events, breaking news, latest updates.
```python
from ddgs import DDGS
with DDGS() as ddgs:
for r in ddgs.news("AI regulation 2026", max_results=5):
print(r["date"], "-", r["title"])
print(r.get("source", ""), "|", r["url"])
print(r.get("body", "")[:200])
print()
```
Returns: `date`, `title`, `body`, `url`, `image`, `source`
### Image Search
Best for: visual references, product images, diagrams.
```python
from ddgs import DDGS
with DDGS() as ddgs:
for r in ddgs.images("semiconductor chip", max_results=5):
print(r["title"])
print(r["image"])
print(r.get("thumbnail", ""))
print(r.get("source", ""))
print()
```
Returns: `title`, `image`, `thumbnail`, `url`, `height`, `width`, `source`
### Video Search
Best for: tutorials, demos, explainers.
```python
from ddgs import DDGS
with DDGS() as ddgs:
for r in ddgs.videos("FastAPI tutorial", max_results=5):
print(r["title"])
print(r.get("content", ""))
print(r.get("duration", ""))
print(r.get("provider", ""))
print(r.get("published", ""))
print()
```
Returns: `title`, `content`, `description`, `duration`, `provider`, `published`, `statistics`, `uploader`
### Quick Reference
| Method | Use When | Key Fields |
|--------|----------|------------|
| `text()` | General research, companies | title, href, body |
| `news()` | Current events, updates | date, title, source, body, url |
| `images()` | Visuals, diagrams | title, image, thumbnail, url |
| `videos()` | Tutorials, demos | title, content, duration, provider |
## Workflow: Search then Extract
DuckDuckGo returns titles, URLs, and snippets — not full page content. To get full page content, search first and then extract the most relevant URL with `web_extract`, browser tools, or curl.
CLI example:
```bash
ddgs text -k "fastapi deployment guide" -m 3 -o json
```
Python example, only after verifying `ddgs` is installed in that runtime:
```python
from ddgs import DDGS
with DDGS() as ddgs:
results = list(ddgs.text("fastapi deployment guide", max_results=3))
for r in results:
print(r["title"], "->", r["href"])
```
Then extract the best URL with `web_extract` or another content-retrieval tool.
## Limitations
- **Rate limiting**: DuckDuckGo may throttle after many rapid requests. Add a short delay between searches if needed.
- **No content extraction**: `ddgs` returns snippets, not full page content. Use `web_extract`, browser tools, or curl for the full article/page.
- **Results quality**: Generally good but less configurable than Firecrawl's search.
- **Availability**: DuckDuckGo may block requests from some cloud IPs. If searches return empty, try different keywords or wait a few seconds.
- **Field variability**: Return fields may vary between results or `ddgs` versions. Use `.get()` for optional fields to avoid `KeyError`.
- **Separate runtimes**: A successful `ddgs` install in terminal does not automatically mean `execute_code` can import it.
## Troubleshooting
| Problem | Likely Cause | What To Do |
|---------|--------------|------------|
| `ddgs: command not found` | CLI not installed in the shell environment | Install `ddgs`, or use built-in web/browser tools instead |
| `ModuleNotFoundError: No module named 'ddgs'` | Python runtime does not have the package installed | Do not use Python DDGS there until that runtime is prepared |
| Search returns nothing | Temporary rate limiting or poor query | Wait a few seconds, retry, or adjust the query |
| CLI works but `execute_code` import fails | Terminal and `execute_code` are different runtimes | Keep using CLI, or separately prepare the Python runtime |
## Pitfalls
- **`max_results` is keyword-only**: `ddgs.text("query", 5)` raises an error. Use `ddgs.text("query", max_results=5)`.
- **Do not assume the CLI exists**: Check `command -v ddgs` before using it.
- **Do not assume `execute_code` can import `ddgs`**: `from ddgs import DDGS` may fail with `ModuleNotFoundError` unless that runtime was prepared separately.
- **Package name**: The package is `ddgs` (previously `duckduckgo-search`). Install with `pip install ddgs`.
- **Don't confuse `-k` and `-m`** (CLI): `-k` is for keywords, `-m` is for max results count.
- **Empty results**: If `ddgs` returns nothing, it may be rate-limited. Wait a few seconds and retry.
## Validated With
Validated examples against `ddgs==9.11.2` semantics. Skill guidance now treats CLI availability and Python import availability as separate concerns so the documented workflow matches actual runtime behavior.

View file

@ -0,0 +1,28 @@
#!/bin/bash
# DuckDuckGo Search Helper Script
# Wrapper around ddgs CLI with sensible defaults
# Usage: ./duckduckgo.sh <query> [max_results]
set -e
QUERY="$1"
MAX_RESULTS="${2:-5}"
if [ -z "$QUERY" ]; then
echo "Usage: $0 <query> [max_results]"
echo ""
echo "Examples:"
echo " $0 'python async programming' 5"
echo " $0 'latest AI news' 10"
echo ""
echo "Requires: pip install ddgs"
exit 1
fi
# Check if ddgs is available
if ! command -v ddgs &> /dev/null; then
echo "Error: ddgs not found. Install with: pip install ddgs"
exit 1
fi
ddgs text -k "$QUERY" -m "$MAX_RESULTS"