mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
feat(gateway): skill-aware slash commands, paginated /commands, Telegram 100-cap (#3934)
* feat(gateway): skill-aware slash commands, paginated /commands, Telegram 100-cap Map active skills to Telegram's slash command menu so users can discover and invoke skills directly. Three changes: 1. Telegram menu now includes active skill commands alongside built-in commands, capped at 100 entries (Telegram Bot API limit). Overflow commands remain callable but hidden from the picker. Logged at startup when cap is hit. 2. New /commands [page] gateway command for paginated browsing of all commands + skills. /help now shows first 10 skill commands and points to /commands for the full list. 3. When a user types a slash command that matches a disabled or uninstalled skill, they get actionable guidance: - Disabled: 'Enable it with: hermes skills config' - Optional (not installed): 'Install with: hermes skills install official/<path>' Built on ideas from PR #3921 by @kshitijk4poor. * chore: move 21 niche skills to optional-skills Move specialized/niche skills from built-in (skills/) to optional (optional-skills/) to reduce the default skill count. Users can install them with: hermes skills install official/<category>/<name> Moved skills (21): - mlops: accelerate, chroma, faiss, flash-attention, hermes-atropos-environments, huggingface-tokenizers, instructor, lambda-labs, llava, nemo-curator, pinecone, pytorch-lightning, qdrant, saelens, simpo, slime, tensorrt-llm, torchtitan - research: domain-intel, duckduckgo-search - devops: inference-sh cli Built-in skills: 96 → 75 Optional skills: 22 → 43 * fix: only include repo built-in skills in Telegram menu, not user-installed User-installed skills (from hub or manually added) stay accessible via /skills and by typing the command directly, but don't get registered in the Telegram slash command picker. Only skills whose SKILL.md is under the repo's skills/ directory are included in the menu. This keeps the Telegram menu focused on the curated built-in set while user-installed skills remain discoverable through /skills and /commands.
This commit is contained in:
parent
97d6813f51
commit
5ceed021dc
73 changed files with 163 additions and 4 deletions
155
optional-skills/devops/cli/SKILL.md
Normal file
155
optional-skills/devops/cli/SKILL.md
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
---
|
||||
name: inference-sh-cli
|
||||
description: "Run 150+ AI apps via inference.sh CLI (infsh) — image generation, video creation, LLMs, search, 3D, social automation. Uses the terminal tool. Triggers: inference.sh, infsh, ai apps, flux, veo, image generation, video generation, seedream, seedance, tavily"
|
||||
version: 1.0.0
|
||||
author: okaris
|
||||
license: MIT
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [AI, image-generation, video, LLM, search, inference, FLUX, Veo, Claude]
|
||||
related_skills: []
|
||||
---
|
||||
|
||||
# inference.sh CLI
|
||||
|
||||
Run 150+ AI apps in the cloud with a simple CLI. No GPU required.
|
||||
|
||||
All commands use the **terminal tool** to run `infsh` commands.
|
||||
|
||||
## When to Use
|
||||
|
||||
- User asks to generate images (FLUX, Reve, Seedream, Grok, Gemini image)
|
||||
- User asks to generate video (Veo, Wan, Seedance, OmniHuman)
|
||||
- User asks about inference.sh or infsh
|
||||
- User wants to run AI apps without managing individual provider APIs
|
||||
- User asks for AI-powered search (Tavily, Exa)
|
||||
- User needs avatar/lipsync generation
|
||||
|
||||
## Prerequisites
|
||||
|
||||
The `infsh` CLI must be installed and authenticated. Check with:
|
||||
|
||||
```bash
|
||||
infsh me
|
||||
```
|
||||
|
||||
If not installed:
|
||||
|
||||
```bash
|
||||
curl -fsSL https://cli.inference.sh | sh
|
||||
infsh login
|
||||
```
|
||||
|
||||
See `references/authentication.md` for full setup details.
|
||||
|
||||
## Workflow
|
||||
|
||||
### 1. Always Search First
|
||||
|
||||
Never guess app names — always search to find the correct app ID:
|
||||
|
||||
```bash
|
||||
infsh app list --search flux
|
||||
infsh app list --search video
|
||||
infsh app list --search image
|
||||
```
|
||||
|
||||
### 2. Run an App
|
||||
|
||||
Use the exact app ID from the search results. Always use `--json` for machine-readable output:
|
||||
|
||||
```bash
|
||||
infsh app run <app-id> --input '{"prompt": "your prompt here"}' --json
|
||||
```
|
||||
|
||||
### 3. Parse the Output
|
||||
|
||||
The JSON output contains URLs to generated media. Present these to the user with `MEDIA:<url>` for inline display.
|
||||
|
||||
## Common Commands
|
||||
|
||||
### Image Generation
|
||||
|
||||
```bash
|
||||
# Search for image apps
|
||||
infsh app list --search image
|
||||
|
||||
# FLUX Dev with LoRA
|
||||
infsh app run falai/flux-dev-lora --input '{"prompt": "sunset over mountains", "num_images": 1}' --json
|
||||
|
||||
# Gemini image generation
|
||||
infsh app run google/gemini-2-5-flash-image --input '{"prompt": "futuristic city", "num_images": 1}' --json
|
||||
|
||||
# Seedream (ByteDance)
|
||||
infsh app run bytedance/seedream-5-lite --input '{"prompt": "nature scene"}' --json
|
||||
|
||||
# Grok Imagine (xAI)
|
||||
infsh app run xai/grok-imagine-image --input '{"prompt": "abstract art"}' --json
|
||||
```
|
||||
|
||||
### Video Generation
|
||||
|
||||
```bash
|
||||
# Search for video apps
|
||||
infsh app list --search video
|
||||
|
||||
# Veo 3.1 (Google)
|
||||
infsh app run google/veo-3-1-fast --input '{"prompt": "drone shot of coastline"}' --json
|
||||
|
||||
# Seedance (ByteDance)
|
||||
infsh app run bytedance/seedance-1-5-pro --input '{"prompt": "dancing figure", "resolution": "1080p"}' --json
|
||||
|
||||
# Wan 2.5
|
||||
infsh app run falai/wan-2-5 --input '{"prompt": "person walking through city"}' --json
|
||||
```
|
||||
|
||||
### Local File Uploads
|
||||
|
||||
The CLI automatically uploads local files when you provide a path:
|
||||
|
||||
```bash
|
||||
# Upscale a local image
|
||||
infsh app run falai/topaz-image-upscaler --input '{"image": "/path/to/photo.jpg", "upscale_factor": 2}' --json
|
||||
|
||||
# Image-to-video from local file
|
||||
infsh app run falai/wan-2-5-i2v --input '{"image": "/path/to/image.png", "prompt": "make it move"}' --json
|
||||
|
||||
# Avatar with audio
|
||||
infsh app run bytedance/omnihuman-1-5 --input '{"audio": "/path/to/audio.mp3", "image": "/path/to/face.jpg"}' --json
|
||||
```
|
||||
|
||||
### Search & Research
|
||||
|
||||
```bash
|
||||
infsh app list --search search
|
||||
infsh app run tavily/tavily-search --input '{"query": "latest AI news"}' --json
|
||||
infsh app run exa/exa-search --input '{"query": "machine learning papers"}' --json
|
||||
```
|
||||
|
||||
### Other Categories
|
||||
|
||||
```bash
|
||||
# 3D generation
|
||||
infsh app list --search 3d
|
||||
|
||||
# Audio / TTS
|
||||
infsh app list --search tts
|
||||
|
||||
# Twitter/X automation
|
||||
infsh app list --search twitter
|
||||
```
|
||||
|
||||
## Pitfalls
|
||||
|
||||
1. **Never guess app IDs** — always run `infsh app list --search <term>` first. App IDs change and new apps are added frequently.
|
||||
2. **Always use `--json`** — raw output is hard to parse. The `--json` flag gives structured output with URLs.
|
||||
3. **Check authentication** — if commands fail with auth errors, run `infsh login` or verify `INFSH_API_KEY` is set.
|
||||
4. **Long-running apps** — video generation can take 30-120 seconds. The terminal tool timeout should be sufficient, but warn the user it may take a moment.
|
||||
5. **Input format** — the `--input` flag takes a JSON string. Make sure to properly escape quotes.
|
||||
|
||||
## Reference Docs
|
||||
|
||||
- `references/authentication.md` — Setup, login, API keys
|
||||
- `references/app-discovery.md` — Searching and browsing the app catalog
|
||||
- `references/running-apps.md` — Running apps, input formats, output handling
|
||||
- `references/cli-reference.md` — Complete CLI command reference
|
||||
112
optional-skills/devops/cli/references/app-discovery.md
Normal file
112
optional-skills/devops/cli/references/app-discovery.md
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
# Discovering Apps
|
||||
|
||||
## List All Apps
|
||||
|
||||
```bash
|
||||
infsh app list
|
||||
```
|
||||
|
||||
## Pagination
|
||||
|
||||
```bash
|
||||
infsh app list --page 2
|
||||
```
|
||||
|
||||
## Filter by Category
|
||||
|
||||
```bash
|
||||
infsh app list --category image
|
||||
infsh app list --category video
|
||||
infsh app list --category audio
|
||||
infsh app list --category text
|
||||
infsh app list --category other
|
||||
```
|
||||
|
||||
## Search
|
||||
|
||||
```bash
|
||||
infsh app search "flux"
|
||||
infsh app search "video generation"
|
||||
infsh app search "tts" -l
|
||||
infsh app search "image" --category image
|
||||
```
|
||||
|
||||
Or use the flag form:
|
||||
|
||||
```bash
|
||||
infsh app list --search "flux"
|
||||
infsh app list --search "video generation"
|
||||
infsh app list --search "tts"
|
||||
```
|
||||
|
||||
## Featured Apps
|
||||
|
||||
```bash
|
||||
infsh app list --featured
|
||||
```
|
||||
|
||||
## Newest First
|
||||
|
||||
```bash
|
||||
infsh app list --new
|
||||
```
|
||||
|
||||
## Detailed View
|
||||
|
||||
```bash
|
||||
infsh app list -l
|
||||
```
|
||||
|
||||
Shows table with app name, category, description, and featured status.
|
||||
|
||||
## Save to File
|
||||
|
||||
```bash
|
||||
infsh app list --save apps.json
|
||||
```
|
||||
|
||||
## Your Apps
|
||||
|
||||
List apps you've deployed:
|
||||
|
||||
```bash
|
||||
infsh app my
|
||||
infsh app my -l # detailed
|
||||
```
|
||||
|
||||
## Get App Details
|
||||
|
||||
```bash
|
||||
infsh app get falai/flux-dev-lora
|
||||
infsh app get falai/flux-dev-lora --json
|
||||
```
|
||||
|
||||
Shows full app info including input/output schema.
|
||||
|
||||
## Popular Apps by Category
|
||||
|
||||
### Image Generation
|
||||
- `falai/flux-dev-lora` - FLUX.2 Dev (high quality)
|
||||
- `falai/flux-2-klein-lora` - FLUX.2 Klein (fastest)
|
||||
- `infsh/sdxl` - Stable Diffusion XL
|
||||
- `google/gemini-3-pro-image-preview` - Gemini 3 Pro
|
||||
- `xai/grok-imagine-image` - Grok image generation
|
||||
|
||||
### Video Generation
|
||||
- `google/veo-3-1-fast` - Veo 3.1 Fast
|
||||
- `google/veo-3` - Veo 3
|
||||
- `bytedance/seedance-1-5-pro` - Seedance 1.5 Pro
|
||||
- `infsh/ltx-video-2` - LTX Video 2 (with audio)
|
||||
- `bytedance/omnihuman-1-5` - OmniHuman avatar
|
||||
|
||||
### Audio
|
||||
- `infsh/dia-tts` - Conversational TTS
|
||||
- `infsh/kokoro-tts` - Kokoro TTS
|
||||
- `infsh/fast-whisper-large-v3` - Fast transcription
|
||||
- `infsh/diffrythm` - Music generation
|
||||
|
||||
## Documentation
|
||||
|
||||
- [Browsing the Grid](https://inference.sh/docs/apps/browsing-grid) - Visual app browsing
|
||||
- [Apps Overview](https://inference.sh/docs/apps/overview) - Understanding apps
|
||||
- [Running Apps](https://inference.sh/docs/apps/running) - How to run apps
|
||||
59
optional-skills/devops/cli/references/authentication.md
Normal file
59
optional-skills/devops/cli/references/authentication.md
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
# Authentication & Setup
|
||||
|
||||
## Install the CLI
|
||||
|
||||
```bash
|
||||
curl -fsSL https://cli.inference.sh | sh
|
||||
```
|
||||
|
||||
## Login
|
||||
|
||||
```bash
|
||||
infsh login
|
||||
```
|
||||
|
||||
This opens a browser for authentication. After login, credentials are stored locally.
|
||||
|
||||
## Check Authentication
|
||||
|
||||
```bash
|
||||
infsh me
|
||||
```
|
||||
|
||||
Shows your user info if authenticated.
|
||||
|
||||
## Environment Variable
|
||||
|
||||
For CI/CD or scripts, set your API key:
|
||||
|
||||
```bash
|
||||
export INFSH_API_KEY=your-api-key
|
||||
```
|
||||
|
||||
The environment variable overrides the config file.
|
||||
|
||||
## Update CLI
|
||||
|
||||
```bash
|
||||
infsh update
|
||||
```
|
||||
|
||||
Or reinstall:
|
||||
|
||||
```bash
|
||||
curl -fsSL https://cli.inference.sh | sh
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Error | Solution |
|
||||
|-------|----------|
|
||||
| "not authenticated" | Run `infsh login` |
|
||||
| "command not found" | Reinstall CLI or add to PATH |
|
||||
| "API key invalid" | Check `INFSH_API_KEY` or re-login |
|
||||
|
||||
## Documentation
|
||||
|
||||
- [CLI Setup](https://inference.sh/docs/extend/cli-setup) - Complete CLI installation guide
|
||||
- [API Authentication](https://inference.sh/docs/api/authentication) - API key management
|
||||
- [Secrets](https://inference.sh/docs/secrets/overview) - Managing credentials
|
||||
104
optional-skills/devops/cli/references/cli-reference.md
Normal file
104
optional-skills/devops/cli/references/cli-reference.md
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
# CLI Reference
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
curl -fsSL https://cli.inference.sh | sh
|
||||
```
|
||||
|
||||
## Global Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `infsh help` | Show help |
|
||||
| `infsh version` | Show CLI version |
|
||||
| `infsh update` | Update CLI to latest |
|
||||
| `infsh login` | Authenticate |
|
||||
| `infsh me` | Show current user |
|
||||
|
||||
## App Commands
|
||||
|
||||
### Discovery
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `infsh app list` | List available apps |
|
||||
| `infsh app list --category <cat>` | Filter by category (image, video, audio, text, other) |
|
||||
| `infsh app search <query>` | Search apps |
|
||||
| `infsh app list --search <query>` | Search apps (flag form) |
|
||||
| `infsh app list --featured` | Show featured apps |
|
||||
| `infsh app list --new` | Sort by newest |
|
||||
| `infsh app list --page <n>` | Pagination |
|
||||
| `infsh app list -l` | Detailed table view |
|
||||
| `infsh app list --save <file>` | Save to JSON file |
|
||||
| `infsh app my` | List your deployed apps |
|
||||
| `infsh app get <app>` | Get app details |
|
||||
| `infsh app get <app> --json` | Get app details as JSON |
|
||||
|
||||
### Execution
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `infsh app run <app> --input <file>` | Run app with input file |
|
||||
| `infsh app run <app> --input '<json>'` | Run with inline JSON |
|
||||
| `infsh app run <app> --input <file> --no-wait` | Run without waiting for completion |
|
||||
| `infsh app sample <app>` | Show sample input |
|
||||
| `infsh app sample <app> --save <file>` | Save sample to file |
|
||||
|
||||
## Task Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `infsh task get <task-id>` | Get task status and result |
|
||||
| `infsh task get <task-id> --json` | Get task as JSON |
|
||||
| `infsh task get <task-id> --save <file>` | Save task result to file |
|
||||
|
||||
### Development
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `infsh app init` | Create new app (interactive) |
|
||||
| `infsh app init <name>` | Create new app with name |
|
||||
| `infsh app test --input <file>` | Test app locally |
|
||||
| `infsh app deploy` | Deploy app |
|
||||
| `infsh app deploy --dry-run` | Validate without deploying |
|
||||
| `infsh app pull <id>` | Pull app source |
|
||||
| `infsh app pull --all` | Pull all your apps |
|
||||
|
||||
## Environment Variables
|
||||
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| `INFSH_API_KEY` | API key (overrides config) |
|
||||
|
||||
## Shell Completions
|
||||
|
||||
```bash
|
||||
# Bash
|
||||
infsh completion bash > /etc/bash_completion.d/infsh
|
||||
|
||||
# Zsh
|
||||
infsh completion zsh > "${fpath[1]}/_infsh"
|
||||
|
||||
# Fish
|
||||
infsh completion fish > ~/.config/fish/completions/infsh.fish
|
||||
```
|
||||
|
||||
## App Name Format
|
||||
|
||||
Apps use the format `namespace/app-name`:
|
||||
|
||||
- `falai/flux-dev-lora` - fal.ai's FLUX 2 Dev
|
||||
- `google/veo-3` - Google's Veo 3
|
||||
- `infsh/sdxl` - inference.sh's SDXL
|
||||
- `bytedance/seedance-1-5-pro` - ByteDance's Seedance
|
||||
- `xai/grok-imagine-image` - xAI's Grok
|
||||
|
||||
Version pinning: `namespace/app-name@version`
|
||||
|
||||
## Documentation
|
||||
|
||||
- [CLI Setup](https://inference.sh/docs/extend/cli-setup) - Complete CLI installation guide
|
||||
- [Running Apps](https://inference.sh/docs/apps/running) - How to run apps via CLI
|
||||
- [Creating an App](https://inference.sh/docs/extend/creating-app) - Build your own apps
|
||||
- [Deploying](https://inference.sh/docs/extend/deploying) - Deploy apps to the cloud
|
||||
171
optional-skills/devops/cli/references/running-apps.md
Normal file
171
optional-skills/devops/cli/references/running-apps.md
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
# Running Apps
|
||||
|
||||
## Basic Run
|
||||
|
||||
```bash
|
||||
infsh app run user/app-name --input input.json
|
||||
```
|
||||
|
||||
## Inline JSON
|
||||
|
||||
```bash
|
||||
infsh app run falai/flux-dev-lora --input '{"prompt": "a sunset over mountains"}'
|
||||
```
|
||||
|
||||
## Version Pinning
|
||||
|
||||
```bash
|
||||
infsh app run user/app-name@1.0.0 --input input.json
|
||||
```
|
||||
|
||||
## Local File Uploads
|
||||
|
||||
The CLI automatically uploads local files when you provide a file path instead of a URL. Any field that accepts a URL also accepts a local path:
|
||||
|
||||
```bash
|
||||
# Upscale a local image
|
||||
infsh app run falai/topaz-image-upscaler --input '{"image": "/path/to/photo.jpg", "upscale_factor": 2}'
|
||||
|
||||
# Image-to-video from local file
|
||||
infsh app run falai/wan-2-5-i2v --input '{"image": "./my-image.png", "prompt": "make it move"}'
|
||||
|
||||
# Avatar with local audio and image
|
||||
infsh app run bytedance/omnihuman-1-5 --input '{"audio": "/path/to/speech.mp3", "image": "/path/to/face.jpg"}'
|
||||
|
||||
# Post tweet with local media
|
||||
infsh app run x/post-create --input '{"text": "Check this out!", "media": "./screenshot.png"}'
|
||||
```
|
||||
|
||||
Supported paths:
|
||||
- Absolute paths: `/home/user/images/photo.jpg`
|
||||
- Relative paths: `./image.png`, `../data/video.mp4`
|
||||
- Home directory: `~/Pictures/photo.jpg`
|
||||
|
||||
## Generate Sample Input
|
||||
|
||||
Before running, generate a sample input file:
|
||||
|
||||
```bash
|
||||
infsh app sample falai/flux-dev-lora
|
||||
```
|
||||
|
||||
Save to file:
|
||||
|
||||
```bash
|
||||
infsh app sample falai/flux-dev-lora --save input.json
|
||||
```
|
||||
|
||||
Then edit `input.json` and run:
|
||||
|
||||
```bash
|
||||
infsh app run falai/flux-dev-lora --input input.json
|
||||
```
|
||||
|
||||
## Workflow Example
|
||||
|
||||
### Image Generation with FLUX
|
||||
|
||||
```bash
|
||||
# 1. Get app details
|
||||
infsh app get falai/flux-dev-lora
|
||||
|
||||
# 2. Generate sample input
|
||||
infsh app sample falai/flux-dev-lora --save input.json
|
||||
|
||||
# 3. Edit input.json
|
||||
# {
|
||||
# "prompt": "a cat astronaut floating in space",
|
||||
# "num_images": 1,
|
||||
# "image_size": "landscape_16_9"
|
||||
# }
|
||||
|
||||
# 4. Run
|
||||
infsh app run falai/flux-dev-lora --input input.json
|
||||
```
|
||||
|
||||
### Video Generation with Veo
|
||||
|
||||
```bash
|
||||
# 1. Generate sample
|
||||
infsh app sample google/veo-3-1-fast --save input.json
|
||||
|
||||
# 2. Edit prompt
|
||||
# {
|
||||
# "prompt": "A drone shot flying over a forest at sunset"
|
||||
# }
|
||||
|
||||
# 3. Run
|
||||
infsh app run google/veo-3-1-fast --input input.json
|
||||
```
|
||||
|
||||
### Text-to-Speech
|
||||
|
||||
```bash
|
||||
# Quick inline run
|
||||
infsh app run falai/kokoro-tts --input '{"text": "Hello, this is a test."}'
|
||||
```
|
||||
|
||||
## Task Tracking
|
||||
|
||||
When you run an app, the CLI shows the task ID:
|
||||
|
||||
```
|
||||
Running falai/flux-dev-lora
|
||||
Task ID: abc123def456
|
||||
```
|
||||
|
||||
For long-running tasks, you can check status anytime:
|
||||
|
||||
```bash
|
||||
# Check task status
|
||||
infsh task get abc123def456
|
||||
|
||||
# Get result as JSON
|
||||
infsh task get abc123def456 --json
|
||||
|
||||
# Save result to file
|
||||
infsh task get abc123def456 --save result.json
|
||||
```
|
||||
|
||||
### Run Without Waiting
|
||||
|
||||
For very long tasks, run in background:
|
||||
|
||||
```bash
|
||||
# Submit and return immediately
|
||||
infsh app run google/veo-3 --input input.json --no-wait
|
||||
|
||||
# Check later
|
||||
infsh task get <task-id>
|
||||
```
|
||||
|
||||
## Output
|
||||
|
||||
The CLI returns the app output directly. For file outputs (images, videos, audio), you'll receive URLs to download.
|
||||
|
||||
Example output:
|
||||
|
||||
```json
|
||||
{
|
||||
"images": [
|
||||
{
|
||||
"url": "https://cloud.inference.sh/...",
|
||||
"content_type": "image/png"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
| Error | Cause | Solution |
|
||||
|-------|-------|----------|
|
||||
| "invalid input" | Schema mismatch | Check `infsh app get` for required fields |
|
||||
| "app not found" | Wrong app name | Check `infsh app list --search` |
|
||||
| "quota exceeded" | Out of credits | Check account balance |
|
||||
|
||||
## Documentation
|
||||
|
||||
- [Running Apps](https://inference.sh/docs/apps/running) - Complete running apps guide
|
||||
- [Streaming Results](https://inference.sh/docs/api/sdk/streaming) - Real-time progress updates
|
||||
- [Setup Parameters](https://inference.sh/docs/apps/setup-parameters) - Configuring app inputs
|
||||
335
optional-skills/mlops/accelerate/SKILL.md
Normal file
335
optional-skills/mlops/accelerate/SKILL.md
Normal file
|
|
@ -0,0 +1,335 @@
|
|||
---
|
||||
name: huggingface-accelerate
|
||||
description: Simplest distributed training API. 4 lines to add distributed support to any PyTorch script. Unified API for DeepSpeed/FSDP/Megatron/DDP. Automatic device placement, mixed precision (FP16/BF16/FP8). Interactive config, single launch command. HuggingFace ecosystem standard.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [accelerate, torch, transformers]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Distributed Training, HuggingFace, Accelerate, DeepSpeed, FSDP, Mixed Precision, PyTorch, DDP, Unified API, Simple]
|
||||
|
||||
---
|
||||
|
||||
# HuggingFace Accelerate - Unified Distributed Training
|
||||
|
||||
## Quick start
|
||||
|
||||
Accelerate simplifies distributed training to 4 lines of code.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install accelerate
|
||||
```
|
||||
|
||||
**Convert PyTorch script** (4 lines):
|
||||
```python
|
||||
import torch
|
||||
+ from accelerate import Accelerator
|
||||
|
||||
+ accelerator = Accelerator()
|
||||
|
||||
model = torch.nn.Transformer()
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
dataloader = torch.utils.data.DataLoader(dataset)
|
||||
|
||||
+ model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
for batch in dataloader:
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch)
|
||||
- loss.backward()
|
||||
+ accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Run** (single command):
|
||||
```bash
|
||||
accelerate launch train.py
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: From single GPU to multi-GPU
|
||||
|
||||
**Original script**:
|
||||
```python
|
||||
# train.py
|
||||
import torch
|
||||
|
||||
model = torch.nn.Linear(10, 2).to('cuda')
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
|
||||
|
||||
for epoch in range(10):
|
||||
for batch in dataloader:
|
||||
batch = batch.to('cuda')
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch).mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**With Accelerate** (4 lines added):
|
||||
```python
|
||||
# train.py
|
||||
import torch
|
||||
from accelerate import Accelerator # +1
|
||||
|
||||
accelerator = Accelerator() # +2
|
||||
|
||||
model = torch.nn.Linear(10, 2)
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) # +3
|
||||
|
||||
for epoch in range(10):
|
||||
for batch in dataloader:
|
||||
# No .to('cuda') needed - automatic!
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch).mean()
|
||||
accelerator.backward(loss) # +4
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Configure** (interactive):
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
**Questions**:
|
||||
- Which machine? (single/multi GPU/TPU/CPU)
|
||||
- How many machines? (1)
|
||||
- Mixed precision? (no/fp16/bf16/fp8)
|
||||
- DeepSpeed? (no/yes)
|
||||
|
||||
**Launch** (works on any setup):
|
||||
```bash
|
||||
# Single GPU
|
||||
accelerate launch train.py
|
||||
|
||||
# Multi-GPU (8 GPUs)
|
||||
accelerate launch --multi_gpu --num_processes 8 train.py
|
||||
|
||||
# Multi-node
|
||||
accelerate launch --multi_gpu --num_processes 16 \
|
||||
--num_machines 2 --machine_rank 0 \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
train.py
|
||||
```
|
||||
|
||||
### Workflow 2: Mixed precision training
|
||||
|
||||
**Enable FP16/BF16**:
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
# FP16 (with gradient scaling)
|
||||
accelerator = Accelerator(mixed_precision='fp16')
|
||||
|
||||
# BF16 (no scaling, more stable)
|
||||
accelerator = Accelerator(mixed_precision='bf16')
|
||||
|
||||
# FP8 (H100+)
|
||||
accelerator = Accelerator(mixed_precision='fp8')
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
# Everything else is automatic!
|
||||
for batch in dataloader:
|
||||
with accelerator.autocast(): # Optional, done automatically
|
||||
loss = model(batch)
|
||||
accelerator.backward(loss)
|
||||
```
|
||||
|
||||
### Workflow 3: DeepSpeed ZeRO integration
|
||||
|
||||
**Enable DeepSpeed ZeRO-2**:
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
deepspeed_plugin={
|
||||
"zero_stage": 2, # ZeRO-2
|
||||
"offload_optimizer": False,
|
||||
"gradient_accumulation_steps": 4
|
||||
}
|
||||
)
|
||||
|
||||
# Same code as before!
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
```
|
||||
|
||||
**Or via config**:
|
||||
```bash
|
||||
accelerate config
|
||||
# Select: DeepSpeed → ZeRO-2
|
||||
```
|
||||
|
||||
**deepspeed_config.json**:
|
||||
```json
|
||||
{
|
||||
"fp16": {"enabled": false},
|
||||
"bf16": {"enabled": true},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {"device": "cpu"},
|
||||
"allgather_bucket_size": 5e8,
|
||||
"reduce_bucket_size": 5e8
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Launch**:
|
||||
```bash
|
||||
accelerate launch --config_file deepspeed_config.json train.py
|
||||
```
|
||||
|
||||
### Workflow 4: FSDP (Fully Sharded Data Parallel)
|
||||
|
||||
**Enable FSDP**:
|
||||
```python
|
||||
from accelerate import Accelerator, FullyShardedDataParallelPlugin
|
||||
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
sharding_strategy="FULL_SHARD", # ZeRO-3 equivalent
|
||||
auto_wrap_policy="TRANSFORMER_AUTO_WRAP",
|
||||
cpu_offload=False
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
fsdp_plugin=fsdp_plugin
|
||||
)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
```
|
||||
|
||||
**Or via config**:
|
||||
```bash
|
||||
accelerate config
|
||||
# Select: FSDP → Full Shard → No CPU Offload
|
||||
```
|
||||
|
||||
### Workflow 5: Gradient accumulation
|
||||
|
||||
**Accumulate gradients**:
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator(gradient_accumulation_steps=4)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
for batch in dataloader:
|
||||
with accelerator.accumulate(model): # Handles accumulation
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Effective batch size**: `batch_size * num_gpus * gradient_accumulation_steps`
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use Accelerate when**:
|
||||
- Want simplest distributed training
|
||||
- Need single script for any hardware
|
||||
- Use HuggingFace ecosystem
|
||||
- Want flexibility (DDP/DeepSpeed/FSDP/Megatron)
|
||||
- Need quick prototyping
|
||||
|
||||
**Key advantages**:
|
||||
- **4 lines**: Minimal code changes
|
||||
- **Unified API**: Same code for DDP, DeepSpeed, FSDP, Megatron
|
||||
- **Automatic**: Device placement, mixed precision, sharding
|
||||
- **Interactive config**: No manual launcher setup
|
||||
- **Single launch**: Works everywhere
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **PyTorch Lightning**: Need callbacks, high-level abstractions
|
||||
- **Ray Train**: Multi-node orchestration, hyperparameter tuning
|
||||
- **DeepSpeed**: Direct API control, advanced features
|
||||
- **Raw DDP**: Maximum control, minimal abstraction
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Wrong device placement**
|
||||
|
||||
Don't manually move to device:
|
||||
```python
|
||||
# WRONG
|
||||
batch = batch.to('cuda')
|
||||
|
||||
# CORRECT
|
||||
# Accelerate handles it automatically after prepare()
|
||||
```
|
||||
|
||||
**Issue: Gradient accumulation not working**
|
||||
|
||||
Use context manager:
|
||||
```python
|
||||
# CORRECT
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Issue: Checkpointing in distributed**
|
||||
|
||||
Use accelerator methods:
|
||||
```python
|
||||
# Save only on main process
|
||||
if accelerator.is_main_process:
|
||||
accelerator.save_state('checkpoint/')
|
||||
|
||||
# Load on all processes
|
||||
accelerator.load_state('checkpoint/')
|
||||
```
|
||||
|
||||
**Issue: Different results with FSDP**
|
||||
|
||||
Ensure same random seed:
|
||||
```python
|
||||
from accelerate.utils import set_seed
|
||||
set_seed(42)
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Megatron integration**: See [references/megatron-integration.md](references/megatron-integration.md) for tensor parallelism, pipeline parallelism, and sequence parallelism setup.
|
||||
|
||||
**Custom plugins**: See [references/custom-plugins.md](references/custom-plugins.md) for creating custom distributed plugins and advanced configuration.
|
||||
|
||||
**Performance tuning**: See [references/performance.md](references/performance.md) for profiling, memory optimization, and best practices.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **CPU**: Works (slow)
|
||||
- **Single GPU**: Works
|
||||
- **Multi-GPU**: DDP (default), DeepSpeed, or FSDP
|
||||
- **Multi-node**: DDP, DeepSpeed, FSDP, Megatron
|
||||
- **TPU**: Supported
|
||||
- **Apple MPS**: Supported
|
||||
|
||||
**Launcher requirements**:
|
||||
- **DDP**: `torch.distributed.run` (built-in)
|
||||
- **DeepSpeed**: `deepspeed` (pip install deepspeed)
|
||||
- **FSDP**: PyTorch 1.12+ (built-in)
|
||||
- **Megatron**: Custom setup
|
||||
|
||||
## Resources
|
||||
|
||||
- Docs: https://huggingface.co/docs/accelerate
|
||||
- GitHub: https://github.com/huggingface/accelerate
|
||||
- Version: 1.11.0+
|
||||
- Tutorial: "Accelerate your scripts"
|
||||
- Examples: https://github.com/huggingface/accelerate/tree/main/examples
|
||||
- Used by: HuggingFace Transformers, TRL, PEFT, all HF libraries
|
||||
|
||||
|
||||
|
||||
453
optional-skills/mlops/accelerate/references/custom-plugins.md
Normal file
453
optional-skills/mlops/accelerate/references/custom-plugins.md
Normal file
|
|
@ -0,0 +1,453 @@
|
|||
# Custom Plugins for Accelerate
|
||||
|
||||
## Overview
|
||||
|
||||
Accelerate allows creating **custom plugins** to extend distributed training strategies beyond built-in options (DDP, FSDP, DeepSpeed).
|
||||
|
||||
## Plugin Architecture
|
||||
|
||||
### Base Plugin Structure
|
||||
|
||||
```python
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
"""Custom training plugin."""
|
||||
|
||||
# Plugin configuration
|
||||
param1: int = 1
|
||||
param2: str = "default"
|
||||
|
||||
def __post_init__(self):
|
||||
# Validation logic
|
||||
if self.param1 < 1:
|
||||
raise ValueError("param1 must be >= 1")
|
||||
```
|
||||
|
||||
### Using Custom Plugin
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
# Create plugin
|
||||
custom_plugin = CustomPlugin(param1=4, param2="value")
|
||||
|
||||
# Pass to Accelerator
|
||||
accelerator = Accelerator(
|
||||
custom_plugin=custom_plugin # Not a real parameter, example only
|
||||
)
|
||||
```
|
||||
|
||||
## Built-In Plugin Examples
|
||||
|
||||
### 1. GradScalerKwargs (FP16 Configuration)
|
||||
|
||||
```python
|
||||
from accelerate.utils import GradScalerKwargs
|
||||
|
||||
# Configure gradient scaler for FP16
|
||||
scaler_kwargs = GradScalerKwargs(
|
||||
init_scale=2.**16, # Initial loss scale
|
||||
growth_factor=2.0, # Scale growth rate
|
||||
backoff_factor=0.5, # Scale backoff rate
|
||||
growth_interval=2000, # Steps between scale increases
|
||||
enabled=True # Enable scaler
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='fp16',
|
||||
kwargs_handlers=[scaler_kwargs] # Pass as kwargs handler
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Fine-tune FP16 gradient scaling behavior
|
||||
|
||||
### 2. DistributedDataParallelKwargs
|
||||
|
||||
```python
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
# Configure DDP behavior
|
||||
ddp_kwargs = DistributedDataParallelKwargs(
|
||||
bucket_cap_mb=25, # Gradient bucketing size
|
||||
find_unused_parameters=False, # Find unused params (slower)
|
||||
check_reduction=False, # Check gradient reduction
|
||||
gradient_as_bucket_view=True, # Memory optimization
|
||||
static_graph=False # Static computation graph
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
kwargs_handlers=[ddp_kwargs]
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Optimize DDP performance for specific models
|
||||
|
||||
### 3. FP8RecipeKwargs (H100 FP8)
|
||||
|
||||
```python
|
||||
from accelerate.utils import FP8RecipeKwargs
|
||||
|
||||
# Configure FP8 training (H100)
|
||||
fp8_recipe = FP8RecipeKwargs(
|
||||
backend="te", # TransformerEngine backend
|
||||
margin=0, # Scaling margin
|
||||
interval=1, # Scaling interval
|
||||
fp8_format="HYBRID", # E4M3 + E5M2 hybrid
|
||||
amax_history_len=1024, # AMAX history length
|
||||
amax_compute_algo="max" # AMAX computation algorithm
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='fp8',
|
||||
kwargs_handlers=[fp8_recipe]
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Ultra-fast training on H100 GPUs
|
||||
|
||||
## Custom DeepSpeed Configuration
|
||||
|
||||
### ZeRO-3 with CPU Offload
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import DeepSpeedPlugin
|
||||
|
||||
# Custom DeepSpeed config
|
||||
ds_plugin = DeepSpeedPlugin(
|
||||
zero_stage=3, # ZeRO-3
|
||||
offload_optimizer_device="cpu", # CPU offload optimizer
|
||||
offload_param_device="cpu", # CPU offload parameters
|
||||
zero3_init_flag=True, # ZeRO-3 initialization
|
||||
zero3_save_16bit_model=True, # Save FP16 weights
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
deepspeed_plugin=ds_plugin,
|
||||
mixed_precision='bf16'
|
||||
)
|
||||
```
|
||||
|
||||
### ZeRO-2 with NVMe Offload
|
||||
|
||||
```python
|
||||
ds_plugin = DeepSpeedPlugin(
|
||||
zero_stage=2,
|
||||
offload_optimizer_device="nvme", # NVMe offload
|
||||
offload_param_device="nvme",
|
||||
nvme_path="/local_nvme", # NVMe mount path
|
||||
)
|
||||
```
|
||||
|
||||
### Custom JSON Config
|
||||
|
||||
```python
|
||||
import json
|
||||
|
||||
# Load custom DeepSpeed config
|
||||
with open('deepspeed_config.json', 'r') as f:
|
||||
ds_config = json.load(f)
|
||||
|
||||
ds_plugin = DeepSpeedPlugin(hf_ds_config=ds_config)
|
||||
|
||||
accelerator = Accelerator(deepspeed_plugin=ds_plugin)
|
||||
```
|
||||
|
||||
**Example config** (`deepspeed_config.json`):
|
||||
```json
|
||||
{
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": 1.0,
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e9,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"stage3_prefetch_bucket_size": 5e8,
|
||||
"stage3_param_persistence_threshold": 1e6,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"steps_per_print": 100,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
```
|
||||
|
||||
## Custom FSDP Configuration
|
||||
|
||||
### FSDP with Custom Auto-Wrap Policy
|
||||
|
||||
```python
|
||||
from accelerate.utils import FullyShardedDataParallelPlugin
|
||||
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
|
||||
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
|
||||
import functools
|
||||
|
||||
# Custom wrap policy (size-based)
|
||||
wrap_policy = functools.partial(
|
||||
size_based_auto_wrap_policy,
|
||||
min_num_params=1e6 # Wrap layers with 1M+ params
|
||||
)
|
||||
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3 equivalent
|
||||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # Prefetch strategy
|
||||
mixed_precision_policy=None, # Use Accelerator's mixed precision
|
||||
auto_wrap_policy=wrap_policy, # Custom wrapping
|
||||
cpu_offload=False,
|
||||
ignored_modules=None, # Modules to not wrap
|
||||
state_dict_type="FULL_STATE_DICT", # Save format
|
||||
optim_state_dict_config=None,
|
||||
limit_all_gathers=False,
|
||||
use_orig_params=True, # Use original param shapes
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
fsdp_plugin=fsdp_plugin,
|
||||
mixed_precision='bf16'
|
||||
)
|
||||
```
|
||||
|
||||
### FSDP with Transformer Auto-Wrap
|
||||
|
||||
```python
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
|
||||
|
||||
# Wrap at transformer block level
|
||||
wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={GPT2Block} # Wrap GPT2Block layers
|
||||
)
|
||||
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
auto_wrap_policy=wrap_policy
|
||||
)
|
||||
```
|
||||
|
||||
## Creating Custom Training Strategy
|
||||
|
||||
### Example: Custom Gradient Accumulation
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
class CustomGradientAccumulation:
|
||||
def __init__(self, steps=4, adaptive=False):
|
||||
self.steps = steps
|
||||
self.adaptive = adaptive
|
||||
self.current_step = 0
|
||||
|
||||
def should_sync(self, loss):
|
||||
"""Decide whether to sync gradients."""
|
||||
self.current_step += 1
|
||||
|
||||
# Adaptive: sync on high loss
|
||||
if self.adaptive and loss > threshold:
|
||||
self.current_step = 0
|
||||
return True
|
||||
|
||||
# Regular: sync every N steps
|
||||
if self.current_step >= self.steps:
|
||||
self.current_step = 0
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# Usage
|
||||
custom_accum = CustomGradientAccumulation(steps=8, adaptive=True)
|
||||
accelerator = Accelerator()
|
||||
|
||||
for batch in dataloader:
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
|
||||
# Scale loss
|
||||
loss = loss / custom_accum.steps
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Conditional sync
|
||||
if custom_accum.should_sync(loss.item()):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
### Example: Custom Mixed Precision
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class CustomMixedPrecision:
|
||||
"""Custom mixed precision with dynamic loss scaling."""
|
||||
|
||||
def __init__(self, init_scale=2**16, scale_window=2000):
|
||||
self.scaler = torch.cuda.amp.GradScaler(
|
||||
init_scale=init_scale,
|
||||
growth_interval=scale_window
|
||||
)
|
||||
self.scale_history = []
|
||||
|
||||
def scale_loss(self, loss):
|
||||
"""Scale loss for backward."""
|
||||
return self.scaler.scale(loss)
|
||||
|
||||
def unscale_and_clip(self, optimizer, max_norm=1.0):
|
||||
"""Unscale gradients and clip."""
|
||||
self.scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
optimizer.param_groups[0]['params'],
|
||||
max_norm
|
||||
)
|
||||
|
||||
def step(self, optimizer):
|
||||
"""Optimizer step with scaler update."""
|
||||
scale_before = self.scaler.get_scale()
|
||||
self.scaler.step(optimizer)
|
||||
self.scaler.update()
|
||||
scale_after = self.scaler.get_scale()
|
||||
|
||||
# Track scale changes
|
||||
if scale_before != scale_after:
|
||||
self.scale_history.append(scale_after)
|
||||
|
||||
# Usage
|
||||
custom_mp = CustomMixedPrecision()
|
||||
|
||||
for batch in dataloader:
|
||||
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||
loss = model(**batch).loss
|
||||
|
||||
scaled_loss = custom_mp.scale_loss(loss)
|
||||
scaled_loss.backward()
|
||||
|
||||
custom_mp.unscale_and_clip(optimizer, max_norm=1.0)
|
||||
custom_mp.step(optimizer)
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
## Advanced: Custom Distributed Backend
|
||||
|
||||
### Custom AllReduce Strategy
|
||||
|
||||
```python
|
||||
import torch.distributed as dist
|
||||
|
||||
class CustomAllReduce:
|
||||
"""Custom all-reduce with compression."""
|
||||
|
||||
def __init__(self, compression_ratio=0.1):
|
||||
self.compression_ratio = compression_ratio
|
||||
|
||||
def compress_gradients(self, tensor):
|
||||
"""Top-k gradient compression."""
|
||||
k = int(tensor.numel() * self.compression_ratio)
|
||||
values, indices = torch.topk(tensor.abs().view(-1), k)
|
||||
return values, indices
|
||||
|
||||
def all_reduce_compressed(self, tensor):
|
||||
"""All-reduce with gradient compression."""
|
||||
# Compress
|
||||
values, indices = self.compress_gradients(tensor)
|
||||
|
||||
# All-reduce compressed gradients
|
||||
dist.all_reduce(values, op=dist.ReduceOp.SUM)
|
||||
|
||||
# Decompress
|
||||
tensor_compressed = torch.zeros_like(tensor).view(-1)
|
||||
tensor_compressed[indices] = values / dist.get_world_size()
|
||||
|
||||
return tensor_compressed.view_as(tensor)
|
||||
|
||||
# Usage in training loop
|
||||
custom_ar = CustomAllReduce(compression_ratio=0.1)
|
||||
|
||||
for batch in dataloader:
|
||||
loss = model(**batch).loss
|
||||
loss.backward()
|
||||
|
||||
# Custom all-reduce
|
||||
for param in model.parameters():
|
||||
if param.grad is not None:
|
||||
param.grad.data = custom_ar.all_reduce_compressed(param.grad.data)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
## Plugin Best Practices
|
||||
|
||||
### 1. Validation in `__post_init__`
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
learning_rate: float = 1e-3
|
||||
warmup_steps: int = 1000
|
||||
|
||||
def __post_init__(self):
|
||||
# Validate parameters
|
||||
if self.learning_rate <= 0:
|
||||
raise ValueError("learning_rate must be positive")
|
||||
if self.warmup_steps < 0:
|
||||
raise ValueError("warmup_steps must be non-negative")
|
||||
|
||||
# Compute derived values
|
||||
self.min_lr = self.learning_rate * 0.1
|
||||
```
|
||||
|
||||
### 2. Compatibility Checks
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
feature_enabled: bool = True
|
||||
|
||||
def is_compatible(self, accelerator):
|
||||
"""Check if plugin is compatible with accelerator config."""
|
||||
if self.feature_enabled and accelerator.mixed_precision == 'fp8':
|
||||
raise ValueError("Custom plugin not compatible with FP8")
|
||||
return True
|
||||
```
|
||||
|
||||
### 3. State Management
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
counter: int = 0
|
||||
history: list = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.history is None:
|
||||
self.history = []
|
||||
|
||||
def update_state(self, value):
|
||||
"""Update plugin state during training."""
|
||||
self.counter += 1
|
||||
self.history.append(value)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Accelerate Plugins: https://huggingface.co/docs/accelerate/package_reference/kwargs
|
||||
- DeepSpeed Config: https://www.deepspeed.ai/docs/config-json/
|
||||
- FSDP Guide: https://pytorch.org/docs/stable/fsdp.html
|
||||
- Custom Training Loops: https://huggingface.co/docs/accelerate/usage_guides/training_tpu
|
||||
|
|
@ -0,0 +1,489 @@
|
|||
# Megatron Integration with Accelerate
|
||||
|
||||
## Overview
|
||||
|
||||
Accelerate supports Megatron-LM for massive model training with tensor parallelism and pipeline parallelism.
|
||||
|
||||
**Megatron capabilities**:
|
||||
- **Tensor Parallelism (TP)**: Split layers across GPUs
|
||||
- **Pipeline Parallelism (PP)**: Split model depth across GPUs
|
||||
- **Data Parallelism (DP)**: Replicate model across GPU groups
|
||||
- **Sequence Parallelism**: Split sequences for long contexts
|
||||
|
||||
## Setup
|
||||
|
||||
### Install Megatron-LM
|
||||
|
||||
```bash
|
||||
# Clone Megatron-LM repository
|
||||
git clone https://github.com/NVIDIA/Megatron-LM.git
|
||||
cd Megatron-LM
|
||||
pip install -e .
|
||||
|
||||
# Install Apex (NVIDIA optimizations)
|
||||
git clone https://github.com/NVIDIA/apex
|
||||
cd apex
|
||||
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \
|
||||
--config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
|
||||
```
|
||||
|
||||
### Accelerate Configuration
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
**Questions**:
|
||||
```
|
||||
In which compute environment are you running?
|
||||
> This machine
|
||||
|
||||
Which type of machine are you using?
|
||||
> Multi-GPU
|
||||
|
||||
How many different machines will you use?
|
||||
> 1
|
||||
|
||||
Do you want to use DeepSpeed/FSDP?
|
||||
> No
|
||||
|
||||
Do you want to use Megatron-LM?
|
||||
> Yes
|
||||
|
||||
What is the Tensor Parallelism degree? [1-8]
|
||||
> 2
|
||||
|
||||
Do you want to enable Sequence Parallelism?
|
||||
> No
|
||||
|
||||
What is the Pipeline Parallelism degree? [1-8]
|
||||
> 2
|
||||
|
||||
What is the Data Parallelism degree? [1-8]
|
||||
> 2
|
||||
|
||||
Where to perform activation checkpointing? ['SELECTIVE', 'FULL', 'NONE']
|
||||
> SELECTIVE
|
||||
|
||||
Where to perform activation partitioning? ['SEQUENTIAL', 'UNIFORM']
|
||||
> SEQUENTIAL
|
||||
```
|
||||
|
||||
**Generated config** (`~/.cache/huggingface/accelerate/default_config.yaml`):
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
distributed_type: MEGATRON_LM
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
megatron_lm_config:
|
||||
megatron_lm_gradient_clipping: 1.0
|
||||
megatron_lm_learning_rate_decay_iters: 320000
|
||||
megatron_lm_num_micro_batches: 1
|
||||
megatron_lm_pp_degree: 2
|
||||
megatron_lm_recompute_activations: true
|
||||
megatron_lm_sequence_parallelism: false
|
||||
megatron_lm_tp_degree: 2
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
## Parallelism Strategies
|
||||
|
||||
### Tensor Parallelism (TP)
|
||||
|
||||
**Splits each transformer layer across GPUs**:
|
||||
|
||||
```python
|
||||
# Layer split across 2 GPUs
|
||||
# GPU 0: First half of attention heads
|
||||
# GPU 1: Second half of attention heads
|
||||
|
||||
# Each GPU computes partial outputs
|
||||
# All-reduce combines results
|
||||
```
|
||||
|
||||
**TP degree recommendations**:
|
||||
- **TP=1**: No tensor parallelism (single GPU per layer)
|
||||
- **TP=2**: 2 GPUs per layer (good for 7-13B models)
|
||||
- **TP=4**: 4 GPUs per layer (good for 20-40B models)
|
||||
- **TP=8**: 8 GPUs per layer (good for 70B+ models)
|
||||
|
||||
**Benefits**:
|
||||
- Reduces memory per GPU
|
||||
- All-reduce communication (fast)
|
||||
|
||||
**Drawbacks**:
|
||||
- Requires fast inter-GPU bandwidth (NVLink)
|
||||
- Communication overhead per layer
|
||||
|
||||
### Pipeline Parallelism (PP)
|
||||
|
||||
**Splits model depth across GPUs**:
|
||||
|
||||
```python
|
||||
# 12-layer model, PP=4
|
||||
# GPU 0: Layers 0-2
|
||||
# GPU 1: Layers 3-5
|
||||
# GPU 2: Layers 6-8
|
||||
# GPU 3: Layers 9-11
|
||||
```
|
||||
|
||||
**PP degree recommendations**:
|
||||
- **PP=1**: No pipeline parallelism
|
||||
- **PP=2**: 2 pipeline stages (good for 20-40B models)
|
||||
- **PP=4**: 4 pipeline stages (good for 70B+ models)
|
||||
- **PP=8**: 8 pipeline stages (good for 175B+ models)
|
||||
|
||||
**Benefits**:
|
||||
- Linear memory reduction (4× PP = 4× less memory)
|
||||
- Works across nodes (slower interconnect OK)
|
||||
|
||||
**Drawbacks**:
|
||||
- Pipeline bubbles (idle time)
|
||||
- Requires micro-batching
|
||||
|
||||
### Data Parallelism (DP)
|
||||
|
||||
**Replicates model across GPU groups**:
|
||||
|
||||
```python
|
||||
# 8 GPUs, TP=2, PP=2, DP=2
|
||||
# Group 0 (GPUs 0-3): Full model replica
|
||||
# Group 1 (GPUs 4-7): Full model replica
|
||||
```
|
||||
|
||||
**DP degree**:
|
||||
- `DP = total_gpus / (TP × PP)`
|
||||
- Example: 8 GPUs, TP=2, PP=2 → DP=2
|
||||
|
||||
**Benefits**:
|
||||
- Increases throughput
|
||||
- Scales batch size
|
||||
|
||||
### Sequence Parallelism
|
||||
|
||||
**Splits long sequences across GPUs** (extends TP):
|
||||
|
||||
```python
|
||||
# 8K sequence, TP=2, Sequence Parallel=True
|
||||
# GPU 0: Tokens 0-4095
|
||||
# GPU 1: Tokens 4096-8191
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Enables very long sequences (100K+ tokens)
|
||||
- Reduces activation memory
|
||||
|
||||
**Requirements**:
|
||||
- Must use with TP > 1
|
||||
- RoPE/ALiBi position encodings work best
|
||||
|
||||
## Accelerate Code Example
|
||||
|
||||
### Basic Setup
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import MegatronLMPlugin
|
||||
|
||||
# Configure Megatron
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
tp_degree=2, # Tensor parallelism degree
|
||||
pp_degree=2, # Pipeline parallelism degree
|
||||
num_micro_batches=4, # Micro-batches for pipeline
|
||||
gradient_clipping=1.0, # Gradient clipping value
|
||||
sequence_parallelism=False, # Enable sequence parallelism
|
||||
recompute_activations=True, # Activation checkpointing
|
||||
use_distributed_optimizer=True, # Distributed optimizer
|
||||
custom_prepare_model_function=None, # Custom model prep
|
||||
)
|
||||
|
||||
# Initialize accelerator
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
megatron_lm_plugin=megatron_plugin
|
||||
)
|
||||
|
||||
# Prepare model and optimizer
|
||||
model, optimizer, train_dataloader = accelerator.prepare(
|
||||
model, optimizer, train_dataloader
|
||||
)
|
||||
|
||||
# Training loop (same as DDP!)
|
||||
for batch in train_dataloader:
|
||||
optimizer.zero_grad()
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
### Full Training Script
|
||||
|
||||
```python
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import MegatronLMPlugin
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
def main():
|
||||
# Megatron configuration
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
tp_degree=2,
|
||||
pp_degree=2,
|
||||
num_micro_batches=4,
|
||||
gradient_clipping=1.0,
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
gradient_accumulation_steps=8,
|
||||
megatron_lm_plugin=megatron_plugin
|
||||
)
|
||||
|
||||
# Model
|
||||
config = GPT2Config(
|
||||
n_layer=24,
|
||||
n_head=16,
|
||||
n_embd=1024,
|
||||
)
|
||||
model = GPT2LMHeadModel(config)
|
||||
|
||||
# Optimizer
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4)
|
||||
|
||||
# Prepare
|
||||
model, optimizer, train_loader = accelerator.prepare(
|
||||
model, optimizer, train_loader
|
||||
)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(num_epochs):
|
||||
for batch in train_loader:
|
||||
with accelerator.accumulate(model):
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Save checkpoint
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.save_state(f'checkpoint-epoch-{epoch}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
```
|
||||
|
||||
### Launch Command
|
||||
|
||||
```bash
|
||||
# 8 GPUs, TP=2, PP=2, DP=2
|
||||
accelerate launch --multi_gpu --num_processes 8 train.py
|
||||
|
||||
# Multi-node (2 nodes, 8 GPUs each)
|
||||
# Node 0
|
||||
accelerate launch --multi_gpu --num_processes 16 \
|
||||
--num_machines 2 --machine_rank 0 \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
--main_process_port 29500 \
|
||||
train.py
|
||||
|
||||
# Node 1
|
||||
accelerate launch --multi_gpu --num_processes 16 \
|
||||
--num_machines 2 --machine_rank 1 \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
--main_process_port 29500 \
|
||||
train.py
|
||||
```
|
||||
|
||||
## Activation Checkpointing
|
||||
|
||||
**Reduces memory by recomputing activations**:
|
||||
|
||||
```python
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
recompute_activations=True, # Enable checkpointing
|
||||
checkpoint_num_layers=1, # Checkpoint every N layers
|
||||
distribute_checkpointed_activations=True, # Distribute across TP
|
||||
partition_activations=True, # Partition in PP
|
||||
check_for_nan_in_loss_and_grad=True, # Stability check
|
||||
)
|
||||
```
|
||||
|
||||
**Strategies**:
|
||||
- `SELECTIVE`: Checkpoint transformer blocks only
|
||||
- `FULL`: Checkpoint all layers
|
||||
- `NONE`: No checkpointing
|
||||
|
||||
**Memory savings**: 30-50% with 10-15% slowdown
|
||||
|
||||
## Distributed Optimizer
|
||||
|
||||
**Shards optimizer state across DP ranks**:
|
||||
|
||||
```python
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
use_distributed_optimizer=True, # Enable sharded optimizer
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Reduces optimizer memory by DP degree
|
||||
- Example: DP=4 → 4× less optimizer memory per GPU
|
||||
|
||||
**Compatible with**:
|
||||
- AdamW, Adam, SGD
|
||||
- Mixed precision training
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### Micro-Batch Size
|
||||
|
||||
```python
|
||||
# Pipeline parallelism requires micro-batching
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
pp_degree=4,
|
||||
num_micro_batches=16, # 16 micro-batches per pipeline
|
||||
)
|
||||
|
||||
# Effective batch = num_micro_batches × micro_batch_size × DP
|
||||
# Example: 16 × 2 × 4 = 128
|
||||
```
|
||||
|
||||
**Recommendations**:
|
||||
- More micro-batches → less pipeline bubble
|
||||
- Typical: 4-16 micro-batches
|
||||
|
||||
### Sequence Length
|
||||
|
||||
```python
|
||||
# For long sequences, enable sequence parallelism
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
tp_degree=4,
|
||||
sequence_parallelism=True, # Required: TP > 1
|
||||
)
|
||||
|
||||
# Enables sequences up to TP × normal limit
|
||||
# Example: TP=4, 8K normal → 32K with sequence parallel
|
||||
```
|
||||
|
||||
### GPU Topology
|
||||
|
||||
**NVLink required for TP**:
|
||||
```bash
|
||||
# Check NVLink topology
|
||||
nvidia-smi topo -m
|
||||
|
||||
# Good topology (NVLink between all GPUs)
|
||||
# GPU0 - GPU1: NV12 (fast)
|
||||
# GPU0 - GPU2: NV12 (fast)
|
||||
|
||||
# Bad topology (PCIe only)
|
||||
# GPU0 - GPU4: PHB (slow, avoid TP across these)
|
||||
```
|
||||
|
||||
**Recommendations**:
|
||||
- **TP**: Within same node (NVLink)
|
||||
- **PP**: Across nodes (slower interconnect OK)
|
||||
- **DP**: Any topology
|
||||
|
||||
## Model Size Guidelines
|
||||
|
||||
| Model Size | GPUs | TP | PP | DP | Micro-Batches |
|
||||
|------------|------|----|----|----|--------------|
|
||||
| 7B | 8 | 1 | 1 | 8 | 1 |
|
||||
| 13B | 8 | 2 | 1 | 4 | 1 |
|
||||
| 20B | 16 | 4 | 1 | 4 | 1 |
|
||||
| 40B | 32 | 4 | 2 | 4 | 4 |
|
||||
| 70B | 64 | 8 | 2 | 4 | 8 |
|
||||
| 175B | 128 | 8 | 4 | 4 | 16 |
|
||||
|
||||
**Assumptions**: BF16, 2K sequence length, A100 80GB
|
||||
|
||||
## Checkpointing
|
||||
|
||||
### Save Checkpoint
|
||||
|
||||
```python
|
||||
# Save full model state
|
||||
accelerator.save_state('checkpoint-1000')
|
||||
|
||||
# Megatron saves separate files per rank
|
||||
# checkpoint-1000/
|
||||
# pytorch_model_tp_0_pp_0.bin
|
||||
# pytorch_model_tp_0_pp_1.bin
|
||||
# pytorch_model_tp_1_pp_0.bin
|
||||
# pytorch_model_tp_1_pp_1.bin
|
||||
# optimizer_tp_0_pp_0.bin
|
||||
# ...
|
||||
```
|
||||
|
||||
### Load Checkpoint
|
||||
|
||||
```python
|
||||
# Resume training
|
||||
accelerator.load_state('checkpoint-1000')
|
||||
|
||||
# Automatically loads correct shard per rank
|
||||
```
|
||||
|
||||
### Convert to Standard PyTorch
|
||||
|
||||
```bash
|
||||
# Merge Megatron checkpoint to single file
|
||||
python merge_megatron_checkpoint.py \
|
||||
--checkpoint-dir checkpoint-1000 \
|
||||
--output pytorch_model.bin
|
||||
```
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Issue: OOM with Pipeline Parallelism
|
||||
|
||||
**Solution**: Increase micro-batches
|
||||
```python
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
pp_degree=4,
|
||||
num_micro_batches=16, # Increase from 4
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Slow Training
|
||||
|
||||
**Check 1**: Pipeline bubbles (PP too high)
|
||||
```python
|
||||
# Reduce PP, increase TP
|
||||
tp_degree=4 # Increase
|
||||
pp_degree=2 # Decrease
|
||||
```
|
||||
|
||||
**Check 2**: Micro-batch size too small
|
||||
```python
|
||||
num_micro_batches=8 # Increase
|
||||
```
|
||||
|
||||
### Issue: NVLink Not Detected
|
||||
|
||||
```bash
|
||||
# Verify NVLink
|
||||
nvidia-smi nvlink -s
|
||||
|
||||
# If no NVLink, avoid TP > 1
|
||||
# Use PP or DP instead
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Megatron-LM: https://github.com/NVIDIA/Megatron-LM
|
||||
- Accelerate Megatron docs: https://huggingface.co/docs/accelerate/usage_guides/megatron_lm
|
||||
- Paper: "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism"
|
||||
- NVIDIA Apex: https://github.com/NVIDIA/apex
|
||||
525
optional-skills/mlops/accelerate/references/performance.md
Normal file
525
optional-skills/mlops/accelerate/references/performance.md
Normal file
|
|
@ -0,0 +1,525 @@
|
|||
# Accelerate Performance Tuning
|
||||
|
||||
## Profiling
|
||||
|
||||
### Basic Profiling
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
import time
|
||||
|
||||
accelerator = Accelerator()
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
batch = next(iter(dataloader))
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Profile training loop
|
||||
start = time.time()
|
||||
total_batches = 100
|
||||
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= total_batches:
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
accelerator.wait_for_everyone() # Sync all processes
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Metrics
|
||||
batches_per_sec = total_batches / elapsed
|
||||
samples_per_sec = (total_batches * batch_size * accelerator.num_processes) / elapsed
|
||||
|
||||
print(f"Throughput: {samples_per_sec:.2f} samples/sec")
|
||||
print(f"Batches/sec: {batches_per_sec:.2f}")
|
||||
```
|
||||
|
||||
### PyTorch Profiler Integration
|
||||
|
||||
```python
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True
|
||||
) as prof:
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= 10: # Profile first 10 batches
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Print profiling results
|
||||
print(prof.key_averages().table(
|
||||
sort_by="cuda_time_total", row_limit=20
|
||||
))
|
||||
|
||||
# Export to Chrome tracing
|
||||
prof.export_chrome_trace("trace.json")
|
||||
# View at chrome://tracing
|
||||
```
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
### 1. Gradient Accumulation
|
||||
|
||||
**Problem**: Large batch size causes OOM
|
||||
|
||||
**Solution**: Accumulate gradients across micro-batches
|
||||
|
||||
```python
|
||||
accelerator = Accelerator(gradient_accumulation_steps=8)
|
||||
|
||||
# Effective batch = batch_size × accumulation_steps × num_gpus
|
||||
# Example: 4 × 8 × 8 = 256
|
||||
|
||||
for batch in dataloader:
|
||||
with accelerator.accumulate(model): # Handles accumulation logic
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
**Memory savings**: 8× less activation memory (with 8 accumulation steps)
|
||||
|
||||
### 2. Gradient Checkpointing
|
||||
|
||||
**Enable in model**:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"gpt2",
|
||||
use_cache=False # Required for gradient checkpointing
|
||||
)
|
||||
|
||||
# Enable checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Prepare with Accelerate
|
||||
model = accelerator.prepare(model)
|
||||
```
|
||||
|
||||
**Memory savings**: 30-50% with 10-15% slowdown
|
||||
|
||||
### 3. Mixed Precision
|
||||
|
||||
**BF16 (A100/H100)**:
|
||||
```python
|
||||
accelerator = Accelerator(mixed_precision='bf16')
|
||||
|
||||
# Automatic mixed precision
|
||||
for batch in dataloader:
|
||||
outputs = model(**batch) # Forward in BF16
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss) # Backward in FP32
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**FP16 (V100, older GPUs)**:
|
||||
```python
|
||||
from accelerate.utils import GradScalerKwargs
|
||||
|
||||
scaler_kwargs = GradScalerKwargs(
|
||||
init_scale=2.**16,
|
||||
growth_interval=2000
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='fp16',
|
||||
kwargs_handlers=[scaler_kwargs]
|
||||
)
|
||||
```
|
||||
|
||||
**Memory savings**: 50% compared to FP32
|
||||
|
||||
### 4. CPU Offloading (DeepSpeed)
|
||||
|
||||
```python
|
||||
from accelerate.utils import DeepSpeedPlugin
|
||||
|
||||
ds_plugin = DeepSpeedPlugin(
|
||||
zero_stage=3,
|
||||
offload_optimizer_device="cpu", # Offload optimizer to CPU
|
||||
offload_param_device="cpu", # Offload parameters to CPU
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
deepspeed_plugin=ds_plugin,
|
||||
mixed_precision='bf16'
|
||||
)
|
||||
```
|
||||
|
||||
**Memory savings**: 10-20× for optimizer state, 5-10× for parameters
|
||||
|
||||
**Trade-off**: 20-30% slower due to CPU-GPU transfers
|
||||
|
||||
### 5. Flash Attention
|
||||
|
||||
```python
|
||||
# Install flash-attn
|
||||
# pip install flash-attn
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"gpt2",
|
||||
attn_implementation="flash_attention_2" # Enable Flash Attention 2
|
||||
)
|
||||
|
||||
model = accelerator.prepare(model)
|
||||
```
|
||||
|
||||
**Memory savings**: 50% for attention, 2× faster
|
||||
|
||||
**Requirements**: A100/H100, sequence length must be multiple of 128
|
||||
|
||||
## Communication Optimization
|
||||
|
||||
### 1. Gradient Bucketing (DDP)
|
||||
|
||||
```python
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(
|
||||
bucket_cap_mb=25, # Bucket size for gradient reduction
|
||||
gradient_as_bucket_view=True, # Reduce memory copies
|
||||
static_graph=False # Set True if model doesn't change
|
||||
)
|
||||
|
||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
||||
```
|
||||
|
||||
**Recommended bucket sizes**:
|
||||
- Small models (<1B): 25 MB
|
||||
- Medium models (1-10B): 50-100 MB
|
||||
- Large models (>10B): 100-200 MB
|
||||
|
||||
### 2. Find Unused Parameters
|
||||
|
||||
```python
|
||||
# Only enable if model has unused parameters (slower!)
|
||||
ddp_kwargs = DistributedDataParallelKwargs(
|
||||
find_unused_parameters=True
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Models with conditional branches (e.g., mixture of experts)
|
||||
|
||||
**Cost**: 10-20% slower
|
||||
|
||||
### 3. NCCL Tuning
|
||||
|
||||
```bash
|
||||
# Set environment variables before launch
|
||||
export NCCL_DEBUG=INFO # Debug info
|
||||
export NCCL_IB_DISABLE=0 # Enable InfiniBand
|
||||
export NCCL_SOCKET_IFNAME=eth0 # Network interface
|
||||
export NCCL_P2P_LEVEL=NVL # Use NVLink
|
||||
|
||||
accelerate launch train.py
|
||||
```
|
||||
|
||||
**NCCL_P2P_LEVEL options**:
|
||||
- `NVL`: NVLink (fastest, within node)
|
||||
- `PIX`: PCIe (fast, within node)
|
||||
- `PHB`: PCIe host bridge (slow, cross-node)
|
||||
|
||||
## Data Loading Optimization
|
||||
|
||||
### 1. DataLoader Workers
|
||||
|
||||
```python
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
train_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=32,
|
||||
num_workers=4, # Parallel data loading
|
||||
pin_memory=True, # Pin memory for faster GPU transfer
|
||||
prefetch_factor=2, # Prefetch batches per worker
|
||||
persistent_workers=True # Keep workers alive between epochs
|
||||
)
|
||||
|
||||
train_loader = accelerator.prepare(train_loader)
|
||||
```
|
||||
|
||||
**Recommendations**:
|
||||
- `num_workers`: 2-4 per GPU (8 GPUs → 16-32 workers)
|
||||
- `pin_memory`: Always True for GPU training
|
||||
- `prefetch_factor`: 2-4 (higher for slow data loading)
|
||||
|
||||
### 2. Data Preprocessing
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
# Bad: Preprocess during training (slow)
|
||||
dataset = load_dataset("openwebtext")
|
||||
|
||||
for batch in dataset:
|
||||
tokens = tokenizer(batch['text']) # Slow!
|
||||
...
|
||||
|
||||
# Good: Preprocess once, save
|
||||
dataset = load_dataset("openwebtext")
|
||||
tokenized = dataset.map(
|
||||
lambda x: tokenizer(x['text']),
|
||||
batched=True,
|
||||
num_proc=8, # Parallel preprocessing
|
||||
remove_columns=['text']
|
||||
)
|
||||
tokenized.save_to_disk("preprocessed_data")
|
||||
|
||||
# Load preprocessed
|
||||
dataset = load_from_disk("preprocessed_data")
|
||||
```
|
||||
|
||||
### 3. Faster Tokenization
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
# Enable Rust-based tokenizers (10× faster)
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"gpt2",
|
||||
use_fast=True # Use fast Rust tokenizer
|
||||
)
|
||||
```
|
||||
|
||||
## Compilation (PyTorch 2.0+)
|
||||
|
||||
### Compile Model
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Compile model for faster execution
|
||||
model = torch.compile(
|
||||
model,
|
||||
mode="reduce-overhead", # Options: default, reduce-overhead, max-autotune
|
||||
fullgraph=False, # Compile entire graph (stricter)
|
||||
dynamic=True # Support dynamic shapes
|
||||
)
|
||||
|
||||
model = accelerator.prepare(model)
|
||||
```
|
||||
|
||||
**Speedup**: 10-50% depending on model
|
||||
|
||||
**Compilation modes**:
|
||||
- `default`: Balanced (best for most cases)
|
||||
- `reduce-overhead`: Min overhead (best for small batches)
|
||||
- `max-autotune`: Max performance (slow compile, best for production)
|
||||
|
||||
### Compilation Best Practices
|
||||
|
||||
```python
|
||||
# Bad: Compile after prepare (won't work)
|
||||
model = accelerator.prepare(model)
|
||||
model = torch.compile(model) # Error!
|
||||
|
||||
# Good: Compile before prepare
|
||||
model = torch.compile(model)
|
||||
model = accelerator.prepare(model)
|
||||
|
||||
# Training loop
|
||||
for batch in dataloader:
|
||||
# First iteration: slow (compilation)
|
||||
# Subsequent iterations: fast (compiled)
|
||||
outputs = model(**batch)
|
||||
...
|
||||
```
|
||||
|
||||
## Benchmarking Different Strategies
|
||||
|
||||
### Script Template
|
||||
|
||||
```python
|
||||
import time
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
|
||||
def benchmark_strategy(strategy_name, accelerator_kwargs):
|
||||
"""Benchmark a specific training strategy."""
|
||||
accelerator = Accelerator(**accelerator_kwargs)
|
||||
|
||||
# Setup
|
||||
model = create_model()
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||||
dataloader = create_dataloader()
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(
|
||||
model, optimizer, dataloader
|
||||
)
|
||||
|
||||
# Warmup
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= 10:
|
||||
break
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Benchmark
|
||||
accelerator.wait_for_everyone()
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
|
||||
num_batches = 100
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= num_batches:
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Metrics
|
||||
throughput = (num_batches * batch_size * accelerator.num_processes) / elapsed
|
||||
memory_used = torch.cuda.max_memory_allocated() / 1e9 # GB
|
||||
|
||||
if accelerator.is_main_process:
|
||||
print(f"\n{strategy_name}:")
|
||||
print(f" Throughput: {throughput:.2f} samples/sec")
|
||||
print(f" Memory: {memory_used:.2f} GB")
|
||||
print(f" Time: {elapsed:.2f} sec")
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Benchmark different strategies
|
||||
strategies = [
|
||||
("DDP + FP32", {}),
|
||||
("DDP + BF16", {"mixed_precision": "bf16"}),
|
||||
("DDP + BF16 + GradAccum", {"mixed_precision": "bf16", "gradient_accumulation_steps": 4}),
|
||||
("FSDP", {"fsdp_plugin": fsdp_plugin}),
|
||||
("DeepSpeed ZeRO-2", {"deepspeed_plugin": ds_plugin_stage2}),
|
||||
("DeepSpeed ZeRO-3", {"deepspeed_plugin": ds_plugin_stage3}),
|
||||
]
|
||||
|
||||
for name, kwargs in strategies:
|
||||
benchmark_strategy(name, kwargs)
|
||||
```
|
||||
|
||||
## Performance Checklist
|
||||
|
||||
**Before training**:
|
||||
- [ ] Use BF16/FP16 mixed precision
|
||||
- [ ] Enable gradient checkpointing (if OOM)
|
||||
- [ ] Set appropriate `num_workers` (2-4 per GPU)
|
||||
- [ ] Enable `pin_memory=True`
|
||||
- [ ] Preprocess data once, not during training
|
||||
- [ ] Compile model with `torch.compile` (PyTorch 2.0+)
|
||||
|
||||
**For large models**:
|
||||
- [ ] Use FSDP or DeepSpeed ZeRO-3
|
||||
- [ ] Enable CPU offloading (if still OOM)
|
||||
- [ ] Use Flash Attention
|
||||
- [ ] Increase gradient accumulation
|
||||
|
||||
**For multi-node**:
|
||||
- [ ] Check network topology (InfiniBand > Ethernet)
|
||||
- [ ] Tune NCCL settings
|
||||
- [ ] Use larger bucket sizes for DDP
|
||||
- [ ] Verify NVLink for tensor parallelism
|
||||
|
||||
**Profiling**:
|
||||
- [ ] Profile first 10-100 batches
|
||||
- [ ] Check GPU utilization (`nvidia-smi dmon`)
|
||||
- [ ] Check data loading time (should be <5% of iteration)
|
||||
- [ ] Identify communication bottlenecks
|
||||
|
||||
## Common Performance Issues
|
||||
|
||||
### Issue: Low GPU Utilization (<80%)
|
||||
|
||||
**Cause 1**: Data loading bottleneck
|
||||
```python
|
||||
# Solution: Increase workers and prefetch
|
||||
num_workers=8
|
||||
prefetch_factor=4
|
||||
```
|
||||
|
||||
**Cause 2**: Small batch size
|
||||
```python
|
||||
# Solution: Increase batch size or use gradient accumulation
|
||||
batch_size=32 # Increase
|
||||
gradient_accumulation_steps=4 # Or accumulate
|
||||
```
|
||||
|
||||
### Issue: High Memory Usage
|
||||
|
||||
**Solution 1**: Gradient checkpointing
|
||||
```python
|
||||
model.gradient_checkpointing_enable()
|
||||
```
|
||||
|
||||
**Solution 2**: Reduce batch size, increase accumulation
|
||||
```python
|
||||
batch_size=8 # Reduce from 32
|
||||
gradient_accumulation_steps=16 # Maintain effective batch
|
||||
```
|
||||
|
||||
**Solution 3**: Use FSDP or DeepSpeed ZeRO-3
|
||||
```python
|
||||
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
|
||||
```
|
||||
|
||||
### Issue: Slow Multi-GPU Training
|
||||
|
||||
**Cause**: Communication bottleneck
|
||||
|
||||
**Check 1**: Gradient bucket size
|
||||
```python
|
||||
ddp_kwargs = DistributedDataParallelKwargs(bucket_cap_mb=100)
|
||||
```
|
||||
|
||||
**Check 2**: NCCL settings
|
||||
```bash
|
||||
export NCCL_DEBUG=INFO
|
||||
# Check for "Using NVLS" (good) vs "Using PHB" (bad)
|
||||
```
|
||||
|
||||
**Check 3**: Network bandwidth
|
||||
```bash
|
||||
# Test inter-GPU bandwidth
|
||||
nvidia-smi nvlink -s
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Accelerate Performance: https://huggingface.co/docs/accelerate/usage_guides/performance
|
||||
- PyTorch Profiler: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
|
||||
- NCCL Tuning: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html
|
||||
- Flash Attention: https://github.com/Dao-AILab/flash-attention
|
||||
409
optional-skills/mlops/chroma/SKILL.md
Normal file
409
optional-skills/mlops/chroma/SKILL.md
Normal file
|
|
@ -0,0 +1,409 @@
|
|||
---
|
||||
name: chroma
|
||||
description: Open-source embedding database for AI applications. Store embeddings and metadata, perform vector and full-text search, filter by metadata. Simple 4-function API. Scales from notebooks to production clusters. Use for semantic search, RAG applications, or document retrieval. Best for local development and open-source projects.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [chromadb, sentence-transformers]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [RAG, Chroma, Vector Database, Embeddings, Semantic Search, Open Source, Self-Hosted, Document Retrieval, Metadata Filtering]
|
||||
|
||||
---
|
||||
|
||||
# Chroma - Open-Source Embedding Database
|
||||
|
||||
The AI-native database for building LLM applications with memory.
|
||||
|
||||
## When to use Chroma
|
||||
|
||||
**Use Chroma when:**
|
||||
- Building RAG (retrieval-augmented generation) applications
|
||||
- Need local/self-hosted vector database
|
||||
- Want open-source solution (Apache 2.0)
|
||||
- Prototyping in notebooks
|
||||
- Semantic search over documents
|
||||
- Storing embeddings with metadata
|
||||
|
||||
**Metrics**:
|
||||
- **24,300+ GitHub stars**
|
||||
- **1,900+ forks**
|
||||
- **v1.3.3** (stable, weekly releases)
|
||||
- **Apache 2.0 license**
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **Pinecone**: Managed cloud, auto-scaling
|
||||
- **FAISS**: Pure similarity search, no metadata
|
||||
- **Weaviate**: Production ML-native database
|
||||
- **Qdrant**: High performance, Rust-based
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Python
|
||||
pip install chromadb
|
||||
|
||||
# JavaScript/TypeScript
|
||||
npm install chromadb @chroma-core/default-embed
|
||||
```
|
||||
|
||||
### Basic usage (Python)
|
||||
|
||||
```python
|
||||
import chromadb
|
||||
|
||||
# Create client
|
||||
client = chromadb.Client()
|
||||
|
||||
# Create collection
|
||||
collection = client.create_collection(name="my_collection")
|
||||
|
||||
# Add documents
|
||||
collection.add(
|
||||
documents=["This is document 1", "This is document 2"],
|
||||
metadatas=[{"source": "doc1"}, {"source": "doc2"}],
|
||||
ids=["id1", "id2"]
|
||||
)
|
||||
|
||||
# Query
|
||||
results = collection.query(
|
||||
query_texts=["document about topic"],
|
||||
n_results=2
|
||||
)
|
||||
|
||||
print(results)
|
||||
```
|
||||
|
||||
## Core operations
|
||||
|
||||
### 1. Create collection
|
||||
|
||||
```python
|
||||
# Simple collection
|
||||
collection = client.create_collection("my_docs")
|
||||
|
||||
# With custom embedding function
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key="your-key",
|
||||
model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
collection = client.create_collection(
|
||||
name="my_docs",
|
||||
embedding_function=openai_ef
|
||||
)
|
||||
|
||||
# Get existing collection
|
||||
collection = client.get_collection("my_docs")
|
||||
|
||||
# Delete collection
|
||||
client.delete_collection("my_docs")
|
||||
```
|
||||
|
||||
### 2. Add documents
|
||||
|
||||
```python
|
||||
# Add with auto-generated IDs
|
||||
collection.add(
|
||||
documents=["Doc 1", "Doc 2", "Doc 3"],
|
||||
metadatas=[
|
||||
{"source": "web", "category": "tutorial"},
|
||||
{"source": "pdf", "page": 5},
|
||||
{"source": "api", "timestamp": "2025-01-01"}
|
||||
],
|
||||
ids=["id1", "id2", "id3"]
|
||||
)
|
||||
|
||||
# Add with custom embeddings
|
||||
collection.add(
|
||||
embeddings=[[0.1, 0.2, ...], [0.3, 0.4, ...]],
|
||||
documents=["Doc 1", "Doc 2"],
|
||||
ids=["id1", "id2"]
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Query (similarity search)
|
||||
|
||||
```python
|
||||
# Basic query
|
||||
results = collection.query(
|
||||
query_texts=["machine learning tutorial"],
|
||||
n_results=5
|
||||
)
|
||||
|
||||
# Query with filters
|
||||
results = collection.query(
|
||||
query_texts=["Python programming"],
|
||||
n_results=3,
|
||||
where={"source": "web"}
|
||||
)
|
||||
|
||||
# Query with metadata filters
|
||||
results = collection.query(
|
||||
query_texts=["advanced topics"],
|
||||
where={
|
||||
"$and": [
|
||||
{"category": "tutorial"},
|
||||
{"difficulty": {"$gte": 3}}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
# Access results
|
||||
print(results["documents"]) # List of matching documents
|
||||
print(results["metadatas"]) # Metadata for each doc
|
||||
print(results["distances"]) # Similarity scores
|
||||
print(results["ids"]) # Document IDs
|
||||
```
|
||||
|
||||
### 4. Get documents
|
||||
|
||||
```python
|
||||
# Get by IDs
|
||||
docs = collection.get(
|
||||
ids=["id1", "id2"]
|
||||
)
|
||||
|
||||
# Get with filters
|
||||
docs = collection.get(
|
||||
where={"category": "tutorial"},
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Get all documents
|
||||
docs = collection.get()
|
||||
```
|
||||
|
||||
### 5. Update documents
|
||||
|
||||
```python
|
||||
# Update document content
|
||||
collection.update(
|
||||
ids=["id1"],
|
||||
documents=["Updated content"],
|
||||
metadatas=[{"source": "updated"}]
|
||||
)
|
||||
```
|
||||
|
||||
### 6. Delete documents
|
||||
|
||||
```python
|
||||
# Delete by IDs
|
||||
collection.delete(ids=["id1", "id2"])
|
||||
|
||||
# Delete with filter
|
||||
collection.delete(
|
||||
where={"source": "outdated"}
|
||||
)
|
||||
```
|
||||
|
||||
## Persistent storage
|
||||
|
||||
```python
|
||||
# Persist to disk
|
||||
client = chromadb.PersistentClient(path="./chroma_db")
|
||||
|
||||
collection = client.create_collection("my_docs")
|
||||
collection.add(documents=["Doc 1"], ids=["id1"])
|
||||
|
||||
# Data persisted automatically
|
||||
# Reload later with same path
|
||||
client = chromadb.PersistentClient(path="./chroma_db")
|
||||
collection = client.get_collection("my_docs")
|
||||
```
|
||||
|
||||
## Embedding functions
|
||||
|
||||
### Default (Sentence Transformers)
|
||||
|
||||
```python
|
||||
# Uses sentence-transformers by default
|
||||
collection = client.create_collection("my_docs")
|
||||
# Default model: all-MiniLM-L6-v2
|
||||
```
|
||||
|
||||
### OpenAI
|
||||
|
||||
```python
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key="your-key",
|
||||
model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
collection = client.create_collection(
|
||||
name="openai_docs",
|
||||
embedding_function=openai_ef
|
||||
)
|
||||
```
|
||||
|
||||
### HuggingFace
|
||||
|
||||
```python
|
||||
huggingface_ef = embedding_functions.HuggingFaceEmbeddingFunction(
|
||||
api_key="your-key",
|
||||
model_name="sentence-transformers/all-mpnet-base-v2"
|
||||
)
|
||||
|
||||
collection = client.create_collection(
|
||||
name="hf_docs",
|
||||
embedding_function=huggingface_ef
|
||||
)
|
||||
```
|
||||
|
||||
### Custom embedding function
|
||||
|
||||
```python
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
|
||||
class MyEmbeddingFunction(EmbeddingFunction):
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
# Your embedding logic
|
||||
return embeddings
|
||||
|
||||
my_ef = MyEmbeddingFunction()
|
||||
collection = client.create_collection(
|
||||
name="custom_docs",
|
||||
embedding_function=my_ef
|
||||
)
|
||||
```
|
||||
|
||||
## Metadata filtering
|
||||
|
||||
```python
|
||||
# Exact match
|
||||
results = collection.query(
|
||||
query_texts=["query"],
|
||||
where={"category": "tutorial"}
|
||||
)
|
||||
|
||||
# Comparison operators
|
||||
results = collection.query(
|
||||
query_texts=["query"],
|
||||
where={"page": {"$gt": 10}} # $gt, $gte, $lt, $lte, $ne
|
||||
)
|
||||
|
||||
# Logical operators
|
||||
results = collection.query(
|
||||
query_texts=["query"],
|
||||
where={
|
||||
"$and": [
|
||||
{"category": "tutorial"},
|
||||
{"difficulty": {"$lte": 3}}
|
||||
]
|
||||
} # Also: $or
|
||||
)
|
||||
|
||||
# Contains
|
||||
results = collection.query(
|
||||
query_texts=["query"],
|
||||
where={"tags": {"$in": ["python", "ml"]}}
|
||||
)
|
||||
```
|
||||
|
||||
## LangChain integration
|
||||
|
||||
```python
|
||||
from langchain_chroma import Chroma
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
# Split documents
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000)
|
||||
docs = text_splitter.split_documents(documents)
|
||||
|
||||
# Create Chroma vector store
|
||||
vectorstore = Chroma.from_documents(
|
||||
documents=docs,
|
||||
embedding=OpenAIEmbeddings(),
|
||||
persist_directory="./chroma_db"
|
||||
)
|
||||
|
||||
# Query
|
||||
results = vectorstore.similarity_search("machine learning", k=3)
|
||||
|
||||
# As retriever
|
||||
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
|
||||
```
|
||||
|
||||
## LlamaIndex integration
|
||||
|
||||
```python
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.core import VectorStoreIndex, StorageContext
|
||||
import chromadb
|
||||
|
||||
# Initialize Chroma
|
||||
db = chromadb.PersistentClient(path="./chroma_db")
|
||||
collection = db.get_or_create_collection("my_collection")
|
||||
|
||||
# Create vector store
|
||||
vector_store = ChromaVectorStore(chroma_collection=collection)
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
|
||||
# Create index
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
storage_context=storage_context
|
||||
)
|
||||
|
||||
# Query
|
||||
query_engine = index.as_query_engine()
|
||||
response = query_engine.query("What is machine learning?")
|
||||
```
|
||||
|
||||
## Server mode
|
||||
|
||||
```python
|
||||
# Run Chroma server
|
||||
# Terminal: chroma run --path ./chroma_db --port 8000
|
||||
|
||||
# Connect to server
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
client = chromadb.HttpClient(
|
||||
host="localhost",
|
||||
port=8000,
|
||||
settings=Settings(anonymized_telemetry=False)
|
||||
)
|
||||
|
||||
# Use as normal
|
||||
collection = client.get_or_create_collection("my_docs")
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use persistent client** - Don't lose data on restart
|
||||
2. **Add metadata** - Enables filtering and tracking
|
||||
3. **Batch operations** - Add multiple docs at once
|
||||
4. **Choose right embedding model** - Balance speed/quality
|
||||
5. **Use filters** - Narrow search space
|
||||
6. **Unique IDs** - Avoid collisions
|
||||
7. **Regular backups** - Copy chroma_db directory
|
||||
8. **Monitor collection size** - Scale up if needed
|
||||
9. **Test embedding functions** - Ensure quality
|
||||
10. **Use server mode for production** - Better for multi-user
|
||||
|
||||
## Performance
|
||||
|
||||
| Operation | Latency | Notes |
|
||||
|-----------|---------|-------|
|
||||
| Add 100 docs | ~1-3s | With embedding |
|
||||
| Query (top 10) | ~50-200ms | Depends on collection size |
|
||||
| Metadata filter | ~10-50ms | Fast with proper indexing |
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/chroma-core/chroma ⭐ 24,300+
|
||||
- **Docs**: https://docs.trychroma.com
|
||||
- **Discord**: https://discord.gg/MMeYNTmh3x
|
||||
- **Version**: 1.3.3+
|
||||
- **License**: Apache 2.0
|
||||
|
||||
|
||||
38
optional-skills/mlops/chroma/references/integration.md
Normal file
38
optional-skills/mlops/chroma/references/integration.md
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
# Chroma Integration Guide
|
||||
|
||||
Integration with LangChain, LlamaIndex, and frameworks.
|
||||
|
||||
## LangChain
|
||||
|
||||
```python
|
||||
from langchain_chroma import Chroma
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
vectorstore = Chroma.from_documents(
|
||||
documents=docs,
|
||||
embedding=OpenAIEmbeddings(),
|
||||
persist_directory="./chroma_db"
|
||||
)
|
||||
|
||||
# Query
|
||||
results = vectorstore.similarity_search("query", k=3)
|
||||
|
||||
# As retriever
|
||||
retriever = vectorstore.as_retriever()
|
||||
```
|
||||
|
||||
## LlamaIndex
|
||||
|
||||
```python
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
import chromadb
|
||||
|
||||
db = chromadb.PersistentClient(path="./chroma_db")
|
||||
collection = db.get_or_create_collection("docs")
|
||||
|
||||
vector_store = ChromaVectorStore(chroma_collection=collection)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Docs**: https://docs.trychroma.com
|
||||
224
optional-skills/mlops/faiss/SKILL.md
Normal file
224
optional-skills/mlops/faiss/SKILL.md
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
---
|
||||
name: faiss
|
||||
description: Facebook's library for efficient similarity search and clustering of dense vectors. Supports billions of vectors, GPU acceleration, and various index types (Flat, IVF, HNSW). Use for fast k-NN search, large-scale vector retrieval, or when you need pure similarity search without metadata. Best for high-performance applications.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [faiss-cpu, faiss-gpu, numpy]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [RAG, FAISS, Similarity Search, Vector Search, Facebook AI, GPU Acceleration, Billion-Scale, K-NN, HNSW, High Performance, Large Scale]
|
||||
|
||||
---
|
||||
|
||||
# FAISS - Efficient Similarity Search
|
||||
|
||||
Facebook AI's library for billion-scale vector similarity search.
|
||||
|
||||
## When to use FAISS
|
||||
|
||||
**Use FAISS when:**
|
||||
- Need fast similarity search on large vector datasets (millions/billions)
|
||||
- GPU acceleration required
|
||||
- Pure vector similarity (no metadata filtering needed)
|
||||
- High throughput, low latency critical
|
||||
- Offline/batch processing of embeddings
|
||||
|
||||
**Metrics**:
|
||||
- **31,700+ GitHub stars**
|
||||
- Meta/Facebook AI Research
|
||||
- **Handles billions of vectors**
|
||||
- **C++** with Python bindings
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **Chroma/Pinecone**: Need metadata filtering
|
||||
- **Weaviate**: Need full database features
|
||||
- **Annoy**: Simpler, fewer features
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# CPU only
|
||||
pip install faiss-cpu
|
||||
|
||||
# GPU support
|
||||
pip install faiss-gpu
|
||||
```
|
||||
|
||||
### Basic usage
|
||||
|
||||
```python
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
# Create sample data (1000 vectors, 128 dimensions)
|
||||
d = 128
|
||||
nb = 1000
|
||||
vectors = np.random.random((nb, d)).astype('float32')
|
||||
|
||||
# Create index
|
||||
index = faiss.IndexFlatL2(d) # L2 distance
|
||||
index.add(vectors) # Add vectors
|
||||
|
||||
# Search
|
||||
k = 5 # Find 5 nearest neighbors
|
||||
query = np.random.random((1, d)).astype('float32')
|
||||
distances, indices = index.search(query, k)
|
||||
|
||||
print(f"Nearest neighbors: {indices}")
|
||||
print(f"Distances: {distances}")
|
||||
```
|
||||
|
||||
## Index types
|
||||
|
||||
### 1. Flat (exact search)
|
||||
|
||||
```python
|
||||
# L2 (Euclidean) distance
|
||||
index = faiss.IndexFlatL2(d)
|
||||
|
||||
# Inner product (cosine similarity if normalized)
|
||||
index = faiss.IndexFlatIP(d)
|
||||
|
||||
# Slowest, most accurate
|
||||
```
|
||||
|
||||
### 2. IVF (inverted file) - Fast approximate
|
||||
|
||||
```python
|
||||
# Create quantizer
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
|
||||
# IVF index with 100 clusters
|
||||
nlist = 100
|
||||
index = faiss.IndexIVFFlat(quantizer, d, nlist)
|
||||
|
||||
# Train on data
|
||||
index.train(vectors)
|
||||
|
||||
# Add vectors
|
||||
index.add(vectors)
|
||||
|
||||
# Search (nprobe = clusters to search)
|
||||
index.nprobe = 10
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
### 3. HNSW (Hierarchical NSW) - Best quality/speed
|
||||
|
||||
```python
|
||||
# HNSW index
|
||||
M = 32 # Number of connections per layer
|
||||
index = faiss.IndexHNSWFlat(d, M)
|
||||
|
||||
# No training needed
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
### 4. Product Quantization - Memory efficient
|
||||
|
||||
```python
|
||||
# PQ reduces memory by 16-32×
|
||||
m = 8 # Number of subquantizers
|
||||
nbits = 8
|
||||
index = faiss.IndexPQ(d, m, nbits)
|
||||
|
||||
# Train and add
|
||||
index.train(vectors)
|
||||
index.add(vectors)
|
||||
```
|
||||
|
||||
## Save and load
|
||||
|
||||
```python
|
||||
# Save index
|
||||
faiss.write_index(index, "large.index")
|
||||
|
||||
# Load index
|
||||
index = faiss.read_index("large.index")
|
||||
|
||||
# Continue using
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
## GPU acceleration
|
||||
|
||||
```python
|
||||
# Single GPU
|
||||
res = faiss.StandardGpuResources()
|
||||
index_cpu = faiss.IndexFlatL2(d)
|
||||
index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu) # GPU 0
|
||||
|
||||
# Multi-GPU
|
||||
index_gpu = faiss.index_cpu_to_all_gpus(index_cpu)
|
||||
|
||||
# 10-100× faster than CPU
|
||||
```
|
||||
|
||||
## LangChain integration
|
||||
|
||||
```python
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
# Create FAISS vector store
|
||||
vectorstore = FAISS.from_documents(docs, OpenAIEmbeddings())
|
||||
|
||||
# Save
|
||||
vectorstore.save_local("faiss_index")
|
||||
|
||||
# Load
|
||||
vectorstore = FAISS.load_local(
|
||||
"faiss_index",
|
||||
OpenAIEmbeddings(),
|
||||
allow_dangerous_deserialization=True
|
||||
)
|
||||
|
||||
# Search
|
||||
results = vectorstore.similarity_search("query", k=5)
|
||||
```
|
||||
|
||||
## LlamaIndex integration
|
||||
|
||||
```python
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
import faiss
|
||||
|
||||
# Create FAISS index
|
||||
d = 1536
|
||||
faiss_index = faiss.IndexFlatL2(d)
|
||||
|
||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Choose right index type** - Flat for <10K, IVF for 10K-1M, HNSW for quality
|
||||
2. **Normalize for cosine** - Use IndexFlatIP with normalized vectors
|
||||
3. **Use GPU for large datasets** - 10-100× faster
|
||||
4. **Save trained indices** - Training is expensive
|
||||
5. **Tune nprobe/ef_search** - Balance speed/accuracy
|
||||
6. **Monitor memory** - PQ for large datasets
|
||||
7. **Batch queries** - Better GPU utilization
|
||||
|
||||
## Performance
|
||||
|
||||
| Index Type | Build Time | Search Time | Memory | Accuracy |
|
||||
|------------|------------|-------------|--------|----------|
|
||||
| Flat | Fast | Slow | High | 100% |
|
||||
| IVF | Medium | Fast | Medium | 95-99% |
|
||||
| HNSW | Slow | Fastest | High | 99% |
|
||||
| PQ | Medium | Fast | Low | 90-95% |
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/facebookresearch/faiss ⭐ 31,700+
|
||||
- **Wiki**: https://github.com/facebookresearch/faiss/wiki
|
||||
- **License**: MIT
|
||||
|
||||
|
||||
280
optional-skills/mlops/faiss/references/index_types.md
Normal file
280
optional-skills/mlops/faiss/references/index_types.md
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
# FAISS Index Types Guide
|
||||
|
||||
Complete guide to choosing and using FAISS index types.
|
||||
|
||||
## Index selection guide
|
||||
|
||||
| Dataset Size | Index Type | Training | Accuracy | Speed |
|
||||
|--------------|------------|----------|----------|-------|
|
||||
| < 10K | Flat | No | 100% | Slow |
|
||||
| 10K-1M | IVF | Yes | 95-99% | Fast |
|
||||
| 1M-10M | HNSW | No | 99% | Fastest |
|
||||
| > 10M | IVF+PQ | Yes | 90-95% | Fast, low memory |
|
||||
|
||||
## Flat indices (exact search)
|
||||
|
||||
### IndexFlatL2 - L2 (Euclidean) distance
|
||||
|
||||
```python
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
d = 128 # Dimension
|
||||
index = faiss.IndexFlatL2(d)
|
||||
|
||||
# Add vectors
|
||||
vectors = np.random.random((1000, d)).astype('float32')
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
k = 5
|
||||
query = np.random.random((1, d)).astype('float32')
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- Dataset < 10,000 vectors
|
||||
- Need 100% accuracy
|
||||
- Serving as baseline
|
||||
|
||||
### IndexFlatIP - Inner product (cosine similarity)
|
||||
|
||||
```python
|
||||
# For cosine similarity, normalize vectors first
|
||||
import faiss
|
||||
|
||||
d = 128
|
||||
index = faiss.IndexFlatIP(d)
|
||||
|
||||
# Normalize vectors (required for cosine similarity)
|
||||
faiss.normalize_L2(vectors)
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
faiss.normalize_L2(query)
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- Need cosine similarity
|
||||
- Recommendation systems
|
||||
- Text embeddings
|
||||
|
||||
## IVF indices (inverted file)
|
||||
|
||||
### IndexIVFFlat - Cluster-based search
|
||||
|
||||
```python
|
||||
# Create quantizer
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
|
||||
# Create IVF index with 100 clusters
|
||||
nlist = 100 # Number of clusters
|
||||
index = faiss.IndexIVFFlat(quantizer, d, nlist)
|
||||
|
||||
# Train on data (required!)
|
||||
index.train(vectors)
|
||||
|
||||
# Add vectors
|
||||
index.add(vectors)
|
||||
|
||||
# Search (nprobe = clusters to search)
|
||||
index.nprobe = 10 # Search 10 closest clusters
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `nlist`: Number of clusters (√N to 4√N recommended)
|
||||
- `nprobe`: Clusters to search (1-nlist, higher = more accurate)
|
||||
|
||||
**Use when:**
|
||||
- Dataset 10K-1M vectors
|
||||
- Need fast approximate search
|
||||
- Can afford training time
|
||||
|
||||
### Tuning nprobe
|
||||
|
||||
```python
|
||||
# Test different nprobe values
|
||||
for nprobe in [1, 5, 10, 20, 50]:
|
||||
index.nprobe = nprobe
|
||||
distances, indices = index.search(query, k)
|
||||
# Measure recall/speed trade-off
|
||||
```
|
||||
|
||||
**Guidelines:**
|
||||
- `nprobe=1`: Fastest, ~50% recall
|
||||
- `nprobe=10`: Good balance, ~95% recall
|
||||
- `nprobe=nlist`: Exact search (same as Flat)
|
||||
|
||||
## HNSW indices (graph-based)
|
||||
|
||||
### IndexHNSWFlat - Hierarchical NSW
|
||||
|
||||
```python
|
||||
# HNSW index
|
||||
M = 32 # Number of connections per layer (16-64)
|
||||
index = faiss.IndexHNSWFlat(d, M)
|
||||
|
||||
# Optional: Set ef_construction (build time parameter)
|
||||
index.hnsw.efConstruction = 40 # Higher = better quality, slower build
|
||||
|
||||
# Add vectors (no training needed!)
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
index.hnsw.efSearch = 16 # Search time parameter
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `M`: Connections per layer (16-64, default 32)
|
||||
- `efConstruction`: Build quality (40-200, higher = better)
|
||||
- `efSearch`: Search quality (16-512, higher = more accurate)
|
||||
|
||||
**Use when:**
|
||||
- Need best quality approximate search
|
||||
- Can afford higher memory (more connections)
|
||||
- Dataset 1M-10M vectors
|
||||
|
||||
## PQ indices (product quantization)
|
||||
|
||||
### IndexPQ - Memory-efficient
|
||||
|
||||
```python
|
||||
# PQ reduces memory by 16-32×
|
||||
m = 8 # Number of subquantizers (divides d)
|
||||
nbits = 8 # Bits per subquantizer
|
||||
|
||||
index = faiss.IndexPQ(d, m, nbits)
|
||||
|
||||
# Train (required!)
|
||||
index.train(vectors)
|
||||
|
||||
# Add vectors
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `m`: Subquantizers (d must be divisible by m)
|
||||
- `nbits`: Bits per code (8 or 16)
|
||||
|
||||
**Memory savings:**
|
||||
- Original: d × 4 bytes (float32)
|
||||
- PQ: m bytes
|
||||
- Compression ratio: 4d/m
|
||||
|
||||
**Use when:**
|
||||
- Limited memory
|
||||
- Large datasets (> 10M vectors)
|
||||
- Can accept ~90-95% accuracy
|
||||
|
||||
### IndexIVFPQ - IVF + PQ combined
|
||||
|
||||
```python
|
||||
# Best for very large datasets
|
||||
nlist = 4096
|
||||
m = 8
|
||||
nbits = 8
|
||||
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, nbits)
|
||||
|
||||
# Train
|
||||
index.train(vectors)
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
index.nprobe = 32
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- Dataset > 10M vectors
|
||||
- Need fast search + low memory
|
||||
- Can accept 90-95% accuracy
|
||||
|
||||
## GPU indices
|
||||
|
||||
### Single GPU
|
||||
|
||||
```python
|
||||
import faiss
|
||||
|
||||
# Create CPU index
|
||||
index_cpu = faiss.IndexFlatL2(d)
|
||||
|
||||
# Move to GPU
|
||||
res = faiss.StandardGpuResources() # GPU resources
|
||||
index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu) # GPU 0
|
||||
|
||||
# Use normally
|
||||
index_gpu.add(vectors)
|
||||
distances, indices = index_gpu.search(query, k)
|
||||
```
|
||||
|
||||
### Multi-GPU
|
||||
|
||||
```python
|
||||
# Use all available GPUs
|
||||
index_gpu = faiss.index_cpu_to_all_gpus(index_cpu)
|
||||
|
||||
# Or specific GPUs
|
||||
gpus = [0, 1, 2, 3] # Use GPUs 0-3
|
||||
index_gpu = faiss.index_cpu_to_gpus_list(index_cpu, gpus)
|
||||
```
|
||||
|
||||
**Speedup:**
|
||||
- Single GPU: 10-50× faster than CPU
|
||||
- Multi-GPU: Near-linear scaling
|
||||
|
||||
## Index factory
|
||||
|
||||
```python
|
||||
# Easy index creation with string descriptors
|
||||
index = faiss.index_factory(d, "IVF100,Flat")
|
||||
index = faiss.index_factory(d, "HNSW32")
|
||||
index = faiss.index_factory(d, "IVF4096,PQ8")
|
||||
|
||||
# Train and use
|
||||
index.train(vectors)
|
||||
index.add(vectors)
|
||||
```
|
||||
|
||||
**Common descriptors:**
|
||||
- `"Flat"`: Exact search
|
||||
- `"IVF100,Flat"`: IVF with 100 clusters
|
||||
- `"HNSW32"`: HNSW with M=32
|
||||
- `"IVF4096,PQ8"`: IVF + PQ compression
|
||||
|
||||
## Performance comparison
|
||||
|
||||
### Search speed (1M vectors, k=10)
|
||||
|
||||
| Index | Build Time | Search Time | Memory | Recall |
|
||||
|-------|------------|-------------|--------|--------|
|
||||
| Flat | 0s | 50ms | 512 MB | 100% |
|
||||
| IVF100 | 5s | 2ms | 512 MB | 95% |
|
||||
| HNSW32 | 60s | 1ms | 1GB | 99% |
|
||||
| IVF4096+PQ8 | 30s | 3ms | 32 MB | 90% |
|
||||
|
||||
*CPU (16 cores), 128-dim vectors*
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Start with Flat** - Baseline for comparison
|
||||
2. **Use IVF for medium datasets** - Good balance
|
||||
3. **Use HNSW for best quality** - If memory allows
|
||||
4. **Add PQ for memory savings** - Large datasets
|
||||
5. **GPU for > 100K vectors** - 10-50× speedup
|
||||
6. **Tune nprobe/efSearch** - Trade-off speed/accuracy
|
||||
7. **Train on representative data** - Better clustering
|
||||
8. **Save trained indices** - Avoid retraining
|
||||
|
||||
## Resources
|
||||
|
||||
- **Wiki**: https://github.com/facebookresearch/faiss/wiki
|
||||
- **Paper**: https://arxiv.org/abs/1702.08734
|
||||
370
optional-skills/mlops/flash-attention/SKILL.md
Normal file
370
optional-skills/mlops/flash-attention/SKILL.md
Normal file
|
|
@ -0,0 +1,370 @@
|
|||
---
|
||||
name: optimizing-attention-flash
|
||||
description: Optimizes transformer attention with Flash Attention for 2-4x speedup and 10-20x memory reduction. Use when training/running transformers with long sequences (>512 tokens), encountering GPU memory issues with attention, or need faster inference. Supports PyTorch native SDPA, flash-attn library, H100 FP8, and sliding window attention.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [flash-attn, torch, transformers]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Optimization, Flash Attention, Attention Optimization, Memory Efficiency, Speed Optimization, Long Context, PyTorch, SDPA, H100, FP8, Transformers]
|
||||
|
||||
---
|
||||
|
||||
# Flash Attention - Fast Memory-Efficient Attention
|
||||
|
||||
## Quick start
|
||||
|
||||
Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation.
|
||||
|
||||
**PyTorch native (easiest, PyTorch 2.2+)**:
|
||||
```python
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]
|
||||
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
|
||||
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
|
||||
|
||||
# Automatically uses Flash Attention if available
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
```
|
||||
|
||||
**flash-attn library (more features)**:
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# q, k, v: [batch, seqlen, nheads, headdim]
|
||||
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Enable in existing PyTorch model
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
Flash Attention Integration:
|
||||
- [ ] Step 1: Check PyTorch version (≥2.2)
|
||||
- [ ] Step 2: Enable Flash Attention backend
|
||||
- [ ] Step 3: Verify speedup with profiling
|
||||
- [ ] Step 4: Test accuracy matches baseline
|
||||
```
|
||||
|
||||
**Step 1: Check PyTorch version**
|
||||
|
||||
```bash
|
||||
python -c "import torch; print(torch.__version__)"
|
||||
# Should be ≥2.2.0
|
||||
```
|
||||
|
||||
If <2.2, upgrade:
|
||||
```bash
|
||||
pip install --upgrade torch
|
||||
```
|
||||
|
||||
**Step 2: Enable Flash Attention backend**
|
||||
|
||||
Replace standard attention:
|
||||
```python
|
||||
# Before (standard attention)
|
||||
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
|
||||
out = attn_weights @ v
|
||||
|
||||
# After (Flash Attention)
|
||||
import torch.nn.functional as F
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
||||
```
|
||||
|
||||
Force Flash Attention backend:
|
||||
```python
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True,
|
||||
enable_math=False,
|
||||
enable_mem_efficient=False
|
||||
):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
```
|
||||
|
||||
**Step 3: Verify speedup with profiling**
|
||||
|
||||
```python
|
||||
import torch.utils.benchmark as benchmark
|
||||
|
||||
def test_attention(use_flash):
|
||||
q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
|
||||
|
||||
if use_flash:
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=True):
|
||||
return F.scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
|
||||
return attn @ v
|
||||
|
||||
# Benchmark
|
||||
t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
|
||||
t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
|
||||
|
||||
print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
|
||||
print(f"Standard: {t_standard.timeit(100).mean:.3f}s")
|
||||
```
|
||||
|
||||
Expected: 2-4x speedup for sequences >512 tokens.
|
||||
|
||||
**Step 4: Test accuracy matches baseline**
|
||||
|
||||
```python
|
||||
# Compare outputs
|
||||
q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
|
||||
|
||||
# Flash Attention
|
||||
out_flash = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
# Standard attention
|
||||
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
|
||||
out_standard = attn_weights @ v
|
||||
|
||||
# Check difference
|
||||
diff = (out_flash - out_standard).abs().max()
|
||||
print(f"Max difference: {diff:.6f}")
|
||||
# Should be <1e-3 for float16
|
||||
```
|
||||
|
||||
### Workflow 2: Use flash-attn library for advanced features
|
||||
|
||||
For multi-query attention, sliding window, or H100 FP8.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
flash-attn Library Setup:
|
||||
- [ ] Step 1: Install flash-attn library
|
||||
- [ ] Step 2: Modify attention code
|
||||
- [ ] Step 3: Enable advanced features
|
||||
- [ ] Step 4: Benchmark performance
|
||||
```
|
||||
|
||||
**Step 1: Install flash-attn library**
|
||||
|
||||
```bash
|
||||
# NVIDIA GPUs (CUDA 12.0+)
|
||||
pip install flash-attn --no-build-isolation
|
||||
|
||||
# Verify installation
|
||||
python -c "from flash_attn import flash_attn_func; print('Success')"
|
||||
```
|
||||
|
||||
**Step 2: Modify attention code**
|
||||
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# Input: [batch_size, seq_len, num_heads, head_dim]
|
||||
# Transpose from [batch, heads, seq, dim] if needed
|
||||
q = q.transpose(1, 2) # [batch, seq, heads, dim]
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
out = flash_attn_func(
|
||||
q, k, v,
|
||||
dropout_p=0.1,
|
||||
causal=True, # For autoregressive models
|
||||
window_size=(-1, -1), # No sliding window
|
||||
softmax_scale=None # Auto-scale
|
||||
)
|
||||
|
||||
out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]
|
||||
```
|
||||
|
||||
**Step 3: Enable advanced features**
|
||||
|
||||
Multi-query attention (shared K/V across heads):
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# q: [batch, seq, num_q_heads, dim]
|
||||
# k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads
|
||||
out = flash_attn_func(q, k, v) # Automatically handles MQA
|
||||
```
|
||||
|
||||
Sliding window attention (local attention):
|
||||
```python
|
||||
# Only attend to window of 256 tokens before/after
|
||||
out = flash_attn_func(
|
||||
q, k, v,
|
||||
window_size=(256, 256), # (left, right) window
|
||||
causal=True
|
||||
)
|
||||
```
|
||||
|
||||
**Step 4: Benchmark performance**
|
||||
|
||||
```python
|
||||
import torch
|
||||
from flash_attn import flash_attn_func
|
||||
import time
|
||||
|
||||
q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
_ = flash_attn_func(q, k, v)
|
||||
|
||||
# Benchmark
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
out = flash_attn_func(q, k, v)
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
|
||||
print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")
|
||||
print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
|
||||
```
|
||||
|
||||
### Workflow 3: H100 FP8 optimization (FlashAttention-3)
|
||||
|
||||
For maximum performance on H100 GPUs.
|
||||
|
||||
```
|
||||
FP8 Setup:
|
||||
- [ ] Step 1: Verify H100 GPU available
|
||||
- [ ] Step 2: Install flash-attn with FP8 support
|
||||
- [ ] Step 3: Convert inputs to FP8
|
||||
- [ ] Step 4: Run with FP8 attention
|
||||
```
|
||||
|
||||
**Step 1: Verify H100 GPU**
|
||||
|
||||
```bash
|
||||
nvidia-smi --query-gpu=name --format=csv
|
||||
# Should show "H100" or "H800"
|
||||
```
|
||||
|
||||
**Step 2: Install flash-attn with FP8 support**
|
||||
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
# FP8 support included for H100
|
||||
```
|
||||
|
||||
**Step 3: Convert inputs to FP8**
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
|
||||
k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
|
||||
v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
|
||||
|
||||
# Convert to float8_e4m3 (FP8)
|
||||
q_fp8 = q.to(torch.float8_e4m3fn)
|
||||
k_fp8 = k.to(torch.float8_e4m3fn)
|
||||
v_fp8 = v.to(torch.float8_e4m3fn)
|
||||
```
|
||||
|
||||
**Step 4: Run with FP8 attention**
|
||||
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# FlashAttention-3 automatically uses FP8 kernels on H100
|
||||
out = flash_attn_func(q_fp8, k_fp8, v_fp8)
|
||||
# Result: ~1.2 PFLOPS, 1.5-2x faster than FP16
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use Flash Attention when:**
|
||||
- Training transformers with sequences >512 tokens
|
||||
- Running inference with long context (>2K tokens)
|
||||
- GPU memory constrained (OOM with standard attention)
|
||||
- Need 2-4x speedup without accuracy loss
|
||||
- Using PyTorch 2.2+ or can install flash-attn
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Standard attention**: Sequences <256 tokens (overhead not worth it)
|
||||
- **xFormers**: Need more attention variants (not just speed)
|
||||
- **Memory-efficient attention**: CPU inference (Flash Attention needs GPU)
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: ImportError: cannot import flash_attn**
|
||||
|
||||
Install with no-build-isolation flag:
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Or install CUDA toolkit first:
|
||||
```bash
|
||||
conda install cuda -c nvidia
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
**Issue: Slower than expected (no speedup)**
|
||||
|
||||
Flash Attention benefits increase with sequence length:
|
||||
- <512 tokens: Minimal speedup (10-20%)
|
||||
- 512-2K tokens: 2-3x speedup
|
||||
- >2K tokens: 3-4x speedup
|
||||
|
||||
Check sequence length is sufficient.
|
||||
|
||||
**Issue: RuntimeError: CUDA error**
|
||||
|
||||
Verify GPU supports Flash Attention:
|
||||
```python
|
||||
import torch
|
||||
print(torch.cuda.get_device_capability())
|
||||
# Should be ≥(7, 5) for Turing+
|
||||
```
|
||||
|
||||
Flash Attention requires:
|
||||
- Ampere (A100, A10): ✅ Full support
|
||||
- Turing (T4): ✅ Supported
|
||||
- Volta (V100): ❌ Not supported
|
||||
|
||||
**Issue: Accuracy degradation**
|
||||
|
||||
Check dtype is float16 or bfloat16 (not float32):
|
||||
```python
|
||||
q = q.to(torch.float16) # Or torch.bfloat16
|
||||
```
|
||||
|
||||
Flash Attention uses float16/bfloat16 for speed. Float32 not supported.
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Integration with HuggingFace Transformers**: See [references/transformers-integration.md](references/transformers-integration.md) for enabling Flash Attention in BERT, GPT, Llama models.
|
||||
|
||||
**Performance benchmarks**: See [references/benchmarks.md](references/benchmarks.md) for detailed speed and memory comparisons across GPUs and sequence lengths.
|
||||
|
||||
**Algorithm details**: See [references/algorithm.md](references/algorithm.md) for tiling strategy, recomputation, and IO complexity analysis.
|
||||
|
||||
**Advanced features**: See [references/advanced-features.md](references/advanced-features.md) for rotary embeddings, ALiBi, paged KV cache, and custom attention masks.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA Ampere+ (A100, A10, A30) or AMD MI200+
|
||||
- **VRAM**: Same as standard attention (Flash Attention doesn't increase memory)
|
||||
- **CUDA**: 12.0+ (11.8 minimum)
|
||||
- **PyTorch**: 2.2+ for native support
|
||||
|
||||
**Not supported**: V100 (Volta), CPU inference
|
||||
|
||||
## Resources
|
||||
|
||||
- Paper: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (NeurIPS 2022)
|
||||
- Paper: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (ICLR 2024)
|
||||
- Blog: https://tridao.me/blog/2024/flash3/
|
||||
- GitHub: https://github.com/Dao-AILab/flash-attention
|
||||
- PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
|
||||
|
||||
|
||||
215
optional-skills/mlops/flash-attention/references/benchmarks.md
Normal file
215
optional-skills/mlops/flash-attention/references/benchmarks.md
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
# Performance Benchmarks
|
||||
|
||||
## Contents
|
||||
- Speed comparisons across GPUs
|
||||
- Memory usage analysis
|
||||
- Scaling with sequence length
|
||||
- Training vs inference performance
|
||||
- Flash Attention versions comparison
|
||||
|
||||
## Speed comparisons across GPUs
|
||||
|
||||
### A100 80GB (Ampere)
|
||||
|
||||
**Forward pass time** (milliseconds, batch=8, heads=32, dim=64):
|
||||
|
||||
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 | Speedup (FA2) |
|
||||
|------------|----------|--------------|--------------|---------------|
|
||||
| 512 | 1.2 | 0.9 | N/A | 1.3x |
|
||||
| 1024 | 3.8 | 1.4 | N/A | 2.7x |
|
||||
| 2048 | 14.2 | 4.8 | N/A | 3.0x |
|
||||
| 4096 | 55.1 | 17.3 | N/A | 3.2x |
|
||||
| 8192 | 218.5 | 66.2 | N/A | 3.3x |
|
||||
|
||||
### H100 80GB (Hopper)
|
||||
|
||||
**Forward pass time** (milliseconds, same config):
|
||||
|
||||
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 (FP16) | Flash Attn 3 (FP8) | Best Speedup |
|
||||
|------------|----------|--------------|---------------------|--------------------|--------------|
|
||||
| 512 | 0.8 | 0.6 | 0.4 | 0.3 | 2.7x |
|
||||
| 1024 | 2.6 | 1.0 | 0.6 | 0.4 | 6.5x |
|
||||
| 2048 | 9.8 | 3.4 | 2.0 | 1.3 | 7.5x |
|
||||
| 4096 | 38.2 | 12.5 | 7.2 | 4.8 | 8.0x |
|
||||
| 8192 | 151.4 | 47.8 | 27.1 | 18.2 | 8.3x |
|
||||
|
||||
**Key insight**: Flash Attention 3 on H100 with FP8 achieves ~1.2 PFLOPS (75% of theoretical max).
|
||||
|
||||
### A10G 24GB (Ampere)
|
||||
|
||||
**Forward pass time** (milliseconds, batch=4):
|
||||
|
||||
| Seq Length | Standard | Flash Attn 2 | Speedup |
|
||||
|------------|----------|--------------|---------|
|
||||
| 512 | 2.1 | 1.6 | 1.3x |
|
||||
| 1024 | 6.8 | 2.8 | 2.4x |
|
||||
| 2048 | 25.9 | 9.4 | 2.8x |
|
||||
| 4096 | 102.1 | 35.2 | 2.9x |
|
||||
|
||||
## Memory usage analysis
|
||||
|
||||
### GPU memory consumption (batch=8, heads=32, dim=64)
|
||||
|
||||
**Standard attention memory**:
|
||||
|
||||
| Seq Length | Attention Matrix | KV Cache | Total | Notes |
|
||||
|------------|------------------|----------|-------|-------|
|
||||
| 512 | 8 MB | 32 MB | 40 MB | Manageable |
|
||||
| 2048 | 128 MB | 128 MB | 256 MB | Growing |
|
||||
| 8192 | 2048 MB (2 GB) | 512 MB | 2.5 GB | Large |
|
||||
| 32768 | 32768 MB (32 GB) | 2048 MB | 34 GB | OOM on 24GB GPUs |
|
||||
|
||||
**Flash Attention 2 memory**:
|
||||
|
||||
| Seq Length | Attention (on-chip) | KV Cache | Total | Reduction |
|
||||
|------------|---------------------|----------|-------|-----------|
|
||||
| 512 | 0 MB (recomputed) | 32 MB | 32 MB | 20% |
|
||||
| 2048 | 0 MB | 128 MB | 128 MB | 50% |
|
||||
| 8192 | 0 MB | 512 MB | 512 MB | 80% |
|
||||
| 32768 | 0 MB | 2048 MB | 2 GB | 94% |
|
||||
|
||||
**Key insight**: Flash Attention doesn't materialize attention matrix, saving O(N²) memory.
|
||||
|
||||
### Memory scaling comparison
|
||||
|
||||
**Llama 2 7B model memory** (float16, batch=1):
|
||||
|
||||
| Context Length | Standard Attention | Flash Attention 2 | Can Fit 24GB GPU? |
|
||||
|----------------|-------------------|-------------------|-------------------|
|
||||
| 2K | 3.2 GB | 2.1 GB | Both: Yes |
|
||||
| 4K | 5.8 GB | 2.8 GB | Both: Yes |
|
||||
| 8K | 12.1 GB | 4.2 GB | Both: Yes |
|
||||
| 16K | 26.3 GB (OOM) | 7.8 GB | Only Flash: Yes |
|
||||
| 32K | OOM | 14.2 GB | Only Flash: Yes |
|
||||
|
||||
### Training memory (Llama 2 7B, batch=4)
|
||||
|
||||
| Context | Standard (GB) | Flash Attn (GB) | Reduction |
|
||||
|---------|---------------|-----------------|-----------|
|
||||
| 2K | 18.2 | 12.4 | 32% |
|
||||
| 4K | 34.8 | 16.8 | 52% |
|
||||
| 8K | OOM (>40GB) | 26.2 | Fits! |
|
||||
|
||||
## Scaling with sequence length
|
||||
|
||||
### Computational complexity
|
||||
|
||||
**Standard attention**:
|
||||
- Time: O(N² × d)
|
||||
- Memory: O(N² + N × d)
|
||||
|
||||
**Flash Attention**:
|
||||
- Time: O(N² × d) (same, but with better constants)
|
||||
- Memory: O(N × d) (linear!)
|
||||
|
||||
### Empirical scaling (A100, batch=1, heads=32, dim=64)
|
||||
|
||||
**Time per token (milliseconds)**:
|
||||
|
||||
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|
||||
|----------|-----|-----|-----|-----|-----|------|
|
||||
| Standard | 0.15 | 0.37 | 1.11 | 3.44 | 13.4 | 52.8 |
|
||||
| Flash Attn 2 | 0.11 | 0.14 | 0.24 | 0.43 | 0.83 | 1.64 |
|
||||
| Speedup | 1.4x | 2.6x | 4.6x | 8.0x | 16.1x | 32.2x |
|
||||
|
||||
**Observation**: Speedup increases quadratically with sequence length!
|
||||
|
||||
### Memory per token (MB)
|
||||
|
||||
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|
||||
|----------|-----|-----|-----|-----|-----|------|
|
||||
| Standard | 0.08 | 0.13 | 0.25 | 0.64 | 2.05 | 8.13 |
|
||||
| Flash Attn 2 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 |
|
||||
|
||||
**Observation**: Flash Attention memory per token is constant!
|
||||
|
||||
## Training vs inference performance
|
||||
|
||||
### Training (forward + backward, Llama 2 7B, A100)
|
||||
|
||||
| Batch × Seq | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|
||||
|-------------|------------------------|--------------------------|---------|
|
||||
| 4 × 2K | 1.2 | 3.1 | 2.6x |
|
||||
| 8 × 2K | 2.1 | 5.8 | 2.8x |
|
||||
| 4 × 4K | 0.4 | 1.3 | 3.3x |
|
||||
| 8 × 4K | OOM | 2.4 | Enabled |
|
||||
| 2 × 8K | 0.1 | 0.4 | 4.0x |
|
||||
|
||||
### Inference (generation, Llama 2 7B, A100)
|
||||
|
||||
| Context Length | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|
||||
|----------------|----------------------|-------------------------|---------|
|
||||
| 512 | 48 | 52 | 1.1x |
|
||||
| 2K | 42 | 62 | 1.5x |
|
||||
| 4K | 31 | 58 | 1.9x |
|
||||
| 8K | 18 | 51 | 2.8x |
|
||||
| 16K | OOM | 42 | Enabled |
|
||||
|
||||
**Note**: Inference speedup less dramatic than training because generation is memory-bound (KV cache accesses).
|
||||
|
||||
## Flash Attention versions comparison
|
||||
|
||||
### Flash Attention 1 vs 2 vs 3 (H100, seq=4096, batch=8)
|
||||
|
||||
| Metric | FA1 | FA2 | FA3 (FP16) | FA3 (FP8) |
|
||||
|--------|-----|-----|------------|-----------|
|
||||
| Forward time (ms) | 28.4 | 12.5 | 7.2 | 4.8 |
|
||||
| Memory (GB) | 4.8 | 4.2 | 4.2 | 2.8 |
|
||||
| TFLOPS | 180 | 420 | 740 | 1150 |
|
||||
| GPU util % | 35% | 55% | 75% | 82% |
|
||||
|
||||
**Key improvements**:
|
||||
- FA2: 2.3x faster than FA1 (better parallelism)
|
||||
- FA3 (FP16): 1.7x faster than FA2 (H100 async optimizations)
|
||||
- FA3 (FP8): 2.6x faster than FA2 (low precision)
|
||||
|
||||
### Features by version
|
||||
|
||||
| Feature | FA1 | FA2 | FA3 |
|
||||
|---------|-----|-----|-----|
|
||||
| Basic attention | ✅ | ✅ | ✅ |
|
||||
| Causal masking | ✅ | ✅ | ✅ |
|
||||
| Multi-query attention | ❌ | ✅ | ✅ |
|
||||
| Sliding window | ❌ | ✅ | ✅ |
|
||||
| Paged KV cache | ❌ | ✅ | ✅ |
|
||||
| FP8 support | ❌ | ❌ | ✅ (H100 only) |
|
||||
| Work partitioning | Basic | Advanced | Optimal |
|
||||
|
||||
## Real-world model benchmarks
|
||||
|
||||
### Llama 2 models (A100 80GB, batch=4, seq=2048)
|
||||
|
||||
| Model | Params | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|
||||
|-------|--------|------------------------|--------------------------|---------|
|
||||
| Llama 2 7B | 7B | 1.2 | 3.1 | 2.6x |
|
||||
| Llama 2 13B | 13B | 0.6 | 1.7 | 2.8x |
|
||||
| Llama 2 70B | 70B | 0.12 | 0.34 | 2.8x |
|
||||
|
||||
### GPT-style models (seq=1024)
|
||||
|
||||
| Model | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|
||||
|-------|----------------------|-------------------------|---------|
|
||||
| GPT-2 (124M) | 520 | 680 | 1.3x |
|
||||
| GPT-J (6B) | 42 | 98 | 2.3x |
|
||||
| GPT-NeoX (20B) | 8 | 22 | 2.75x |
|
||||
|
||||
## Recommendations by use case
|
||||
|
||||
**Training large models (>7B parameters)**:
|
||||
- Use Flash Attention 2 on A100
|
||||
- Use Flash Attention 3 FP8 on H100 for maximum speed
|
||||
- Expected: 2.5-3x speedup
|
||||
|
||||
**Long context inference (>4K tokens)**:
|
||||
- Flash Attention essential (enables contexts standard attention can't handle)
|
||||
- Expected: 2-4x speedup, 5-10x memory reduction
|
||||
|
||||
**Short sequences (<512 tokens)**:
|
||||
- Flash Attention provides 1.2-1.5x speedup
|
||||
- Minimal memory benefit
|
||||
- Still worth enabling (no downside)
|
||||
|
||||
**Multi-user serving**:
|
||||
- Flash Attention reduces per-request memory
|
||||
- Allows higher concurrent batch sizes
|
||||
- Can serve 2-3x more users on same hardware
|
||||
|
|
@ -0,0 +1,293 @@
|
|||
# HuggingFace Transformers Integration
|
||||
|
||||
## Contents
|
||||
- Enabling Flash Attention in Transformers
|
||||
- Supported model architectures
|
||||
- Configuration examples
|
||||
- Performance comparisons
|
||||
- Troubleshooting model-specific issues
|
||||
|
||||
## Enabling Flash Attention in Transformers
|
||||
|
||||
HuggingFace Transformers (v4.36+) supports Flash Attention 2 natively.
|
||||
|
||||
**Simple enable for any supported model**:
|
||||
```python
|
||||
from transformers import AutoModel
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
**Install requirements**:
|
||||
```bash
|
||||
pip install transformers>=4.36
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
## Supported model architectures
|
||||
|
||||
As of Transformers 4.40:
|
||||
|
||||
**Fully supported**:
|
||||
- Llama / Llama 2 / Llama 3
|
||||
- Mistral / Mixtral
|
||||
- Falcon
|
||||
- GPT-NeoX
|
||||
- Phi / Phi-2 / Phi-3
|
||||
- Qwen / Qwen2
|
||||
- Gemma
|
||||
- Starcoder2
|
||||
- GPT-J
|
||||
- OPT
|
||||
- BLOOM
|
||||
|
||||
**Partially supported** (encoder-decoder):
|
||||
- BART
|
||||
- T5 / Flan-T5
|
||||
- Whisper
|
||||
|
||||
**Check support**:
|
||||
```python
|
||||
from transformers import AutoConfig
|
||||
|
||||
config = AutoConfig.from_pretrained("model-name")
|
||||
print(config._attn_implementation_internal)
|
||||
# 'flash_attention_2' if supported
|
||||
```
|
||||
|
||||
## Configuration examples
|
||||
|
||||
### Llama 2 with Flash Attention
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-2-7b-hf"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# Generate
|
||||
inputs = tokenizer("Once upon a time", return_tensors="pt").to("cuda")
|
||||
outputs = model.generate(**inputs, max_length=100)
|
||||
print(tokenizer.decode(outputs[0]))
|
||||
```
|
||||
|
||||
### Mistral with Flash Attention for long context
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
import torch
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"mistralai/Mistral-7B-v0.1",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16, # Better for long context
|
||||
device_map="auto",
|
||||
max_position_embeddings=32768 # Extended context
|
||||
)
|
||||
|
||||
# Process long document (32K tokens)
|
||||
long_text = "..." * 10000
|
||||
inputs = tokenizer(long_text, return_tensors="pt", truncation=False).to("cuda")
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
```
|
||||
|
||||
### Fine-tuning with Flash Attention
|
||||
|
||||
```python
|
||||
from transformers import Trainer, TrainingArguments
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
num_train_epochs=3,
|
||||
fp16=True, # Must match model dtype
|
||||
optim="adamw_torch_fused" # Fast optimizer
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### Multi-GPU training
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
import torch
|
||||
|
||||
# Model parallelism with Flash Attention
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-13b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto", # Automatic multi-GPU placement
|
||||
max_memory={0: "20GB", 1: "20GB"} # Limit per GPU
|
||||
)
|
||||
```
|
||||
|
||||
## Performance comparisons
|
||||
|
||||
### Memory usage (Llama 2 7B, batch=1)
|
||||
|
||||
| Sequence Length | Standard Attention | Flash Attention 2 | Reduction |
|
||||
|-----------------|-------------------|-------------------|-----------|
|
||||
| 512 | 1.2 GB | 0.9 GB | 25% |
|
||||
| 2048 | 3.8 GB | 1.4 GB | 63% |
|
||||
| 8192 | 14.2 GB | 3.2 GB | 77% |
|
||||
| 32768 | OOM (>24GB) | 10.8 GB | Fits! |
|
||||
|
||||
### Speed (tokens/sec, A100 80GB)
|
||||
|
||||
| Model | Standard | Flash Attn 2 | Speedup |
|
||||
|-------|----------|--------------|---------|
|
||||
| Llama 2 7B (seq=2048) | 42 | 118 | 2.8x |
|
||||
| Llama 2 13B (seq=4096) | 18 | 52 | 2.9x |
|
||||
| Llama 2 70B (seq=2048) | 4 | 11 | 2.75x |
|
||||
|
||||
### Training throughput (samples/sec)
|
||||
|
||||
| Model | Batch Size | Standard | Flash Attn 2 | Speedup |
|
||||
|-------|------------|----------|--------------|---------|
|
||||
| Llama 2 7B | 4 | 1.2 | 3.1 | 2.6x |
|
||||
| Llama 2 7B | 8 | 2.1 | 5.8 | 2.8x |
|
||||
| Llama 2 13B | 2 | 0.6 | 1.7 | 2.8x |
|
||||
|
||||
## Troubleshooting model-specific issues
|
||||
|
||||
### Issue: Model doesn't support Flash Attention
|
||||
|
||||
Check support list above. If not supported, use PyTorch SDPA as fallback:
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="sdpa", # PyTorch native (still faster)
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: CUDA out of memory during loading
|
||||
|
||||
Reduce memory footprint:
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
max_memory={0: "18GB"}, # Reserve memory for KV cache
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Slower inference than expected
|
||||
|
||||
Ensure dtype matches:
|
||||
|
||||
```python
|
||||
# Model and inputs must both be float16/bfloat16
|
||||
model = model.to(torch.float16)
|
||||
inputs = tokenizer(..., return_tensors="pt").to("cuda")
|
||||
inputs = {k: v.to(torch.float16) if v.dtype == torch.float32 else v
|
||||
for k, v in inputs.items()}
|
||||
```
|
||||
|
||||
### Issue: Different outputs vs standard attention
|
||||
|
||||
Flash Attention is numerically equivalent but uses different computation order. Small differences (<1e-3) are normal:
|
||||
|
||||
```python
|
||||
# Compare outputs
|
||||
model_standard = AutoModelForCausalLM.from_pretrained("model-name", torch_dtype=torch.float16)
|
||||
model_flash = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
inputs = tokenizer("Test", return_tensors="pt").to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
out_standard = model_standard(**inputs).logits
|
||||
out_flash = model_flash(**inputs).logits
|
||||
|
||||
diff = (out_standard - out_flash).abs().max()
|
||||
print(f"Max diff: {diff:.6f}") # Should be ~1e-3 to 1e-4
|
||||
```
|
||||
|
||||
### Issue: ImportError during model loading
|
||||
|
||||
Install flash-attn:
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Or disable Flash Attention:
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="eager", # Standard PyTorch
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Always use float16/bfloat16** with Flash Attention (not float32)
|
||||
2. **Set device_map="auto"** for automatic memory management
|
||||
3. **Use bfloat16 for long context** (better numerical stability)
|
||||
4. **Enable gradient checkpointing** for training large models
|
||||
5. **Monitor memory** with `torch.cuda.max_memory_allocated()`
|
||||
|
||||
**Example with all best practices**:
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, TrainingArguments
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16, # Better for training
|
||||
device_map="auto",
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
# Enable gradient checkpointing for memory
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Training with optimizations
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
per_device_train_batch_size=8,
|
||||
gradient_accumulation_steps=2,
|
||||
bf16=True, # Match model dtype
|
||||
optim="adamw_torch_fused",
|
||||
gradient_checkpointing=True
|
||||
)
|
||||
```
|
||||
302
optional-skills/mlops/hermes-atropos-environments/SKILL.md
Normal file
302
optional-skills/mlops/hermes-atropos-environments/SKILL.md
Normal file
|
|
@ -0,0 +1,302 @@
|
|||
---
|
||||
name: hermes-atropos-environments
|
||||
description: Build, test, and debug Hermes Agent RL environments for Atropos training. Covers the HermesAgentBaseEnv interface, reward functions, agent loop integration, evaluation with tools, wandb logging, and the three CLI modes (serve/process/evaluate). Use when creating, reviewing, or fixing RL environments in the hermes-agent repo.
|
||||
version: 1.1.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [atropos, rl, environments, training, reinforcement-learning, reward-functions]
|
||||
related_skills: [axolotl, grpo-rl-training, trl-fine-tuning, lm-evaluation-harness]
|
||||
---
|
||||
|
||||
# Hermes Agent Atropos Environments
|
||||
|
||||
Guide for building RL environments in the hermes-agent repo that integrate with the Atropos training framework.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
Atropos BaseEnv (atroposlib/envs/base.py)
|
||||
└── HermesAgentBaseEnv (environments/hermes_base_env.py)
|
||||
├── Handles agent loop orchestration
|
||||
├── Handles tool resolution per group
|
||||
├── Handles ToolContext for reward verification
|
||||
└── YOUR ENVIRONMENT (environments/your_env.py)
|
||||
Only implements: setup, get_next_item, format_prompt,
|
||||
compute_reward, evaluate, wandb_log
|
||||
```
|
||||
|
||||
Hermes environments are special because they run a **multi-turn agent loop with tool calling** — not just single-turn completions. The base env handles the loop; you implement the task and scoring.
|
||||
|
||||
## File Locations
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `environments/hermes_base_env.py` | Base class with agent loop + tool resolution |
|
||||
| `environments/agent_loop.py` | `HermesAgentLoop` + `AgentResult` dataclass |
|
||||
| `environments/tool_context.py` | `ToolContext` for reward verification |
|
||||
| `environments/tool_call_parsers.py` | Phase 2 tool call parsers (hermes, mistral, etc.) |
|
||||
| `environments/your_env.py` | Your environment implementation |
|
||||
|
||||
## Inference Setup — Ask the User First
|
||||
|
||||
**IMPORTANT:** Before running any test, evaluation, or data generation command, always ask the user how they want to handle inference. Do NOT assume OpenRouter or any specific endpoint. Present these options:
|
||||
|
||||
1. **OpenRouter** — Ask which model they want to use (e.g., `anthropic/claude-sonnet-4.5`, `google/gemini-2.5-pro`, `meta-llama/llama-3.3-70b-instruct`, etc.). Requires `OPENROUTER_API_KEY` in environment.
|
||||
2. **Self-hosted VLLM endpoint** — Ask for their base URL (e.g., `http://localhost:8000/v1`) and model name. Set `--openai.server_type vllm`.
|
||||
3. **Other OpenAI-compatible API** — Ask for the base URL, model name, and any required API key. Set `--openai.server_type openai` and `--openai.health_check false`.
|
||||
4. **Local Atropos training server** — For `serve` mode with a live training loop. Default `http://localhost:8000/v1`.
|
||||
|
||||
Once the user tells you their setup, use those values in all CLI commands for that session. Example prompts:
|
||||
|
||||
> "Before I run this, how would you like to handle inference?
|
||||
> 1. OpenRouter (I'll need your preferred model, e.g. claude-sonnet-4.5)
|
||||
> 2. A self-hosted VLLM endpoint (give me the URL and model name)
|
||||
> 3. Another OpenAI-compatible API (give me the URL, model, and any auth details)
|
||||
> 4. Local Atropos training server (serve mode)"
|
||||
|
||||
### Key flags by provider:
|
||||
|
||||
| Provider | `--openai.server_type` | `--openai.health_check` | `--openai.api_key` |
|
||||
|----------|----------------------|------------------------|-------------------|
|
||||
| OpenRouter | `openai` | `false` | `$OPENROUTER_API_KEY` |
|
||||
| VLLM (self-hosted) | `vllm` | (default) | (not needed) |
|
||||
| Other OpenAI-compatible | `openai` | `false` | As needed |
|
||||
| Local Atropos | (default) | (default) | (not needed) |
|
||||
|
||||
## Required Methods
|
||||
|
||||
### 1. `setup()` — Load dataset and initialize state
|
||||
|
||||
```python
|
||||
async def setup(self) -> None:
|
||||
"""Called once at startup. Load datasets, initialize state."""
|
||||
# Try HuggingFace first, fallback to built-in samples
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
ds = load_dataset("your/dataset", split="test")
|
||||
self._items = [...]
|
||||
except Exception:
|
||||
self._items = BUILTIN_SAMPLES
|
||||
|
||||
# Always split into train/eval
|
||||
random.shuffle(self._items)
|
||||
eval_size = max(20, int(len(self._items) * 0.1))
|
||||
self._eval_items = self._items[:eval_size]
|
||||
self._items = self._items[eval_size:]
|
||||
```
|
||||
|
||||
### 2. `get_next_item()` — Return next training item
|
||||
|
||||
```python
|
||||
async def get_next_item(self) -> dict:
|
||||
"""Return next item, cycling through dataset."""
|
||||
item = self._items[self._index % len(self._items)]
|
||||
self._index += 1
|
||||
return item
|
||||
```
|
||||
|
||||
### 3. `format_prompt(item)` — Convert item to user message
|
||||
|
||||
```python
|
||||
def format_prompt(self, item: dict) -> str:
|
||||
"""Convert a dataset item into the user-facing prompt."""
|
||||
return f"Research this question: {item['question']}"
|
||||
```
|
||||
|
||||
### 4. `compute_reward(item, result, ctx)` — Score the rollout
|
||||
|
||||
**CRITICAL**: `result` is an `AgentResult`, NOT a dict. It has these attributes:
|
||||
- `result.messages` — List of message dicts (OpenAI format)
|
||||
- `result.turns_used` — Number of LLM calls made
|
||||
- `result.finished_naturally` — True if model stopped voluntarily
|
||||
- `result.tool_errors` — List of ToolError objects
|
||||
|
||||
**AgentResult does NOT have**: `final_response`, `tool_calls`, `tools_used`.
|
||||
You must extract these from `result.messages`:
|
||||
|
||||
```python
|
||||
async def compute_reward(self, item, result: AgentResult, ctx: ToolContext) -> float:
|
||||
# Extract final response (last assistant message with content)
|
||||
final_response = ""
|
||||
tools_used = []
|
||||
for msg in reversed(result.messages):
|
||||
if msg.get("role") == "assistant" and msg.get("content") and not final_response:
|
||||
final_response = msg["content"]
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
|
||||
name = fn.get("name", "")
|
||||
if name:
|
||||
tools_used.append(name)
|
||||
|
||||
# Score using LLM judge, heuristic, or ToolContext verification
|
||||
correctness = await self._llm_judge(item, final_response)
|
||||
return correctness
|
||||
```
|
||||
|
||||
`ctx` (ToolContext) gives you terminal/file access to the agent's sandbox for verification:
|
||||
```python
|
||||
# Run tests in the agent's sandbox
|
||||
result = ctx.terminal("pytest /workspace/test.py")
|
||||
return 1.0 if result["exit_code"] == 0 else 0.0
|
||||
```
|
||||
|
||||
### 5. `evaluate()` — Periodic evaluation with full agent loop
|
||||
|
||||
**MUST use the full agent loop with tools**, not single-turn chat_completion.
|
||||
The whole point of hermes-agent environments is agentic evaluation:
|
||||
|
||||
```python
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
import time, uuid
|
||||
from environments.agent_loop import HermesAgentLoop
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
start_time = time.time()
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
samples = []
|
||||
|
||||
for item in self._eval_items[:self.config.eval_size]:
|
||||
task_id = str(uuid.uuid4())
|
||||
messages = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append({"role": "user", "content": self.format_prompt(item)})
|
||||
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=0.0, # Deterministic for eval
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
reward = await self.compute_reward(item, result, ctx)
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
|
||||
samples.append({"prompt": ..., "response": ..., "reward": reward})
|
||||
|
||||
eval_metrics = {"eval/mean_reward": ...}
|
||||
await self.evaluate_log(metrics=eval_metrics, samples=samples,
|
||||
start_time=start_time, end_time=time.time())
|
||||
```
|
||||
|
||||
### 6. `wandb_log()` — Custom metrics logging
|
||||
|
||||
Always call `super().wandb_log()` at the end:
|
||||
|
||||
```python
|
||||
async def wandb_log(self, wandb_metrics=None):
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
if self._reward_buffer:
|
||||
n = len(self._reward_buffer)
|
||||
wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n
|
||||
self._reward_buffer.clear()
|
||||
await super().wandb_log(wandb_metrics) # MUST call super
|
||||
```
|
||||
|
||||
**Pitfall**: `compute_reward` appends to metric buffers. During eval, this pollutes training metrics. Roll back buffer entries added during eval.
|
||||
|
||||
## Config Class
|
||||
|
||||
Always create a custom config subclass with Pydantic Field descriptors. Key inherited fields you can tune: `enabled_toolsets`, `max_agent_turns`, `agent_temperature`, `system_prompt`, `terminal_backend`, `group_size`, `steps_per_eval`, `total_steps`.
|
||||
|
||||
## config_init() — Default Configuration
|
||||
|
||||
Classmethod returning `(YourEnvConfig, [APIServerConfig(...)])`. Set server_type to "openai" for OpenRouter/external APIs. Load API key from environment variable.
|
||||
|
||||
## Three CLI Modes
|
||||
|
||||
```bash
|
||||
# SERVE — Full training loop (connects to Atropos API server)
|
||||
python environments/my_env.py serve --openai.base_url http://localhost:8000/v1
|
||||
|
||||
# PROCESS — Offline data generation (saves JSONL)
|
||||
python environments/my_env.py process --env.total_steps 10 --env.group_size 1 \
|
||||
--env.use_wandb false --env.data_path_to_save_groups output.jsonl \
|
||||
--openai.base_url "<USER_BASE_URL>" \
|
||||
--openai.model_name "<USER_MODEL>" \
|
||||
--openai.server_type <USER_SERVER_TYPE> --openai.health_check false
|
||||
|
||||
# EVALUATE — Standalone eval (runs setup + evaluate only)
|
||||
python environments/my_env.py evaluate --env.eval_size 20 \
|
||||
--env.data_dir_to_save_evals /tmp/eval_results \
|
||||
--openai.base_url "<USER_BASE_URL>" \
|
||||
--openai.model_name "<USER_MODEL>" \
|
||||
--openai.server_type <USER_SERVER_TYPE> --openai.health_check false
|
||||
```
|
||||
|
||||
Config priority: CLI args > YAML file > config_init() defaults.
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
1. **AgentResult has .messages, not .final_response** — Extract the final response by iterating reversed(result.messages) looking for the last assistant message with content.
|
||||
|
||||
2. **evaluate() must use HermesAgentLoop, not chat_completion** — Single-turn chat_completion has no tools. The whole point of hermes-agent benchmarks is agentic evaluation with tool use.
|
||||
|
||||
3. **Don't call _llm_judge twice** — If compute_reward already calls it, extract the score from the buffer instead of calling judge separately in evaluate().
|
||||
|
||||
4. **Eval pollutes training buffers** — compute_reward appends to metric buffers. During eval, roll back buffer entries to keep training metrics clean.
|
||||
|
||||
5. **Always set health_check=false for OpenRouter** — OpenRouter has no /health endpoint.
|
||||
|
||||
6. **Set data_dir_to_save_evals in evaluate mode** — Without it, results aren't saved.
|
||||
|
||||
7. **default_toolsets class variable vs enabled_toolsets config** — The class variable is a hint; the config field is what actually controls tool resolution.
|
||||
|
||||
8. **Tool call parsing in messages** — Tool calls are dicts with `{"function": {"name": ..., "arguments": ...}}`. Always check `isinstance(tc, dict)`.
|
||||
|
||||
9. **ToolContext.cleanup()** — Always call in a finally block to release sandbox resources.
|
||||
|
||||
10. **server_type must be "openai" for external APIs** — Without it, Atropos assumes a local VLLM server.
|
||||
|
||||
11. **Always ask the user for their inference setup** — Never hardcode or assume a specific provider/model. See the "Inference Setup" section above.
|
||||
|
||||
## Reward Function Patterns
|
||||
|
||||
### LLM Judge (for open-ended tasks)
|
||||
Use `self.server.chat_completion()` with a scoring prompt. Parse JSON response for score float. Always include a heuristic fallback (keyword overlap) for when the judge call fails.
|
||||
|
||||
### Binary Verification (for code/terminal tasks)
|
||||
Use `ctx.terminal("pytest test.py -q")` to run tests in the agent's sandbox. Return 1.0 for pass, 0.0 for fail.
|
||||
|
||||
### Multi-Signal (combine multiple indicators)
|
||||
Weight correctness (0.6) + tool usage (0.2) + efficiency (0.2) + optional bonuses. Clamp to [0, 1].
|
||||
|
||||
## Testing Your Environment
|
||||
|
||||
1. **Import test**: `python -c "from environments.my_env import MyEnv; print('OK')"`
|
||||
2. **Ask the user for inference setup** (see "Inference Setup" section above)
|
||||
3. **Process mode** (1 item): Verify JSONL output has valid tokens, masks, scores
|
||||
4. **Evaluate mode**: Verify full agent loop runs with tools, metrics logged correctly
|
||||
5. **Check reward range**: Scores should be in [0, 1], not all identical
|
||||
|
||||
## Minimum Implementation Checklist
|
||||
|
||||
```python
|
||||
class MyEnv(HermesAgentBaseEnv):
|
||||
name = "my-env"
|
||||
env_config_cls = MyEnvConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls): ... # Default server + env config
|
||||
async def setup(self): ... # Load dataset + train/eval split
|
||||
async def get_next_item(self): ... # Cycle through training items
|
||||
def format_prompt(self, item): ... # Item → user message string
|
||||
async def compute_reward(self, item, result, ctx): ... # Score rollout
|
||||
async def evaluate(self, *args, **kwargs): ... # Full agent loop eval
|
||||
async def wandb_log(self, metrics=None): ... # Custom metrics + super()
|
||||
|
||||
if __name__ == "__main__":
|
||||
MyEnv.cli()
|
||||
```
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
# AgentResult Fields Reference
|
||||
|
||||
`AgentResult` is defined in `environments/agent_loop.py` as a dataclass.
|
||||
|
||||
## Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `messages` | `List[Dict[str, Any]]` | Full conversation history in OpenAI message format |
|
||||
| `managed_state` | `Optional[Dict]` | ManagedServer.get_state() if Phase 2, else None |
|
||||
| `turns_used` | `int` | Number of LLM calls made during the loop |
|
||||
| `finished_naturally` | `bool` | True if model stopped calling tools on its own |
|
||||
| `reasoning_per_turn` | `List[Optional[str]]` | Extracted reasoning content per turn |
|
||||
| `tool_errors` | `List[ToolError]` | Tool errors encountered during the loop |
|
||||
|
||||
## ToolError Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `turn` | `int` | Which turn the error occurred |
|
||||
| `tool_name` | `str` | Name of the tool that failed |
|
||||
| `arguments` | `str` | Arguments passed to the tool |
|
||||
| `error` | `str` | Error message |
|
||||
| `tool_result` | `str` | The result returned to the model |
|
||||
|
||||
## Extracting Data from Messages
|
||||
|
||||
Messages follow OpenAI format. Common patterns:
|
||||
|
||||
```python
|
||||
# Get final assistant response
|
||||
for msg in reversed(result.messages):
|
||||
if msg.get("role") == "assistant" and msg.get("content"):
|
||||
final_response = msg["content"]
|
||||
break
|
||||
|
||||
# Get all tool names used
|
||||
tools = []
|
||||
for msg in result.messages:
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
|
||||
tools.append(fn.get("name", ""))
|
||||
|
||||
# Get tool results
|
||||
for msg in result.messages:
|
||||
if msg.get("role") == "tool":
|
||||
tool_output = msg.get("content", "")
|
||||
call_id = msg.get("tool_call_id", "")
|
||||
```
|
||||
|
||||
## Fields that DO NOT EXIST
|
||||
|
||||
These are common mistakes — AgentResult does NOT have:
|
||||
- `final_response` — extract from messages
|
||||
- `tool_calls` — extract from messages
|
||||
- `tools_used` — extract from messages
|
||||
- `output` — extract from messages
|
||||
- `response` — extract from messages
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
# Atropos BaseEnv Reference
|
||||
|
||||
Source: `atroposlib/envs/base.py` (~2124 lines)
|
||||
|
||||
## Abstract Methods (MUST implement)
|
||||
|
||||
| Method | Signature | Description |
|
||||
|--------|-----------|-------------|
|
||||
| `get_next_item()` | `async def get_next_item(self) -> Item` | Return next item for trajectory. Return None to pause. |
|
||||
| `evaluate()` | `async def evaluate(self, *args, **kwargs)` | Called every steps_per_eval steps. |
|
||||
| `setup()` | `async def setup(self)` | Called once at start. Load datasets, init models. |
|
||||
| `collect_trajectory()` | `async def collect_trajectory(self, item) -> Tuple[Optional[ScoredDataItem], List[Item]]` | Single rollout. Or override collect_trajectories instead. |
|
||||
|
||||
## Overridable Methods
|
||||
|
||||
| Method | Default Behavior | Override When |
|
||||
|--------|-----------------|---------------|
|
||||
| `collect_trajectories()` | Runs collect_trajectory group_size times in parallel | Batch generation, MCTS, coupled rollouts |
|
||||
| `wandb_log()` | Logs completion lengths, rollout table, perf stats | Add custom metrics (always call super) |
|
||||
| `config_init()` | Returns (env_config_cls(), ServerBaseline()) | Custom defaults + server configs |
|
||||
| `postprocess_histories()` | Passthrough | Final processing before sending to trainer |
|
||||
| `save_checkpoint()` | Saves JSON to checkpoint_dir | Custom serialization |
|
||||
| `cleanup()` | No-op | Release resources after each rollout |
|
||||
|
||||
## ScoredDataGroup Structure
|
||||
|
||||
```python
|
||||
ScoredDataGroup = TypedDict with:
|
||||
tokens: List[List[int]] # Token IDs per rollout
|
||||
masks: List[List[int]] # -100=prompt, token_id=completion
|
||||
scores: List[float] # Score per rollout
|
||||
advantages: Optional[...] # Per-token advantages
|
||||
ref_logprobs: Optional[...] # Reference model logprobs
|
||||
messages: Optional[...] # OpenAI-format messages
|
||||
inference_logprobs: Optional[...] # Inference logprobs
|
||||
```
|
||||
|
||||
## BaseEnvConfig Key Fields
|
||||
|
||||
| Field | Default | Description |
|
||||
|-------|---------|-------------|
|
||||
| `group_size` | 4 | Responses grouped for scoring |
|
||||
| `steps_per_eval` | 100 | Steps between evaluations |
|
||||
| `max_token_length` | 2048 | Max token length for generations |
|
||||
| `total_steps` | 1000 | Total training steps |
|
||||
| `use_wandb` | True | Enable wandb logging |
|
||||
| `tokenizer_name` | DeepHermes-3 | Tokenizer for token encoding |
|
||||
| `ensure_scores_are_not_same` | True | Skip groups with identical scores |
|
||||
| `worker_timeout` | 600 | Task timeout seconds |
|
||||
|
||||
## Data Flow
|
||||
|
||||
```
|
||||
env_manager() → add_train_workers() → handle_env()
|
||||
→ collect_trajectories() → postprocess_histories()
|
||||
→ handle_send_to_api() → training server
|
||||
```
|
||||
|
||||
## Atropos Environment Statistics (82 environments analyzed)
|
||||
|
||||
- 95% implement setup, collect_trajectories, evaluate, get_next_item
|
||||
- 76% override wandb_log
|
||||
- 54% have custom config class
|
||||
- Most use collect_trajectories (plural), not collect_trajectory (singular)
|
||||
- Common reward patterns: LLM-judge (~40), regex-extract (~35), code-exec (~12)
|
||||
|
|
@ -0,0 +1,199 @@
|
|||
# Usage Patterns — Testing Environments and Evaluating Models
|
||||
|
||||
## Pattern 1: Test Your Environment Works (process mode)
|
||||
|
||||
Use `process` mode to verify your environment runs end-to-end before
|
||||
committing. This generates trajectories without needing an Atropos
|
||||
training server.
|
||||
|
||||
**Before running:** Ask the user for their inference setup (see SKILL.md "Inference Setup" section). Replace `<BASE_URL>`, `<MODEL>`, and `<SERVER_TYPE>` below with their chosen values.
|
||||
|
||||
### Step 1: Run 1 trajectory
|
||||
|
||||
```bash
|
||||
cd ~/.hermes/hermes-agent
|
||||
source venv/bin/activate
|
||||
|
||||
python environments/your_env.py process \
|
||||
--env.total_steps 1 \
|
||||
--env.group_size 1 \
|
||||
--env.use_wandb false \
|
||||
--env.data_path_to_save_groups /tmp/test_output.jsonl \
|
||||
--openai.base_url "<BASE_URL>" \
|
||||
--openai.model_name "<MODEL>" \
|
||||
--openai.server_type <SERVER_TYPE> \
|
||||
--openai.health_check false
|
||||
```
|
||||
|
||||
### Step 2: Verify the output
|
||||
|
||||
```python
|
||||
import json
|
||||
for line in open("/tmp/test_output.jsonl"):
|
||||
data = json.loads(line)
|
||||
print(f"Scores: {data.get('scores', [])}")
|
||||
print(f"Token sequences: {len(data.get('tokens', []))}")
|
||||
# Check messages include tool calls
|
||||
for msg_list in data.get("messages", []):
|
||||
roles = [m.get("role") for m in msg_list]
|
||||
print(f"Roles: {roles}")
|
||||
for m in reversed(msg_list):
|
||||
if m.get("role") == "assistant" and m.get("content"):
|
||||
print(f"Response: {m['content'][:200]}...")
|
||||
break
|
||||
```
|
||||
|
||||
### What to check:
|
||||
- **Scores are not all 0.0** — if so, compute_reward is broken
|
||||
- **Scores are in [0, 1]** — not negative, not >1
|
||||
- **Messages include "tool" role entries** — agent used tools
|
||||
- **Token sequences are non-empty**
|
||||
- **An HTML visualization is generated** next to the .jsonl
|
||||
|
||||
### Common failures:
|
||||
- `'AgentResult' object has no attribute 'X'` — accessing a field that doesn't exist. See agentresult-fields.md.
|
||||
- Score always 0.0 — reward function erroring silently
|
||||
- Score always 1.0 — verification too lenient or not running
|
||||
|
||||
|
||||
## Pattern 2: Evaluate a Model (evaluate mode)
|
||||
|
||||
Use `evaluate` mode to benchmark a model on your environment's eval
|
||||
split. This runs the full agent loop with tools for each eval item.
|
||||
|
||||
### Step 1: Run evaluation
|
||||
|
||||
```bash
|
||||
python environments/your_env.py evaluate \
|
||||
--env.eval_size 20 \
|
||||
--env.use_wandb false \
|
||||
--env.data_dir_to_save_evals /tmp/eval_results \
|
||||
--openai.base_url "<BASE_URL>" \
|
||||
--openai.model_name "<MODEL>" \
|
||||
--openai.server_type <SERVER_TYPE> \
|
||||
--openai.health_check false
|
||||
```
|
||||
|
||||
### Step 2: Read results
|
||||
|
||||
Stdout shows a lighteval-compatible table:
|
||||
|
||||
```
|
||||
Evaluation Results: your-env_eval
|
||||
|Metric | Value|
|
||||
|mean correctness| 0.850 |
|
||||
|mean reward | 0.920 |
|
||||
|mean tool calls | 4.300 |
|
||||
|n items | 20 |
|
||||
Evaluation completed in 367 seconds
|
||||
```
|
||||
|
||||
JSON results saved to the eval directory:
|
||||
|
||||
```python
|
||||
import json
|
||||
data = json.load(open("/tmp/eval_results/metrics.json"))
|
||||
for metric, value in data["results"]["all"].items():
|
||||
print(f"{metric}: {value}")
|
||||
```
|
||||
|
||||
### Step 3: Compare models
|
||||
|
||||
Run evaluate with different models and compare the metrics.json files.
|
||||
|
||||
### What to check:
|
||||
- **"data_dir_to_save_evals is not set"** — you forgot the flag, results won't be saved
|
||||
- **Tool usage rate = 0** — evaluate() is using chat_completion instead of HermesAgentLoop
|
||||
- **All scores identical** — judge failing, falling back to heuristic
|
||||
- **Very slow** — each item runs a full agent loop (~30-90s). Use `--env.eval_size 5` for quick checks.
|
||||
|
||||
|
||||
## Pattern 3: Generate Training Data (process mode, larger scale)
|
||||
|
||||
Generate trajectory data for offline training or analysis:
|
||||
|
||||
```bash
|
||||
python environments/your_env.py process \
|
||||
--env.total_steps 50 \
|
||||
--env.group_size 4 \
|
||||
--env.use_wandb false \
|
||||
--env.data_path_to_save_groups data/trajectories.jsonl \
|
||||
--openai.base_url "<BASE_URL>" \
|
||||
--openai.model_name "<MODEL>" \
|
||||
--openai.server_type <SERVER_TYPE> \
|
||||
--openai.health_check false
|
||||
```
|
||||
|
||||
### Analyze the distribution:
|
||||
|
||||
```python
|
||||
import json
|
||||
scores = []
|
||||
for line in open("data/trajectories.jsonl"):
|
||||
data = json.loads(line)
|
||||
scores.extend(data.get("scores", []))
|
||||
|
||||
print(f"Total: {len(scores)}, Mean: {sum(scores)/len(scores):.3f}")
|
||||
for bucket in [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]:
|
||||
count = sum(1 for s in scores if abs(s - bucket) < 0.1)
|
||||
print(f" {bucket:.1f}: {'█' * count} ({count})")
|
||||
```
|
||||
|
||||
### What to check:
|
||||
- **Score distribution has variance** — RL needs score variance. All-same scores are useless.
|
||||
|
||||
|
||||
## Pattern 4: Full RL Training (serve mode)
|
||||
|
||||
For actual RL training with Atropos:
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start Atropos API server
|
||||
run-api
|
||||
|
||||
# Terminal 2: Start your environment
|
||||
python environments/your_env.py serve \
|
||||
--config environments/your_env/default.yaml
|
||||
```
|
||||
|
||||
For Phase 2 with VLLM:
|
||||
|
||||
```bash
|
||||
# Terminal 1: VLLM server
|
||||
python -m vllm.entrypoints.openai.api_server --model your-model --port 8000
|
||||
|
||||
# Terminal 2: Atropos API
|
||||
run-api
|
||||
|
||||
# Terminal 3: Environment
|
||||
python environments/your_env.py serve \
|
||||
--openai.base_url http://localhost:8000/v1 \
|
||||
--openai.model_name your-model \
|
||||
--openai.server_type vllm
|
||||
```
|
||||
|
||||
|
||||
## Pattern 5: Quick Smoke Test
|
||||
|
||||
Verify imports and config before spending money on API calls:
|
||||
|
||||
```python
|
||||
from environments.your_env import YourEnv
|
||||
print(f"Name: {YourEnv.name}")
|
||||
cfg, servers = YourEnv.config_init()
|
||||
print(f"Toolsets: {cfg.enabled_toolsets}")
|
||||
print(f"Server: {servers[0].model_name}")
|
||||
print("All imports OK")
|
||||
```
|
||||
|
||||
|
||||
## Timing Expectations
|
||||
|
||||
| Mode | Items | Time per item | Total |
|
||||
|------|-------|--------------|-------|
|
||||
| process (1 item) | 1 | 30-90s | ~1 min |
|
||||
| evaluate (5 items) | 5 | 30-90s | ~5 min |
|
||||
| evaluate (20 items) | 20 | 30-90s | ~15-30 min |
|
||||
| process (50 items) | 50 | 30-90s | ~30-75 min |
|
||||
|
||||
Times are for cloud APIs with Claude Sonnet-class models. Local models may be faster or slower depending on hardware.
|
||||
519
optional-skills/mlops/huggingface-tokenizers/SKILL.md
Normal file
519
optional-skills/mlops/huggingface-tokenizers/SKILL.md
Normal file
|
|
@ -0,0 +1,519 @@
|
|||
---
|
||||
name: huggingface-tokenizers
|
||||
description: Fast tokenizers optimized for research and production. Rust-based implementation tokenizes 1GB in <20 seconds. Supports BPE, WordPiece, and Unigram algorithms. Train custom vocabularies, track alignments, handle padding/truncation. Integrates seamlessly with transformers. Use when you need high-performance tokenization or custom tokenizer training.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [tokenizers, transformers, datasets]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Tokenization, HuggingFace, BPE, WordPiece, Unigram, Fast Tokenization, Rust, Custom Tokenizer, Alignment Tracking, Production]
|
||||
|
||||
---
|
||||
|
||||
# HuggingFace Tokenizers - Fast Tokenization for NLP
|
||||
|
||||
Fast, production-ready tokenizers with Rust performance and Python ease-of-use.
|
||||
|
||||
## When to use HuggingFace Tokenizers
|
||||
|
||||
**Use HuggingFace Tokenizers when:**
|
||||
- Need extremely fast tokenization (<20s per GB of text)
|
||||
- Training custom tokenizers from scratch
|
||||
- Want alignment tracking (token → original text position)
|
||||
- Building production NLP pipelines
|
||||
- Need to tokenize large corpora efficiently
|
||||
|
||||
**Performance**:
|
||||
- **Speed**: <20 seconds to tokenize 1GB on CPU
|
||||
- **Implementation**: Rust core with Python/Node.js bindings
|
||||
- **Efficiency**: 10-100× faster than pure Python implementations
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **SentencePiece**: Language-independent, used by T5/ALBERT
|
||||
- **tiktoken**: OpenAI's BPE tokenizer for GPT models
|
||||
- **transformers AutoTokenizer**: Loading pretrained only (uses this library internally)
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Install tokenizers
|
||||
pip install tokenizers
|
||||
|
||||
# With transformers integration
|
||||
pip install tokenizers transformers
|
||||
```
|
||||
|
||||
### Load pretrained tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
|
||||
# Load from HuggingFace Hub
|
||||
tokenizer = Tokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Encode text
|
||||
output = tokenizer.encode("Hello, how are you?")
|
||||
print(output.tokens) # ['hello', ',', 'how', 'are', 'you', '?']
|
||||
print(output.ids) # [7592, 1010, 2129, 2024, 2017, 1029]
|
||||
|
||||
# Decode back
|
||||
text = tokenizer.decode(output.ids)
|
||||
print(text) # "hello, how are you?"
|
||||
```
|
||||
|
||||
### Train custom BPE tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from tokenizers.pre_tokenizers import Whitespace
|
||||
|
||||
# Initialize tokenizer with BPE model
|
||||
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
|
||||
tokenizer.pre_tokenizer = Whitespace()
|
||||
|
||||
# Configure trainer
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=30000,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
|
||||
min_frequency=2
|
||||
)
|
||||
|
||||
# Train on files
|
||||
files = ["train.txt", "validation.txt"]
|
||||
tokenizer.train(files, trainer)
|
||||
|
||||
# Save
|
||||
tokenizer.save("my-tokenizer.json")
|
||||
```
|
||||
|
||||
**Training time**: ~1-2 minutes for 100MB corpus, ~10-20 minutes for 1GB
|
||||
|
||||
### Batch encoding with padding
|
||||
|
||||
```python
|
||||
# Enable padding
|
||||
tokenizer.enable_padding(pad_id=3, pad_token="[PAD]")
|
||||
|
||||
# Encode batch
|
||||
texts = ["Hello world", "This is a longer sentence"]
|
||||
encodings = tokenizer.encode_batch(texts)
|
||||
|
||||
for encoding in encodings:
|
||||
print(encoding.ids)
|
||||
# [101, 7592, 2088, 102, 3, 3, 3]
|
||||
# [101, 2023, 2003, 1037, 2936, 6251, 102]
|
||||
```
|
||||
|
||||
## Tokenization algorithms
|
||||
|
||||
### BPE (Byte-Pair Encoding)
|
||||
|
||||
**How it works**:
|
||||
1. Start with character-level vocabulary
|
||||
2. Find most frequent character pair
|
||||
3. Merge into new token, add to vocabulary
|
||||
4. Repeat until vocabulary size reached
|
||||
|
||||
**Used by**: GPT-2, GPT-3, RoBERTa, BART, DeBERTa
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
|
||||
tokenizer = Tokenizer(BPE(unk_token="<|endoftext|>"))
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=50257,
|
||||
special_tokens=["<|endoftext|>"],
|
||||
min_frequency=2
|
||||
)
|
||||
|
||||
tokenizer.train(files=["data.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- Handles OOV words well (breaks into subwords)
|
||||
- Flexible vocabulary size
|
||||
- Good for morphologically rich languages
|
||||
|
||||
**Trade-offs**:
|
||||
- Tokenization depends on merge order
|
||||
- May split common words unexpectedly
|
||||
|
||||
### WordPiece
|
||||
|
||||
**How it works**:
|
||||
1. Start with character vocabulary
|
||||
2. Score merge pairs: `frequency(pair) / (frequency(first) × frequency(second))`
|
||||
3. Merge highest scoring pair
|
||||
4. Repeat until vocabulary size reached
|
||||
|
||||
**Used by**: BERT, DistilBERT, MobileBERT
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import WordPiece
|
||||
from tokenizers.trainers import WordPieceTrainer
|
||||
from tokenizers.pre_tokenizers import Whitespace
|
||||
from tokenizers.normalizers import BertNormalizer
|
||||
|
||||
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
|
||||
tokenizer.normalizer = BertNormalizer(lowercase=True)
|
||||
tokenizer.pre_tokenizer = Whitespace()
|
||||
|
||||
trainer = WordPieceTrainer(
|
||||
vocab_size=30522,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
|
||||
continuing_subword_prefix="##"
|
||||
)
|
||||
|
||||
tokenizer.train(files=["corpus.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- Prioritizes meaningful merges (high score = semantically related)
|
||||
- Used successfully in BERT (state-of-the-art results)
|
||||
|
||||
**Trade-offs**:
|
||||
- Unknown words become `[UNK]` if no subword match
|
||||
- Saves vocabulary, not merge rules (larger files)
|
||||
|
||||
### Unigram
|
||||
|
||||
**How it works**:
|
||||
1. Start with large vocabulary (all substrings)
|
||||
2. Compute loss for corpus with current vocabulary
|
||||
3. Remove tokens with minimal impact on loss
|
||||
4. Repeat until vocabulary size reached
|
||||
|
||||
**Used by**: ALBERT, T5, mBART, XLNet (via SentencePiece)
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import Unigram
|
||||
from tokenizers.trainers import UnigramTrainer
|
||||
|
||||
tokenizer = Tokenizer(Unigram())
|
||||
|
||||
trainer = UnigramTrainer(
|
||||
vocab_size=8000,
|
||||
special_tokens=["<unk>", "<s>", "</s>"],
|
||||
unk_token="<unk>"
|
||||
)
|
||||
|
||||
tokenizer.train(files=["data.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- Probabilistic (finds most likely tokenization)
|
||||
- Works well for languages without word boundaries
|
||||
- Handles diverse linguistic contexts
|
||||
|
||||
**Trade-offs**:
|
||||
- Computationally expensive to train
|
||||
- More hyperparameters to tune
|
||||
|
||||
## Tokenization pipeline
|
||||
|
||||
Complete pipeline: **Normalization → Pre-tokenization → Model → Post-processing**
|
||||
|
||||
### Normalization
|
||||
|
||||
Clean and standardize text:
|
||||
|
||||
```python
|
||||
from tokenizers.normalizers import NFD, StripAccents, Lowercase, Sequence
|
||||
|
||||
tokenizer.normalizer = Sequence([
|
||||
NFD(), # Unicode normalization (decompose)
|
||||
Lowercase(), # Convert to lowercase
|
||||
StripAccents() # Remove accents
|
||||
])
|
||||
|
||||
# Input: "Héllo WORLD"
|
||||
# After normalization: "hello world"
|
||||
```
|
||||
|
||||
**Common normalizers**:
|
||||
- `NFD`, `NFC`, `NFKD`, `NFKC` - Unicode normalization forms
|
||||
- `Lowercase()` - Convert to lowercase
|
||||
- `StripAccents()` - Remove accents (é → e)
|
||||
- `Strip()` - Remove whitespace
|
||||
- `Replace(pattern, content)` - Regex replacement
|
||||
|
||||
### Pre-tokenization
|
||||
|
||||
Split text into word-like units:
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Whitespace, Punctuation, Sequence, ByteLevel
|
||||
|
||||
# Split on whitespace and punctuation
|
||||
tokenizer.pre_tokenizer = Sequence([
|
||||
Whitespace(),
|
||||
Punctuation()
|
||||
])
|
||||
|
||||
# Input: "Hello, world!"
|
||||
# After pre-tokenization: ["Hello", ",", "world", "!"]
|
||||
```
|
||||
|
||||
**Common pre-tokenizers**:
|
||||
- `Whitespace()` - Split on spaces, tabs, newlines
|
||||
- `ByteLevel()` - GPT-2 style byte-level splitting
|
||||
- `Punctuation()` - Isolate punctuation
|
||||
- `Digits(individual_digits=True)` - Split digits individually
|
||||
- `Metaspace()` - Replace spaces with ▁ (SentencePiece style)
|
||||
|
||||
### Post-processing
|
||||
|
||||
Add special tokens for model input:
|
||||
|
||||
```python
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
|
||||
# BERT-style: [CLS] sentence [SEP]
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="[CLS] $A [SEP]",
|
||||
pair="[CLS] $A [SEP] $B [SEP]",
|
||||
special_tokens=[
|
||||
("[CLS]", 1),
|
||||
("[SEP]", 2),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
**Common patterns**:
|
||||
```python
|
||||
# GPT-2: sentence <|endoftext|>
|
||||
TemplateProcessing(
|
||||
single="$A <|endoftext|>",
|
||||
special_tokens=[("<|endoftext|>", 50256)]
|
||||
)
|
||||
|
||||
# RoBERTa: <s> sentence </s>
|
||||
TemplateProcessing(
|
||||
single="<s> $A </s>",
|
||||
pair="<s> $A </s> </s> $B </s>",
|
||||
special_tokens=[("<s>", 0), ("</s>", 2)]
|
||||
)
|
||||
```
|
||||
|
||||
## Alignment tracking
|
||||
|
||||
Track token positions in original text:
|
||||
|
||||
```python
|
||||
output = tokenizer.encode("Hello, world!")
|
||||
|
||||
# Get token offsets
|
||||
for token, offset in zip(output.tokens, output.offsets):
|
||||
start, end = offset
|
||||
print(f"{token:10} → [{start:2}, {end:2}): {text[start:end]!r}")
|
||||
|
||||
# Output:
|
||||
# hello → [ 0, 5): 'Hello'
|
||||
# , → [ 5, 6): ','
|
||||
# world → [ 7, 12): 'world'
|
||||
# ! → [12, 13): '!'
|
||||
```
|
||||
|
||||
**Use cases**:
|
||||
- Named entity recognition (map predictions back to text)
|
||||
- Question answering (extract answer spans)
|
||||
- Token classification (align labels to original positions)
|
||||
|
||||
## Integration with transformers
|
||||
|
||||
### Load with AutoTokenizer
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# AutoTokenizer automatically uses fast tokenizers
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Check if using fast tokenizer
|
||||
print(tokenizer.is_fast) # True
|
||||
|
||||
# Access underlying tokenizers.Tokenizer
|
||||
fast_tokenizer = tokenizer.backend_tokenizer
|
||||
print(type(fast_tokenizer)) # <class 'tokenizers.Tokenizer'>
|
||||
```
|
||||
|
||||
### Convert custom tokenizer to transformers
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
# Train custom tokenizer
|
||||
tokenizer = Tokenizer(BPE())
|
||||
# ... train tokenizer ...
|
||||
tokenizer.save("my-tokenizer.json")
|
||||
|
||||
# Wrap for transformers
|
||||
transformers_tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_file="my-tokenizer.json",
|
||||
unk_token="[UNK]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
sep_token="[SEP]",
|
||||
mask_token="[MASK]"
|
||||
)
|
||||
|
||||
# Use like any transformers tokenizer
|
||||
outputs = transformers_tokenizer(
|
||||
"Hello world",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
return_tensors="pt"
|
||||
)
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Train from iterator (large datasets)
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
|
||||
|
||||
# Create batch iterator
|
||||
def batch_iterator(batch_size=1000):
|
||||
for i in range(0, len(dataset), batch_size):
|
||||
yield dataset[i:i + batch_size]["text"]
|
||||
|
||||
# Train tokenizer
|
||||
tokenizer.train_from_iterator(
|
||||
batch_iterator(),
|
||||
trainer=trainer,
|
||||
length=len(dataset) # For progress bar
|
||||
)
|
||||
```
|
||||
|
||||
**Performance**: Processes 1GB in ~10-20 minutes
|
||||
|
||||
### Enable truncation and padding
|
||||
|
||||
```python
|
||||
# Enable truncation
|
||||
tokenizer.enable_truncation(max_length=512)
|
||||
|
||||
# Enable padding
|
||||
tokenizer.enable_padding(
|
||||
pad_id=tokenizer.token_to_id("[PAD]"),
|
||||
pad_token="[PAD]",
|
||||
length=512 # Fixed length, or None for batch max
|
||||
)
|
||||
|
||||
# Encode with both
|
||||
output = tokenizer.encode("This is a long sentence that will be truncated...")
|
||||
print(len(output.ids)) # 512
|
||||
```
|
||||
|
||||
### Multi-processing
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from multiprocessing import Pool
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = Tokenizer.from_file("tokenizer.json")
|
||||
|
||||
def encode_batch(texts):
|
||||
return tokenizer.encode_batch(texts)
|
||||
|
||||
# Process large corpus in parallel
|
||||
with Pool(8) as pool:
|
||||
# Split corpus into chunks
|
||||
chunk_size = 1000
|
||||
chunks = [corpus[i:i+chunk_size] for i in range(0, len(corpus), chunk_size)]
|
||||
|
||||
# Encode in parallel
|
||||
results = pool.map(encode_batch, chunks)
|
||||
```
|
||||
|
||||
**Speedup**: 5-8× with 8 cores
|
||||
|
||||
## Performance benchmarks
|
||||
|
||||
### Training speed
|
||||
|
||||
| Corpus Size | BPE (30k vocab) | WordPiece (30k) | Unigram (8k) |
|
||||
|-------------|-----------------|-----------------|--------------|
|
||||
| 10 MB | 15 sec | 18 sec | 25 sec |
|
||||
| 100 MB | 1.5 min | 2 min | 4 min |
|
||||
| 1 GB | 15 min | 20 min | 40 min |
|
||||
|
||||
**Hardware**: 16-core CPU, tested on English Wikipedia
|
||||
|
||||
### Tokenization speed
|
||||
|
||||
| Implementation | 1 GB corpus | Throughput |
|
||||
|----------------|-------------|---------------|
|
||||
| Pure Python | ~20 minutes | ~50 MB/min |
|
||||
| HF Tokenizers | ~15 seconds | ~4 GB/min |
|
||||
| **Speedup** | **80×** | **80×** |
|
||||
|
||||
**Test**: English text, average sentence length 20 words
|
||||
|
||||
### Memory usage
|
||||
|
||||
| Task | Memory |
|
||||
|-------------------------|---------|
|
||||
| Load tokenizer | ~10 MB |
|
||||
| Train BPE (30k vocab) | ~200 MB |
|
||||
| Encode 1M sentences | ~500 MB |
|
||||
|
||||
## Supported models
|
||||
|
||||
Pre-trained tokenizers available via `from_pretrained()`:
|
||||
|
||||
**BERT family**:
|
||||
- `bert-base-uncased`, `bert-large-cased`
|
||||
- `distilbert-base-uncased`
|
||||
- `roberta-base`, `roberta-large`
|
||||
|
||||
**GPT family**:
|
||||
- `gpt2`, `gpt2-medium`, `gpt2-large`
|
||||
- `distilgpt2`
|
||||
|
||||
**T5 family**:
|
||||
- `t5-small`, `t5-base`, `t5-large`
|
||||
- `google/flan-t5-xxl`
|
||||
|
||||
**Other**:
|
||||
- `facebook/bart-base`, `facebook/mbart-large-cc25`
|
||||
- `albert-base-v2`, `albert-xlarge-v2`
|
||||
- `xlm-roberta-base`, `xlm-roberta-large`
|
||||
|
||||
Browse all: https://huggingface.co/models?library=tokenizers
|
||||
|
||||
## References
|
||||
|
||||
- **[Training Guide](references/training.md)** - Train custom tokenizers, configure trainers, handle large datasets
|
||||
- **[Algorithms Deep Dive](references/algorithms.md)** - BPE, WordPiece, Unigram explained in detail
|
||||
- **[Pipeline Components](references/pipeline.md)** - Normalizers, pre-tokenizers, post-processors, decoders
|
||||
- **[Transformers Integration](references/integration.md)** - AutoTokenizer, PreTrainedTokenizerFast, special tokens
|
||||
|
||||
## Resources
|
||||
|
||||
- **Docs**: https://huggingface.co/docs/tokenizers
|
||||
- **GitHub**: https://github.com/huggingface/tokenizers ⭐ 9,000+
|
||||
- **Version**: 0.20.0+
|
||||
- **Course**: https://huggingface.co/learn/nlp-course/chapter6/1
|
||||
- **Paper**: BPE (Sennrich et al., 2016), WordPiece (Schuster & Nakajima, 2012)
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,653 @@
|
|||
# Tokenization Algorithms Deep Dive
|
||||
|
||||
Comprehensive explanation of BPE, WordPiece, and Unigram algorithms.
|
||||
|
||||
## Byte-Pair Encoding (BPE)
|
||||
|
||||
### Algorithm overview
|
||||
|
||||
BPE iteratively merges the most frequent pair of tokens in a corpus.
|
||||
|
||||
**Training process**:
|
||||
1. Initialize vocabulary with all characters
|
||||
2. Count frequency of all adjacent token pairs
|
||||
3. Merge most frequent pair into new token
|
||||
4. Add new token to vocabulary
|
||||
5. Update corpus with new token
|
||||
6. Repeat until vocabulary size reached
|
||||
|
||||
### Step-by-step example
|
||||
|
||||
**Corpus**:
|
||||
```
|
||||
low: 5
|
||||
lower: 2
|
||||
newest: 6
|
||||
widest: 3
|
||||
```
|
||||
|
||||
**Iteration 1**:
|
||||
```
|
||||
Count pairs:
|
||||
'e' + 's': 9 (newest: 6, widest: 3) ← most frequent
|
||||
'l' + 'o': 7
|
||||
'o' + 'w': 7
|
||||
...
|
||||
|
||||
Merge: 'e' + 's' → 'es'
|
||||
|
||||
Updated corpus:
|
||||
low: 5
|
||||
lower: 2
|
||||
newest: 6 → newes|t: 6
|
||||
widest: 3 → wides|t: 3
|
||||
|
||||
Vocabulary: [a-z] + ['es']
|
||||
```
|
||||
|
||||
**Iteration 2**:
|
||||
```
|
||||
Count pairs:
|
||||
'es' + 't': 9 ← most frequent
|
||||
'l' + 'o': 7
|
||||
...
|
||||
|
||||
Merge: 'es' + 't' → 'est'
|
||||
|
||||
Updated corpus:
|
||||
low: 5
|
||||
lower: 2
|
||||
newest: 6 → new|est: 6
|
||||
widest: 3 → wid|est: 3
|
||||
|
||||
Vocabulary: [a-z] + ['es', 'est']
|
||||
```
|
||||
|
||||
**Continue until desired vocabulary size...**
|
||||
|
||||
### Tokenization with trained BPE
|
||||
|
||||
Given vocabulary: `['l', 'o', 'w', 'e', 'r', 'n', 's', 't', 'i', 'd', 'es', 'est', 'lo', 'low', 'ne', 'new', 'newest', 'wi', 'wid', 'widest']`
|
||||
|
||||
Tokenize "lowest":
|
||||
```
|
||||
Step 1: Split into characters
|
||||
['l', 'o', 'w', 'e', 's', 't']
|
||||
|
||||
Step 2: Apply merges in order learned during training
|
||||
- Merge 'l' + 'o' → 'lo' (if this merge was learned)
|
||||
- Merge 'lo' + 'w' → 'low' (if learned)
|
||||
- Merge 'e' + 's' → 'es' (learned)
|
||||
- Merge 'es' + 't' → 'est' (learned)
|
||||
|
||||
Final: ['low', 'est']
|
||||
```
|
||||
|
||||
### Implementation
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from tokenizers.pre_tokenizers import Whitespace
|
||||
|
||||
# Initialize
|
||||
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
|
||||
tokenizer.pre_tokenizer = Whitespace()
|
||||
|
||||
# Configure trainer
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=1000,
|
||||
min_frequency=2,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
|
||||
)
|
||||
|
||||
# Train
|
||||
corpus = [
|
||||
"This is a sample corpus for BPE training.",
|
||||
"BPE learns subword units from the training data.",
|
||||
# ... more sentences
|
||||
]
|
||||
|
||||
tokenizer.train_from_iterator(corpus, trainer=trainer)
|
||||
|
||||
# Use
|
||||
output = tokenizer.encode("This is tokenization")
|
||||
print(output.tokens) # ['This', 'is', 'token', 'ization']
|
||||
```
|
||||
|
||||
### Byte-level BPE (GPT-2 variant)
|
||||
|
||||
**Problem**: Standard BPE has limited character coverage (256+ Unicode chars)
|
||||
|
||||
**Solution**: Operate on byte level (256 bytes)
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
|
||||
|
||||
tokenizer = Tokenizer(BPE())
|
||||
|
||||
# Byte-level pre-tokenization
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
tokenizer.decoder = ByteLevelDecoder()
|
||||
|
||||
# This handles ALL possible characters, including emojis
|
||||
text = "Hello 🌍 世界"
|
||||
tokens = tokenizer.encode(text).tokens
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- Handles any Unicode character (256 byte coverage)
|
||||
- No unknown tokens (worst case: bytes)
|
||||
- Used by GPT-2, GPT-3, BART
|
||||
|
||||
**Trade-offs**:
|
||||
- Slightly worse compression (bytes vs characters)
|
||||
- More tokens for non-ASCII text
|
||||
|
||||
### BPE variants
|
||||
|
||||
**SentencePiece BPE**:
|
||||
- Language-independent (no pre-tokenization)
|
||||
- Treats input as raw byte stream
|
||||
- Used by T5, ALBERT, XLNet
|
||||
|
||||
**Robust BPE**:
|
||||
- Dropout during training (randomly skip merges)
|
||||
- More robust tokenization at inference
|
||||
- Reduces overfitting to training data
|
||||
|
||||
## WordPiece
|
||||
|
||||
### Algorithm overview
|
||||
|
||||
WordPiece is similar to BPE but uses a different merge selection criterion.
|
||||
|
||||
**Training process**:
|
||||
1. Initialize vocabulary with all characters
|
||||
2. Count frequency of all token pairs
|
||||
3. Score each pair: `score = freq(pair) / (freq(first) × freq(second))`
|
||||
4. Merge pair with highest score
|
||||
5. Repeat until vocabulary size reached
|
||||
|
||||
### Why different scoring?
|
||||
|
||||
**BPE**: Merges most frequent pairs
|
||||
- "aa" appears 100 times → high priority
|
||||
- Even if 'a' appears 1000 times alone
|
||||
|
||||
**WordPiece**: Merges pairs that are semantically related
|
||||
- "aa" appears 100 times, 'a' appears 1000 times → low score (100 / (1000 × 1000))
|
||||
- "th" appears 50 times, 't' appears 60 times, 'h' appears 55 times → high score (50 / (60 × 55))
|
||||
- Prioritizes pairs that appear together more than expected
|
||||
|
||||
### Step-by-step example
|
||||
|
||||
**Corpus**:
|
||||
```
|
||||
low: 5
|
||||
lower: 2
|
||||
newest: 6
|
||||
widest: 3
|
||||
```
|
||||
|
||||
**Iteration 1**:
|
||||
```
|
||||
Count frequencies:
|
||||
'e': 11 (lower: 2, newest: 6, widest: 3)
|
||||
's': 9
|
||||
't': 9
|
||||
...
|
||||
|
||||
Count pairs:
|
||||
'e' + 's': 9 (newest: 6, widest: 3)
|
||||
'es' + 't': 9 (newest: 6, widest: 3)
|
||||
...
|
||||
|
||||
Compute scores:
|
||||
score('e' + 's') = 9 / (11 × 9) = 0.091
|
||||
score('es' + 't') = 9 / (9 × 9) = 0.111 ← highest score
|
||||
score('l' + 'o') = 7 / (7 × 9) = 0.111 ← tied
|
||||
|
||||
Choose: 'es' + 't' → 'est' (or 'lo' if tied)
|
||||
```
|
||||
|
||||
**Key difference**: WordPiece prioritizes rare combinations over frequent ones.
|
||||
|
||||
### Tokenization with WordPiece
|
||||
|
||||
Given vocabulary: `['##e', '##s', '##t', 'l', 'o', 'w', 'new', 'est', 'low']`
|
||||
|
||||
Tokenize "lowest":
|
||||
```
|
||||
Step 1: Find longest matching prefix
|
||||
'lowest' → 'low' (matches)
|
||||
|
||||
Step 2: Find longest match for remainder
|
||||
'est' → 'est' (matches)
|
||||
|
||||
Final: ['low', 'est']
|
||||
```
|
||||
|
||||
**If no match**:
|
||||
```
|
||||
Tokenize "unknownword":
|
||||
'unknownword' → no match
|
||||
'unknown' → no match
|
||||
'unkn' → no match
|
||||
'un' → no match
|
||||
'u' → no match
|
||||
→ [UNK]
|
||||
```
|
||||
|
||||
### Implementation
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import WordPiece
|
||||
from tokenizers.trainers import WordPieceTrainer
|
||||
from tokenizers.normalizers import BertNormalizer
|
||||
from tokenizers.pre_tokenizers import BertPreTokenizer
|
||||
|
||||
# Initialize BERT-style tokenizer
|
||||
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
|
||||
|
||||
# Normalization (lowercase, accent stripping)
|
||||
tokenizer.normalizer = BertNormalizer(lowercase=True)
|
||||
|
||||
# Pre-tokenization (whitespace + punctuation)
|
||||
tokenizer.pre_tokenizer = BertPreTokenizer()
|
||||
|
||||
# Configure trainer
|
||||
trainer = WordPieceTrainer(
|
||||
vocab_size=30522, # BERT vocab size
|
||||
min_frequency=2,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
|
||||
continuing_subword_prefix="##" # BERT uses ##
|
||||
)
|
||||
|
||||
# Train
|
||||
tokenizer.train_from_iterator(corpus, trainer=trainer)
|
||||
|
||||
# Use
|
||||
output = tokenizer.encode("Tokenization works great!")
|
||||
print(output.tokens) # ['token', '##ization', 'works', 'great', '!']
|
||||
```
|
||||
|
||||
### Subword prefix
|
||||
|
||||
**BERT uses `##` prefix**:
|
||||
```
|
||||
"unbelievable" → ['un', '##believ', '##able']
|
||||
```
|
||||
|
||||
**Why?**
|
||||
- Indicates token is a continuation
|
||||
- Allows reconstruction: remove ##, concatenate
|
||||
- Helps model distinguish word boundaries
|
||||
|
||||
### WordPiece advantages
|
||||
|
||||
**Semantic merges**:
|
||||
- Prioritizes meaningful combinations
|
||||
- "qu" has high score (always together)
|
||||
- "qx" has low score (rare combination)
|
||||
|
||||
**Better for morphology**:
|
||||
- Captures affixes: un-, -ing, -ed
|
||||
- Preserves word stems
|
||||
|
||||
**Trade-offs**:
|
||||
- Slower training than BPE
|
||||
- More memory (stores vocabulary, not merges)
|
||||
- Original implementation not open-source (HF reimplementation)
|
||||
|
||||
## Unigram
|
||||
|
||||
### Algorithm overview
|
||||
|
||||
Unigram works backward: start with large vocabulary, remove tokens.
|
||||
|
||||
**Training process**:
|
||||
1. Initialize with large vocabulary (all substrings)
|
||||
2. Estimate probability of each token (frequency-based)
|
||||
3. For each token, compute loss increase if removed
|
||||
4. Remove 10-20% of tokens with lowest loss impact
|
||||
5. Re-estimate probabilities
|
||||
6. Repeat until desired vocabulary size
|
||||
|
||||
### Probabilistic tokenization
|
||||
|
||||
**Unigram assumption**: Each token is independent.
|
||||
|
||||
Given vocabulary with probabilities:
|
||||
```
|
||||
P('low') = 0.02
|
||||
P('l') = 0.01
|
||||
P('o') = 0.015
|
||||
P('w') = 0.01
|
||||
P('est') = 0.03
|
||||
P('e') = 0.02
|
||||
P('s') = 0.015
|
||||
P('t') = 0.015
|
||||
```
|
||||
|
||||
Tokenize "lowest":
|
||||
```
|
||||
Option 1: ['low', 'est']
|
||||
P = P('low') × P('est') = 0.02 × 0.03 = 0.0006
|
||||
|
||||
Option 2: ['l', 'o', 'w', 'est']
|
||||
P = 0.01 × 0.015 × 0.01 × 0.03 = 0.000000045
|
||||
|
||||
Option 3: ['low', 'e', 's', 't']
|
||||
P = 0.02 × 0.02 × 0.015 × 0.015 = 0.0000009
|
||||
|
||||
Choose option 1 (highest probability)
|
||||
```
|
||||
|
||||
### Viterbi algorithm
|
||||
|
||||
Finding best tokenization is expensive (exponential possibilities).
|
||||
|
||||
**Viterbi algorithm** (dynamic programming):
|
||||
```python
|
||||
def tokenize_viterbi(word, vocab, probs):
|
||||
n = len(word)
|
||||
# dp[i] = (best_prob, best_tokens) for word[:i]
|
||||
dp = [{} for _ in range(n + 1)]
|
||||
dp[0] = (0.0, []) # log probability
|
||||
|
||||
for i in range(1, n + 1):
|
||||
best_prob = float('-inf')
|
||||
best_tokens = []
|
||||
|
||||
# Try all possible last tokens
|
||||
for j in range(i):
|
||||
token = word[j:i]
|
||||
if token in vocab:
|
||||
prob = dp[j][0] + log(probs[token])
|
||||
if prob > best_prob:
|
||||
best_prob = prob
|
||||
best_tokens = dp[j][1] + [token]
|
||||
|
||||
dp[i] = (best_prob, best_tokens)
|
||||
|
||||
return dp[n][1]
|
||||
```
|
||||
|
||||
**Time complexity**: O(n² × vocab_size) vs O(2^n) brute force
|
||||
|
||||
### Implementation
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import Unigram
|
||||
from tokenizers.trainers import UnigramTrainer
|
||||
|
||||
# Initialize
|
||||
tokenizer = Tokenizer(Unigram())
|
||||
|
||||
# Configure trainer
|
||||
trainer = UnigramTrainer(
|
||||
vocab_size=8000,
|
||||
special_tokens=["<unk>", "<s>", "</s>"],
|
||||
unk_token="<unk>",
|
||||
max_piece_length=16, # Max token length
|
||||
n_sub_iterations=2, # EM iterations
|
||||
shrinking_factor=0.75 # Remove 25% each iteration
|
||||
)
|
||||
|
||||
# Train
|
||||
tokenizer.train_from_iterator(corpus, trainer=trainer)
|
||||
|
||||
# Use
|
||||
output = tokenizer.encode("Tokenization with Unigram")
|
||||
print(output.tokens) # ['▁Token', 'ization', '▁with', '▁Un', 'igram']
|
||||
```
|
||||
|
||||
### Unigram advantages
|
||||
|
||||
**Probabilistic**:
|
||||
- Multiple valid tokenizations
|
||||
- Can sample different tokenizations (data augmentation)
|
||||
|
||||
**Subword regularization**:
|
||||
```python
|
||||
# Sample different tokenizations
|
||||
for _ in range(3):
|
||||
tokens = tokenizer.encode("tokenization", is_pretokenized=False).tokens
|
||||
print(tokens)
|
||||
|
||||
# Output (different each time):
|
||||
# ['token', 'ization']
|
||||
# ['tok', 'en', 'ization']
|
||||
# ['token', 'iz', 'ation']
|
||||
```
|
||||
|
||||
**Language-independent**:
|
||||
- No word boundaries needed
|
||||
- Works for CJK languages (Chinese, Japanese, Korean)
|
||||
- Treats input as character stream
|
||||
|
||||
**Trade-offs**:
|
||||
- Slower training (EM algorithm)
|
||||
- More hyperparameters
|
||||
- Larger model (stores probabilities)
|
||||
|
||||
## Algorithm comparison
|
||||
|
||||
### Training speed
|
||||
|
||||
| Algorithm | Small (10MB) | Medium (100MB) | Large (1GB) |
|
||||
|------------|--------------|----------------|-------------|
|
||||
| BPE | 10-15 sec | 1-2 min | 10-20 min |
|
||||
| WordPiece | 15-20 sec | 2-3 min | 15-30 min |
|
||||
| Unigram | 20-30 sec | 3-5 min | 30-60 min |
|
||||
|
||||
**Tested on**: 16-core CPU, 30k vocab
|
||||
|
||||
### Tokenization quality
|
||||
|
||||
Tested on English Wikipedia (perplexity measurement):
|
||||
|
||||
| Algorithm | Vocab Size | Tokens/Word | Unknown Rate |
|
||||
|------------|------------|-------------|--------------|
|
||||
| BPE | 30k | 1.3 | 0.5% |
|
||||
| WordPiece | 30k | 1.2 | 1.2% |
|
||||
| Unigram | 8k | 1.5 | 0.3% |
|
||||
|
||||
**Key observations**:
|
||||
- WordPiece: Slightly better compression
|
||||
- BPE: Lower unknown rate
|
||||
- Unigram: Smallest vocab, good coverage
|
||||
|
||||
### Compression ratio
|
||||
|
||||
Characters per token (higher = better compression):
|
||||
|
||||
| Language | BPE (30k) | WordPiece (30k) | Unigram (8k) |
|
||||
|----------|-----------|-----------------|--------------|
|
||||
| English | 4.2 | 4.5 | 3.8 |
|
||||
| Chinese | 2.1 | 2.3 | 2.5 |
|
||||
| Arabic | 3.5 | 3.8 | 3.2 |
|
||||
|
||||
**Best for each**:
|
||||
- English: WordPiece
|
||||
- Chinese: Unigram (language-independent)
|
||||
- Arabic: WordPiece
|
||||
|
||||
### Use case recommendations
|
||||
|
||||
**BPE** - Best for:
|
||||
- English language models
|
||||
- Code (handles symbols well)
|
||||
- Fast training needed
|
||||
- **Models**: GPT-2, GPT-3, RoBERTa, BART
|
||||
|
||||
**WordPiece** - Best for:
|
||||
- Masked language modeling (BERT-style)
|
||||
- Morphologically rich languages
|
||||
- Semantic understanding tasks
|
||||
- **Models**: BERT, DistilBERT, ELECTRA
|
||||
|
||||
**Unigram** - Best for:
|
||||
- Multilingual models
|
||||
- Languages without word boundaries (CJK)
|
||||
- Data augmentation via subword regularization
|
||||
- **Models**: T5, ALBERT, XLNet (via SentencePiece)
|
||||
|
||||
## Advanced topics
|
||||
|
||||
### Handling rare words
|
||||
|
||||
**BPE approach**:
|
||||
```
|
||||
"antidisestablishmentarianism"
|
||||
→ ['anti', 'dis', 'establish', 'ment', 'arian', 'ism']
|
||||
```
|
||||
|
||||
**WordPiece approach**:
|
||||
```
|
||||
"antidisestablishmentarianism"
|
||||
→ ['anti', '##dis', '##establish', '##ment', '##arian', '##ism']
|
||||
```
|
||||
|
||||
**Unigram approach**:
|
||||
```
|
||||
"antidisestablishmentarianism"
|
||||
→ ['▁anti', 'dis', 'establish', 'ment', 'arian', 'ism']
|
||||
```
|
||||
|
||||
### Handling numbers
|
||||
|
||||
**Challenge**: Infinite number combinations
|
||||
|
||||
**BPE solution**: Byte-level (handles any digit sequence)
|
||||
```python
|
||||
tokenizer = Tokenizer(BPE())
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
|
||||
# Handles any number
|
||||
"123456789" → byte-level tokens
|
||||
```
|
||||
|
||||
**WordPiece solution**: Digit pre-tokenization
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Digits
|
||||
|
||||
# Split digits individually or as groups
|
||||
tokenizer.pre_tokenizer = Digits(individual_digits=True)
|
||||
|
||||
"123" → ['1', '2', '3']
|
||||
```
|
||||
|
||||
**Unigram solution**: Learns common number patterns
|
||||
```python
|
||||
# Learns patterns during training
|
||||
"2023" → ['202', '3'] or ['20', '23']
|
||||
```
|
||||
|
||||
### Handling case sensitivity
|
||||
|
||||
**Lowercase (BERT)**:
|
||||
```python
|
||||
from tokenizers.normalizers import Lowercase
|
||||
|
||||
tokenizer.normalizer = Lowercase()
|
||||
|
||||
"Hello WORLD" → "hello world" → ['hello', 'world']
|
||||
```
|
||||
|
||||
**Preserve case (GPT-2)**:
|
||||
```python
|
||||
# No case normalization
|
||||
tokenizer.normalizer = None
|
||||
|
||||
"Hello WORLD" → ['Hello', 'WORLD']
|
||||
```
|
||||
|
||||
**Cased tokens (RoBERTa)**:
|
||||
```python
|
||||
# Learns separate tokens for different cases
|
||||
Vocabulary: ['Hello', 'hello', 'HELLO', 'world', 'WORLD']
|
||||
```
|
||||
|
||||
### Handling emojis and special characters
|
||||
|
||||
**Byte-level (GPT-2)**:
|
||||
```python
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
|
||||
"Hello 🌍 👋" → byte-level representation (always works)
|
||||
```
|
||||
|
||||
**Unicode normalization**:
|
||||
```python
|
||||
from tokenizers.normalizers import NFKC
|
||||
|
||||
tokenizer.normalizer = NFKC()
|
||||
|
||||
"é" (composed) ↔ "é" (decomposed) → normalized to one form
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: Poor subword splitting
|
||||
|
||||
**Symptom**:
|
||||
```
|
||||
"running" → ['r', 'u', 'n', 'n', 'i', 'n', 'g'] (too granular)
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
1. Increase vocabulary size
|
||||
2. Train longer (more merge iterations)
|
||||
3. Lower `min_frequency` threshold
|
||||
|
||||
### Issue: Too many unknown tokens
|
||||
|
||||
**Symptom**:
|
||||
```
|
||||
5% of tokens are [UNK]
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
1. Increase vocabulary size
|
||||
2. Use byte-level BPE (no UNK possible)
|
||||
3. Verify training corpus is representative
|
||||
|
||||
### Issue: Inconsistent tokenization
|
||||
|
||||
**Symptom**:
|
||||
```
|
||||
"running" → ['run', 'ning']
|
||||
"runner" → ['r', 'u', 'n', 'n', 'e', 'r']
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
1. Check normalization consistency
|
||||
2. Ensure pre-tokenization is deterministic
|
||||
3. Use Unigram for probabilistic variance
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Match algorithm to model architecture**:
|
||||
- BERT-style → WordPiece
|
||||
- GPT-style → BPE
|
||||
- T5-style → Unigram
|
||||
|
||||
2. **Use byte-level for multilingual**:
|
||||
- Handles any Unicode
|
||||
- No unknown tokens
|
||||
|
||||
3. **Test on representative data**:
|
||||
- Measure compression ratio
|
||||
- Check unknown token rate
|
||||
- Inspect sample tokenizations
|
||||
|
||||
4. **Version control tokenizers**:
|
||||
- Save with model
|
||||
- Document special tokens
|
||||
- Track vocabulary changes
|
||||
|
|
@ -0,0 +1,637 @@
|
|||
# Transformers Integration
|
||||
|
||||
Complete guide to using HuggingFace Tokenizers with the Transformers library.
|
||||
|
||||
## AutoTokenizer
|
||||
|
||||
The easiest way to load tokenizers.
|
||||
|
||||
### Loading pretrained tokenizers
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Load from HuggingFace Hub
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Check if using fast tokenizer (Rust-based)
|
||||
print(tokenizer.is_fast) # True
|
||||
|
||||
# Access underlying tokenizers.Tokenizer
|
||||
if tokenizer.is_fast:
|
||||
fast_tokenizer = tokenizer.backend_tokenizer
|
||||
print(type(fast_tokenizer)) # <class 'tokenizers.Tokenizer'>
|
||||
```
|
||||
|
||||
### Fast vs slow tokenizers
|
||||
|
||||
| Feature | Fast (Rust) | Slow (Python) |
|
||||
|--------------------------|----------------|---------------|
|
||||
| Speed | 5-10× faster | Baseline |
|
||||
| Alignment tracking | ✅ Full support | ❌ Limited |
|
||||
| Batch processing | ✅ Optimized | ⚠️ Slower |
|
||||
| Offset mapping | ✅ Yes | ❌ No |
|
||||
| Installation | `tokenizers` | Built-in |
|
||||
|
||||
**Always use fast tokenizers when available.**
|
||||
|
||||
### Check available tokenizers
|
||||
|
||||
```python
|
||||
from transformers import TOKENIZER_MAPPING
|
||||
|
||||
# List all fast tokenizers
|
||||
for config_class, (slow, fast) in TOKENIZER_MAPPING.items():
|
||||
if fast is not None:
|
||||
print(f"{config_class.__name__}: {fast.__name__}")
|
||||
```
|
||||
|
||||
## PreTrainedTokenizerFast
|
||||
|
||||
Wrap custom tokenizers for transformers.
|
||||
|
||||
### Convert custom tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
# Train custom tokenizer
|
||||
tokenizer = Tokenizer(BPE())
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=30000,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
|
||||
)
|
||||
tokenizer.train(files=["corpus.txt"], trainer=trainer)
|
||||
|
||||
# Save tokenizer
|
||||
tokenizer.save("my-tokenizer.json")
|
||||
|
||||
# Wrap for transformers
|
||||
transformers_tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_file="my-tokenizer.json",
|
||||
unk_token="[UNK]",
|
||||
sep_token="[SEP]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
mask_token="[MASK]"
|
||||
)
|
||||
|
||||
# Save in transformers format
|
||||
transformers_tokenizer.save_pretrained("my-tokenizer")
|
||||
```
|
||||
|
||||
**Result**: Directory with `tokenizer.json` + `tokenizer_config.json` + `special_tokens_map.json`
|
||||
|
||||
### Use like any transformers tokenizer
|
||||
|
||||
```python
|
||||
# Load
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("my-tokenizer")
|
||||
|
||||
# Encode with all transformers features
|
||||
outputs = tokenizer(
|
||||
"Hello world",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=128,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
print(outputs.keys())
|
||||
# dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
|
||||
```
|
||||
|
||||
## Special tokens
|
||||
|
||||
### Default special tokens
|
||||
|
||||
| Model Family | CLS/BOS | SEP/EOS | PAD | UNK | MASK |
|
||||
|--------------|---------|---------------|---------|---------|---------|
|
||||
| BERT | [CLS] | [SEP] | [PAD] | [UNK] | [MASK] |
|
||||
| GPT-2 | - | <\|endoftext\|> | <\|endoftext\|> | <\|endoftext\|> | - |
|
||||
| RoBERTa | <s> | </s> | <pad> | <unk> | <mask> |
|
||||
| T5 | - | </s> | <pad> | <unk> | - |
|
||||
|
||||
### Adding special tokens
|
||||
|
||||
```python
|
||||
# Add new special tokens
|
||||
special_tokens_dict = {
|
||||
"additional_special_tokens": ["<|image|>", "<|video|>", "<|audio|>"]
|
||||
}
|
||||
|
||||
num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
||||
print(f"Added {num_added_tokens} tokens")
|
||||
|
||||
# Resize model embeddings
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Use new tokens
|
||||
text = "This is an image: <|image|>"
|
||||
tokens = tokenizer.encode(text)
|
||||
```
|
||||
|
||||
### Adding regular tokens
|
||||
|
||||
```python
|
||||
# Add domain-specific tokens
|
||||
new_tokens = ["COVID-19", "mRNA", "vaccine"]
|
||||
num_added = tokenizer.add_tokens(new_tokens)
|
||||
|
||||
# These are NOT special tokens (can be split if needed)
|
||||
tokenizer.add_tokens(new_tokens, special_tokens=False)
|
||||
|
||||
# These ARE special tokens (never split)
|
||||
tokenizer.add_tokens(new_tokens, special_tokens=True)
|
||||
```
|
||||
|
||||
## Encoding and decoding
|
||||
|
||||
### Basic encoding
|
||||
|
||||
```python
|
||||
# Single sentence
|
||||
text = "Hello, how are you?"
|
||||
encoded = tokenizer(text)
|
||||
|
||||
print(encoded)
|
||||
# {'input_ids': [101, 7592, 1010, 2129, 2024, 2017, 1029, 102],
|
||||
# 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0],
|
||||
# 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}
|
||||
```
|
||||
|
||||
### Batch encoding
|
||||
|
||||
```python
|
||||
# Multiple sentences
|
||||
texts = ["Hello world", "How are you?", "I am fine"]
|
||||
encoded = tokenizer(texts, padding=True, truncation=True, max_length=10)
|
||||
|
||||
print(encoded['input_ids'])
|
||||
# [[101, 7592, 2088, 102, 0, 0, 0, 0, 0, 0],
|
||||
# [101, 2129, 2024, 2017, 1029, 102, 0, 0, 0, 0],
|
||||
# [101, 1045, 2572, 2986, 102, 0, 0, 0, 0, 0]]
|
||||
```
|
||||
|
||||
### Return tensors
|
||||
|
||||
```python
|
||||
# Return PyTorch tensors
|
||||
outputs = tokenizer("Hello world", return_tensors="pt")
|
||||
print(outputs['input_ids'].shape) # torch.Size([1, 5])
|
||||
|
||||
# Return TensorFlow tensors
|
||||
outputs = tokenizer("Hello world", return_tensors="tf")
|
||||
|
||||
# Return NumPy arrays
|
||||
outputs = tokenizer("Hello world", return_tensors="np")
|
||||
|
||||
# Return lists (default)
|
||||
outputs = tokenizer("Hello world", return_tensors=None)
|
||||
```
|
||||
|
||||
### Decoding
|
||||
|
||||
```python
|
||||
# Decode token IDs
|
||||
ids = [101, 7592, 2088, 102]
|
||||
text = tokenizer.decode(ids)
|
||||
print(text) # "[CLS] hello world [SEP]"
|
||||
|
||||
# Skip special tokens
|
||||
text = tokenizer.decode(ids, skip_special_tokens=True)
|
||||
print(text) # "hello world"
|
||||
|
||||
# Batch decode
|
||||
batch_ids = [[101, 7592, 102], [101, 2088, 102]]
|
||||
texts = tokenizer.batch_decode(batch_ids, skip_special_tokens=True)
|
||||
print(texts) # ["hello", "world"]
|
||||
```
|
||||
|
||||
## Padding and truncation
|
||||
|
||||
### Padding strategies
|
||||
|
||||
```python
|
||||
# Pad to max length in batch
|
||||
tokenizer(texts, padding="longest")
|
||||
|
||||
# Pad to model max length
|
||||
tokenizer(texts, padding="max_length", max_length=128)
|
||||
|
||||
# No padding
|
||||
tokenizer(texts, padding=False)
|
||||
|
||||
# Pad to multiple of value (for efficient computation)
|
||||
tokenizer(texts, padding="max_length", max_length=128, pad_to_multiple_of=8)
|
||||
# Result: length will be 128 (already multiple of 8)
|
||||
```
|
||||
|
||||
### Truncation strategies
|
||||
|
||||
```python
|
||||
# Truncate to max length
|
||||
tokenizer(text, truncation=True, max_length=10)
|
||||
|
||||
# Only truncate first sequence (for pairs)
|
||||
tokenizer(text1, text2, truncation="only_first", max_length=20)
|
||||
|
||||
# Only truncate second sequence
|
||||
tokenizer(text1, text2, truncation="only_second", max_length=20)
|
||||
|
||||
# Truncate longest first (default for pairs)
|
||||
tokenizer(text1, text2, truncation="longest_first", max_length=20)
|
||||
|
||||
# No truncation (error if too long)
|
||||
tokenizer(text, truncation=False)
|
||||
```
|
||||
|
||||
### Stride for long documents
|
||||
|
||||
```python
|
||||
# For documents longer than max_length
|
||||
text = "Very long document " * 1000
|
||||
|
||||
# Encode with overlap
|
||||
encodings = tokenizer(
|
||||
text,
|
||||
max_length=512,
|
||||
stride=128, # Overlap between chunks
|
||||
truncation=True,
|
||||
return_overflowing_tokens=True,
|
||||
return_offsets_mapping=True
|
||||
)
|
||||
|
||||
# Get all chunks
|
||||
num_chunks = len(encodings['input_ids'])
|
||||
print(f"Split into {num_chunks} chunks")
|
||||
|
||||
# Each chunk overlaps by stride tokens
|
||||
for i, chunk in enumerate(encodings['input_ids']):
|
||||
print(f"Chunk {i}: {len(chunk)} tokens")
|
||||
```
|
||||
|
||||
**Use case**: Long document QA, sliding window inference
|
||||
|
||||
## Alignment and offsets
|
||||
|
||||
### Offset mapping
|
||||
|
||||
```python
|
||||
# Get character offsets for each token
|
||||
encoded = tokenizer("Hello, world!", return_offsets_mapping=True)
|
||||
|
||||
for token, (start, end) in zip(
|
||||
encoded.tokens(),
|
||||
encoded['offset_mapping'][0]
|
||||
):
|
||||
print(f"{token:10s} → [{start:2d}, {end:2d})")
|
||||
|
||||
# Output:
|
||||
# [CLS] → [ 0, 0)
|
||||
# Hello → [ 0, 5)
|
||||
# , → [ 5, 6)
|
||||
# world → [ 7, 12)
|
||||
# ! → [12, 13)
|
||||
# [SEP] → [ 0, 0)
|
||||
```
|
||||
|
||||
### Word IDs
|
||||
|
||||
```python
|
||||
# Get word index for each token
|
||||
encoded = tokenizer("Hello world", return_offsets_mapping=True)
|
||||
word_ids = encoded.word_ids()
|
||||
|
||||
print(word_ids)
|
||||
# [None, 0, 1, None]
|
||||
# None = special token, 0 = first word, 1 = second word
|
||||
```
|
||||
|
||||
**Use case**: Token classification (NER, POS tagging)
|
||||
|
||||
### Character to token mapping
|
||||
|
||||
```python
|
||||
text = "Machine learning is awesome"
|
||||
encoded = tokenizer(text, return_offsets_mapping=True)
|
||||
|
||||
# Find token for character position
|
||||
char_pos = 8 # "l" in "learning"
|
||||
token_idx = encoded.char_to_token(char_pos)
|
||||
|
||||
print(f"Character {char_pos} is in token {token_idx}: {encoded.tokens()[token_idx]}")
|
||||
# Character 8 is in token 2: learning
|
||||
```
|
||||
|
||||
**Use case**: Question answering (map answer character span to tokens)
|
||||
|
||||
### Sequence pairs
|
||||
|
||||
```python
|
||||
# Encode sentence pair
|
||||
encoded = tokenizer("Question here", "Answer here", return_offsets_mapping=True)
|
||||
|
||||
# Get sequence IDs (which sequence each token belongs to)
|
||||
sequence_ids = encoded.sequence_ids()
|
||||
print(sequence_ids)
|
||||
# [None, 0, 0, 0, None, 1, 1, 1, None]
|
||||
# None = special token, 0 = question, 1 = answer
|
||||
```
|
||||
|
||||
## Model integration
|
||||
|
||||
### Use with transformers models
|
||||
|
||||
```python
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
import torch
|
||||
|
||||
# Load model and tokenizer
|
||||
model = AutoModel.from_pretrained("bert-base-uncased")
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Tokenize
|
||||
text = "Hello world"
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
|
||||
# Forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Get embeddings
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
print(last_hidden_state.shape) # [1, seq_len, hidden_size]
|
||||
```
|
||||
|
||||
### Custom model with custom tokenizer
|
||||
|
||||
```python
|
||||
from transformers import BertConfig, BertModel
|
||||
|
||||
# Train custom tokenizer
|
||||
from tokenizers import Tokenizer, models, trainers
|
||||
tokenizer = Tokenizer(models.BPE())
|
||||
trainer = trainers.BpeTrainer(vocab_size=30000)
|
||||
tokenizer.train(files=["data.txt"], trainer=trainer)
|
||||
|
||||
# Wrap for transformers
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
fast_tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=tokenizer,
|
||||
unk_token="[UNK]",
|
||||
pad_token="[PAD]"
|
||||
)
|
||||
|
||||
# Create model with custom vocab size
|
||||
config = BertConfig(vocab_size=30000)
|
||||
model = BertModel(config)
|
||||
|
||||
# Use together
|
||||
inputs = fast_tokenizer("Hello world", return_tensors="pt")
|
||||
outputs = model(**inputs)
|
||||
```
|
||||
|
||||
### Save and load together
|
||||
|
||||
```python
|
||||
# Save both
|
||||
model.save_pretrained("my-model")
|
||||
tokenizer.save_pretrained("my-model")
|
||||
|
||||
# Directory structure:
|
||||
# my-model/
|
||||
# ├── config.json
|
||||
# ├── pytorch_model.bin
|
||||
# ├── tokenizer.json
|
||||
# ├── tokenizer_config.json
|
||||
# └── special_tokens_map.json
|
||||
|
||||
# Load both
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
model = AutoModel.from_pretrained("my-model")
|
||||
tokenizer = AutoTokenizer.from_pretrained("my-model")
|
||||
```
|
||||
|
||||
## Advanced features
|
||||
|
||||
### Multimodal tokenization
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# LLaVA-style (image + text)
|
||||
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
|
||||
# Add image placeholder token
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
|
||||
|
||||
# Use in prompt
|
||||
text = "Describe this image: <image>"
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
```
|
||||
|
||||
### Template formatting
|
||||
|
||||
```python
|
||||
# Chat template
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi! How can I help?"},
|
||||
{"role": "user", "content": "What's the weather?"}
|
||||
]
|
||||
|
||||
# Apply chat template (if tokenizer has one)
|
||||
if hasattr(tokenizer, "apply_chat_template"):
|
||||
text = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
```
|
||||
|
||||
### Custom template
|
||||
|
||||
```python
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json")
|
||||
|
||||
# Define chat template
|
||||
tokenizer.chat_template = """
|
||||
{%- for message in messages %}
|
||||
{%- if message['role'] == 'system' %}
|
||||
System: {{ message['content'] }}\\n
|
||||
{%- elif message['role'] == 'user' %}
|
||||
User: {{ message['content'] }}\\n
|
||||
{%- elif message['role'] == 'assistant' %}
|
||||
Assistant: {{ message['content'] }}\\n
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
Assistant:
|
||||
"""
|
||||
|
||||
# Use template
|
||||
text = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
```
|
||||
|
||||
## Performance optimization
|
||||
|
||||
### Batch processing
|
||||
|
||||
```python
|
||||
# Process large datasets efficiently
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("imdb", split="train[:1000]")
|
||||
|
||||
# Tokenize in batches
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(
|
||||
examples["text"],
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=512
|
||||
)
|
||||
|
||||
# Map over dataset (batched)
|
||||
tokenized_dataset = dataset.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
batch_size=1000,
|
||||
num_proc=4 # Parallel processing
|
||||
)
|
||||
```
|
||||
|
||||
### Caching
|
||||
|
||||
```python
|
||||
# Enable caching for repeated tokenization
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"bert-base-uncased",
|
||||
use_fast=True,
|
||||
cache_dir="./cache" # Cache tokenizer files
|
||||
)
|
||||
|
||||
# Tokenize with caching
|
||||
from functools import lru_cache
|
||||
|
||||
@lru_cache(maxsize=10000)
|
||||
def cached_tokenize(text):
|
||||
return tuple(tokenizer.encode(text))
|
||||
|
||||
# Reuses cached results for repeated inputs
|
||||
```
|
||||
|
||||
### Memory efficiency
|
||||
|
||||
```python
|
||||
# For very large datasets, use streaming
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("pile", split="train", streaming=True)
|
||||
|
||||
def process_batch(batch):
|
||||
# Tokenize
|
||||
tokens = tokenizer(batch["text"], truncation=True, max_length=512)
|
||||
|
||||
# Process tokens...
|
||||
|
||||
return tokens
|
||||
|
||||
# Process in chunks (memory efficient)
|
||||
for batch in dataset.batch(batch_size=1000):
|
||||
processed = process_batch(batch)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: Tokenizer not fast
|
||||
|
||||
**Symptom**:
|
||||
```python
|
||||
tokenizer.is_fast # False
|
||||
```
|
||||
|
||||
**Solution**: Install tokenizers library
|
||||
```bash
|
||||
pip install tokenizers
|
||||
```
|
||||
|
||||
### Issue: Special tokens not working
|
||||
|
||||
**Symptom**: Special tokens are split into subwords
|
||||
|
||||
**Solution**: Add as special tokens, not regular tokens
|
||||
```python
|
||||
# Wrong
|
||||
tokenizer.add_tokens(["<|image|>"])
|
||||
|
||||
# Correct
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": ["<|image|>"]})
|
||||
```
|
||||
|
||||
### Issue: Offset mapping not available
|
||||
|
||||
**Symptom**:
|
||||
```python
|
||||
tokenizer("text", return_offsets_mapping=True)
|
||||
# Error: return_offsets_mapping not supported
|
||||
```
|
||||
|
||||
**Solution**: Use fast tokenizer
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Load fast version
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)
|
||||
```
|
||||
|
||||
### Issue: Padding inconsistent
|
||||
|
||||
**Symptom**: Some sequences padded, others not
|
||||
|
||||
**Solution**: Specify padding strategy
|
||||
```python
|
||||
# Explicit padding
|
||||
tokenizer(
|
||||
texts,
|
||||
padding="max_length", # or "longest"
|
||||
max_length=128
|
||||
)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Always use fast tokenizers**:
|
||||
- 5-10× faster
|
||||
- Full alignment tracking
|
||||
- Better batch processing
|
||||
|
||||
2. **Save tokenizer with model**:
|
||||
- Ensures reproducibility
|
||||
- Prevents version mismatches
|
||||
|
||||
3. **Use batch processing for datasets**:
|
||||
- Tokenize with `.map(batched=True)`
|
||||
- Set `num_proc` for parallelism
|
||||
|
||||
4. **Enable caching for repeated inputs**:
|
||||
- Use `lru_cache` for inference
|
||||
- Cache tokenizer files with `cache_dir`
|
||||
|
||||
5. **Handle special tokens properly**:
|
||||
- Use `add_special_tokens()` for never-split tokens
|
||||
- Resize embeddings after adding tokens
|
||||
|
||||
6. **Test alignment for downstream tasks**:
|
||||
- Verify `offset_mapping` is correct
|
||||
- Test `char_to_token()` on samples
|
||||
|
||||
7. **Version control tokenizer config**:
|
||||
- Save `tokenizer_config.json`
|
||||
- Document custom templates
|
||||
- Track vocabulary changes
|
||||
|
|
@ -0,0 +1,723 @@
|
|||
# Tokenization Pipeline Components
|
||||
|
||||
Complete guide to normalizers, pre-tokenizers, models, post-processors, and decoders.
|
||||
|
||||
## Pipeline overview
|
||||
|
||||
**Full tokenization pipeline**:
|
||||
```
|
||||
Raw Text
|
||||
↓
|
||||
Normalization (cleaning, lowercasing)
|
||||
↓
|
||||
Pre-tokenization (split into words)
|
||||
↓
|
||||
Model (apply BPE/WordPiece/Unigram)
|
||||
↓
|
||||
Post-processing (add special tokens)
|
||||
↓
|
||||
Token IDs
|
||||
```
|
||||
|
||||
**Decoding reverses the process**:
|
||||
```
|
||||
Token IDs
|
||||
↓
|
||||
Decoder (handle special encodings)
|
||||
↓
|
||||
Raw Text
|
||||
```
|
||||
|
||||
## Normalizers
|
||||
|
||||
Clean and standardize input text.
|
||||
|
||||
### Common normalizers
|
||||
|
||||
**Lowercase**:
|
||||
```python
|
||||
from tokenizers.normalizers import Lowercase
|
||||
|
||||
tokenizer.normalizer = Lowercase()
|
||||
|
||||
# Input: "Hello WORLD"
|
||||
# Output: "hello world"
|
||||
```
|
||||
|
||||
**Unicode normalization**:
|
||||
```python
|
||||
from tokenizers.normalizers import NFD, NFC, NFKD, NFKC
|
||||
|
||||
# NFD: Canonical decomposition
|
||||
tokenizer.normalizer = NFD()
|
||||
# "é" → "e" + "́" (separate characters)
|
||||
|
||||
# NFC: Canonical composition (default)
|
||||
tokenizer.normalizer = NFC()
|
||||
# "e" + "́" → "é" (composed)
|
||||
|
||||
# NFKD: Compatibility decomposition
|
||||
tokenizer.normalizer = NFKD()
|
||||
# "fi" → "f" + "i"
|
||||
|
||||
# NFKC: Compatibility composition
|
||||
tokenizer.normalizer = NFKC()
|
||||
# Most aggressive normalization
|
||||
```
|
||||
|
||||
**Strip accents**:
|
||||
```python
|
||||
from tokenizers.normalizers import StripAccents
|
||||
|
||||
tokenizer.normalizer = StripAccents()
|
||||
|
||||
# Input: "café"
|
||||
# Output: "cafe"
|
||||
```
|
||||
|
||||
**Whitespace handling**:
|
||||
```python
|
||||
from tokenizers.normalizers import Strip, StripAccents
|
||||
|
||||
# Remove leading/trailing whitespace
|
||||
tokenizer.normalizer = Strip()
|
||||
|
||||
# Input: " hello "
|
||||
# Output: "hello"
|
||||
```
|
||||
|
||||
**Replace patterns**:
|
||||
```python
|
||||
from tokenizers.normalizers import Replace
|
||||
|
||||
# Replace newlines with spaces
|
||||
tokenizer.normalizer = Replace("\\n", " ")
|
||||
|
||||
# Input: "hello\\nworld"
|
||||
# Output: "hello world"
|
||||
```
|
||||
|
||||
### Combining normalizers
|
||||
|
||||
```python
|
||||
from tokenizers.normalizers import Sequence, NFD, Lowercase, StripAccents
|
||||
|
||||
# BERT-style normalization
|
||||
tokenizer.normalizer = Sequence([
|
||||
NFD(), # Unicode decomposition
|
||||
Lowercase(), # Convert to lowercase
|
||||
StripAccents() # Remove accents
|
||||
])
|
||||
|
||||
# Input: "Café au Lait"
|
||||
# After NFD: "Café au Lait" (e + ́)
|
||||
# After Lowercase: "café au lait"
|
||||
# After StripAccents: "cafe au lait"
|
||||
```
|
||||
|
||||
### Use case examples
|
||||
|
||||
**Case-insensitive model (BERT)**:
|
||||
```python
|
||||
from tokenizers.normalizers import BertNormalizer
|
||||
|
||||
# All-in-one BERT normalization
|
||||
tokenizer.normalizer = BertNormalizer(
|
||||
clean_text=True, # Remove control characters
|
||||
handle_chinese_chars=True, # Add spaces around Chinese
|
||||
strip_accents=True, # Remove accents
|
||||
lowercase=True # Lowercase
|
||||
)
|
||||
```
|
||||
|
||||
**Case-sensitive model (GPT-2)**:
|
||||
```python
|
||||
# Minimal normalization
|
||||
tokenizer.normalizer = NFC() # Only normalize Unicode
|
||||
```
|
||||
|
||||
**Multilingual (mBERT)**:
|
||||
```python
|
||||
# Preserve scripts, normalize form
|
||||
tokenizer.normalizer = NFKC()
|
||||
```
|
||||
|
||||
## Pre-tokenizers
|
||||
|
||||
Split text into word-like units before tokenization.
|
||||
|
||||
### Whitespace splitting
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Whitespace
|
||||
|
||||
tokenizer.pre_tokenizer = Whitespace()
|
||||
|
||||
# Input: "Hello world! How are you?"
|
||||
# Output: [("Hello", (0, 5)), ("world!", (6, 12)), ("How", (13, 16)), ("are", (17, 20)), ("you?", (21, 25))]
|
||||
```
|
||||
|
||||
### Punctuation isolation
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Punctuation
|
||||
|
||||
tokenizer.pre_tokenizer = Punctuation()
|
||||
|
||||
# Input: "Hello, world!"
|
||||
# Output: [("Hello", ...), (",", ...), ("world", ...), ("!", ...)]
|
||||
```
|
||||
|
||||
### Byte-level (GPT-2)
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
|
||||
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True)
|
||||
|
||||
# Input: "Hello world"
|
||||
# Output: Byte-level tokens with Ġ prefix for spaces
|
||||
# [("ĠHello", ...), ("Ġworld", ...)]
|
||||
```
|
||||
|
||||
**Key feature**: Handles ALL Unicode characters (256 byte combinations)
|
||||
|
||||
### Metaspace (SentencePiece)
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Metaspace
|
||||
|
||||
tokenizer.pre_tokenizer = Metaspace(replacement="▁", add_prefix_space=True)
|
||||
|
||||
# Input: "Hello world"
|
||||
# Output: [("▁Hello", ...), ("▁world", ...)]
|
||||
```
|
||||
|
||||
**Used by**: T5, ALBERT (via SentencePiece)
|
||||
|
||||
### Digits splitting
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Digits
|
||||
|
||||
# Split digits individually
|
||||
tokenizer.pre_tokenizer = Digits(individual_digits=True)
|
||||
|
||||
# Input: "Room 123"
|
||||
# Output: [("Room", ...), ("1", ...), ("2", ...), ("3", ...)]
|
||||
|
||||
# Keep digits together
|
||||
tokenizer.pre_tokenizer = Digits(individual_digits=False)
|
||||
|
||||
# Input: "Room 123"
|
||||
# Output: [("Room", ...), ("123", ...)]
|
||||
```
|
||||
|
||||
### BERT pre-tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import BertPreTokenizer
|
||||
|
||||
tokenizer.pre_tokenizer = BertPreTokenizer()
|
||||
|
||||
# Splits on whitespace and punctuation, preserves CJK
|
||||
# Input: "Hello, 世界!"
|
||||
# Output: [("Hello", ...), (",", ...), ("世", ...), ("界", ...), ("!", ...)]
|
||||
```
|
||||
|
||||
### Combining pre-tokenizers
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Sequence, Whitespace, Punctuation
|
||||
|
||||
tokenizer.pre_tokenizer = Sequence([
|
||||
Whitespace(), # Split on whitespace first
|
||||
Punctuation() # Then isolate punctuation
|
||||
])
|
||||
|
||||
# Input: "Hello, world!"
|
||||
# After Whitespace: [("Hello,", ...), ("world!", ...)]
|
||||
# After Punctuation: [("Hello", ...), (",", ...), ("world", ...), ("!", ...)]
|
||||
```
|
||||
|
||||
### Pre-tokenizer comparison
|
||||
|
||||
| Pre-tokenizer | Use Case | Example |
|
||||
|-------------------|---------------------------------|--------------------------------------------|
|
||||
| Whitespace | Simple English | "Hello world" → ["Hello", "world"] |
|
||||
| Punctuation | Isolate symbols | "world!" → ["world", "!"] |
|
||||
| ByteLevel | Multilingual, emojis | "🌍" → byte tokens |
|
||||
| Metaspace | SentencePiece-style | "Hello" → ["▁Hello"] |
|
||||
| BertPreTokenizer | BERT-style (CJK aware) | "世界" → ["世", "界"] |
|
||||
| Digits | Handle numbers | "123" → ["1", "2", "3"] or ["123"] |
|
||||
|
||||
## Models
|
||||
|
||||
Core tokenization algorithms.
|
||||
|
||||
### BPE Model
|
||||
|
||||
```python
|
||||
from tokenizers.models import BPE
|
||||
|
||||
model = BPE(
|
||||
vocab=None, # Or provide pre-built vocab
|
||||
merges=None, # Or provide merge rules
|
||||
unk_token="[UNK]", # Unknown token
|
||||
continuing_subword_prefix="",
|
||||
end_of_word_suffix="",
|
||||
fuse_unk=False # Keep unknown tokens separate
|
||||
)
|
||||
|
||||
tokenizer = Tokenizer(model)
|
||||
```
|
||||
|
||||
**Parameters**:
|
||||
- `vocab`: Dict of token → id
|
||||
- `merges`: List of merge rules `["a b", "ab c"]`
|
||||
- `unk_token`: Token for unknown words
|
||||
- `continuing_subword_prefix`: Prefix for subwords (empty for GPT-2)
|
||||
- `end_of_word_suffix`: Suffix for last subword (empty for GPT-2)
|
||||
|
||||
### WordPiece Model
|
||||
|
||||
```python
|
||||
from tokenizers.models import WordPiece
|
||||
|
||||
model = WordPiece(
|
||||
vocab=None,
|
||||
unk_token="[UNK]",
|
||||
max_input_chars_per_word=100, # Max word length
|
||||
continuing_subword_prefix="##" # BERT-style prefix
|
||||
)
|
||||
|
||||
tokenizer = Tokenizer(model)
|
||||
```
|
||||
|
||||
**Key difference**: Uses `##` prefix for continuing subwords.
|
||||
|
||||
### Unigram Model
|
||||
|
||||
```python
|
||||
from tokenizers.models import Unigram
|
||||
|
||||
model = Unigram(
|
||||
vocab=None, # List of (token, score) tuples
|
||||
unk_id=0, # ID for unknown token
|
||||
byte_fallback=False # Fall back to bytes if no match
|
||||
)
|
||||
|
||||
tokenizer = Tokenizer(model)
|
||||
```
|
||||
|
||||
**Probabilistic**: Selects tokenization with highest probability.
|
||||
|
||||
### WordLevel Model
|
||||
|
||||
```python
|
||||
from tokenizers.models import WordLevel
|
||||
|
||||
# Simple word-to-ID mapping (no subwords)
|
||||
model = WordLevel(
|
||||
vocab=None,
|
||||
unk_token="[UNK]"
|
||||
)
|
||||
|
||||
tokenizer = Tokenizer(model)
|
||||
```
|
||||
|
||||
**Warning**: Requires huge vocabulary (one token per word).
|
||||
|
||||
## Post-processors
|
||||
|
||||
Add special tokens and format output.
|
||||
|
||||
### Template processing
|
||||
|
||||
**BERT-style** (`[CLS] sentence [SEP]`):
|
||||
```python
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="[CLS] $A [SEP]",
|
||||
pair="[CLS] $A [SEP] $B [SEP]",
|
||||
special_tokens=[
|
||||
("[CLS]", 101),
|
||||
("[SEP]", 102),
|
||||
],
|
||||
)
|
||||
|
||||
# Single sentence
|
||||
output = tokenizer.encode("Hello world")
|
||||
# [101, ..., 102] ([CLS] hello world [SEP])
|
||||
|
||||
# Sentence pair
|
||||
output = tokenizer.encode("Hello", "world")
|
||||
# [101, ..., 102, ..., 102] ([CLS] hello [SEP] world [SEP])
|
||||
```
|
||||
|
||||
**GPT-2 style** (`sentence <|endoftext|>`):
|
||||
```python
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="$A <|endoftext|>",
|
||||
special_tokens=[
|
||||
("<|endoftext|>", 50256),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
**RoBERTa style** (`<s> sentence </s>`):
|
||||
```python
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="<s> $A </s>",
|
||||
pair="<s> $A </s> </s> $B </s>",
|
||||
special_tokens=[
|
||||
("<s>", 0),
|
||||
("</s>", 2),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
**T5 style** (no special tokens):
|
||||
```python
|
||||
# T5 doesn't add special tokens via post-processor
|
||||
tokenizer.post_processor = None
|
||||
```
|
||||
|
||||
### RobertaProcessing
|
||||
|
||||
```python
|
||||
from tokenizers.processors import RobertaProcessing
|
||||
|
||||
tokenizer.post_processor = RobertaProcessing(
|
||||
sep=("</s>", 2),
|
||||
cls=("<s>", 0),
|
||||
add_prefix_space=True, # Add space before first token
|
||||
trim_offsets=True # Trim leading space from offsets
|
||||
)
|
||||
```
|
||||
|
||||
### ByteLevelProcessing
|
||||
|
||||
```python
|
||||
from tokenizers.processors import ByteLevel as ByteLevelProcessing
|
||||
|
||||
tokenizer.post_processor = ByteLevelProcessing(
|
||||
trim_offsets=True # Remove Ġ from offsets
|
||||
)
|
||||
```
|
||||
|
||||
## Decoders
|
||||
|
||||
Convert token IDs back to text.
|
||||
|
||||
### ByteLevel decoder
|
||||
|
||||
```python
|
||||
from tokenizers.decoders import ByteLevel
|
||||
|
||||
tokenizer.decoder = ByteLevel()
|
||||
|
||||
# Handles byte-level tokens
|
||||
# ["ĠHello", "Ġworld"] → "Hello world"
|
||||
```
|
||||
|
||||
### WordPiece decoder
|
||||
|
||||
```python
|
||||
from tokenizers.decoders import WordPiece
|
||||
|
||||
tokenizer.decoder = WordPiece(prefix="##")
|
||||
|
||||
# Removes ## prefix and concatenates
|
||||
# ["token", "##ization"] → "tokenization"
|
||||
```
|
||||
|
||||
### Metaspace decoder
|
||||
|
||||
```python
|
||||
from tokenizers.decoders import Metaspace
|
||||
|
||||
tokenizer.decoder = Metaspace(replacement="▁", add_prefix_space=True)
|
||||
|
||||
# Converts ▁ back to spaces
|
||||
# ["▁Hello", "▁world"] → "Hello world"
|
||||
```
|
||||
|
||||
### BPEDecoder
|
||||
|
||||
```python
|
||||
from tokenizers.decoders import BPEDecoder
|
||||
|
||||
tokenizer.decoder = BPEDecoder(suffix="</w>")
|
||||
|
||||
# Removes suffix and concatenates
|
||||
# ["token", "ization</w>"] → "tokenization"
|
||||
```
|
||||
|
||||
### Sequence decoder
|
||||
|
||||
```python
|
||||
from tokenizers.decoders import Sequence, ByteLevel, Strip
|
||||
|
||||
tokenizer.decoder = Sequence([
|
||||
ByteLevel(), # Decode byte-level first
|
||||
Strip(' ', 1, 1) # Strip leading/trailing spaces
|
||||
])
|
||||
```
|
||||
|
||||
## Complete pipeline examples
|
||||
|
||||
### BERT tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import WordPiece
|
||||
from tokenizers.normalizers import BertNormalizer
|
||||
from tokenizers.pre_tokenizers import BertPreTokenizer
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
from tokenizers.decoders import WordPiece as WordPieceDecoder
|
||||
|
||||
# Model
|
||||
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
|
||||
|
||||
# Normalization
|
||||
tokenizer.normalizer = BertNormalizer(lowercase=True)
|
||||
|
||||
# Pre-tokenization
|
||||
tokenizer.pre_tokenizer = BertPreTokenizer()
|
||||
|
||||
# Post-processing
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="[CLS] $A [SEP]",
|
||||
pair="[CLS] $A [SEP] $B [SEP]",
|
||||
special_tokens=[("[CLS]", 101), ("[SEP]", 102)],
|
||||
)
|
||||
|
||||
# Decoder
|
||||
tokenizer.decoder = WordPieceDecoder(prefix="##")
|
||||
|
||||
# Enable padding
|
||||
tokenizer.enable_padding(pad_id=0, pad_token="[PAD]")
|
||||
|
||||
# Enable truncation
|
||||
tokenizer.enable_truncation(max_length=512)
|
||||
```
|
||||
|
||||
### GPT-2 tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.normalizers import NFC
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
|
||||
# Model
|
||||
tokenizer = Tokenizer(BPE())
|
||||
|
||||
# Normalization (minimal)
|
||||
tokenizer.normalizer = NFC()
|
||||
|
||||
# Byte-level pre-tokenization
|
||||
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
|
||||
|
||||
# Post-processing
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="$A <|endoftext|>",
|
||||
special_tokens=[("<|endoftext|>", 50256)],
|
||||
)
|
||||
|
||||
# Byte-level decoder
|
||||
tokenizer.decoder = ByteLevelDecoder()
|
||||
```
|
||||
|
||||
### T5 tokenizer (SentencePiece-style)
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import Unigram
|
||||
from tokenizers.normalizers import NFKC
|
||||
from tokenizers.pre_tokenizers import Metaspace
|
||||
from tokenizers.decoders import Metaspace as MetaspaceDecoder
|
||||
|
||||
# Model
|
||||
tokenizer = Tokenizer(Unigram())
|
||||
|
||||
# Normalization
|
||||
tokenizer.normalizer = NFKC()
|
||||
|
||||
# Metaspace pre-tokenization
|
||||
tokenizer.pre_tokenizer = Metaspace(replacement="▁", add_prefix_space=True)
|
||||
|
||||
# No post-processing (T5 doesn't add CLS/SEP)
|
||||
tokenizer.post_processor = None
|
||||
|
||||
# Metaspace decoder
|
||||
tokenizer.decoder = MetaspaceDecoder(replacement="▁", add_prefix_space=True)
|
||||
```
|
||||
|
||||
## Alignment tracking
|
||||
|
||||
Track token positions in original text.
|
||||
|
||||
### Basic alignment
|
||||
|
||||
```python
|
||||
text = "Hello, world!"
|
||||
output = tokenizer.encode(text)
|
||||
|
||||
for token, (start, end) in zip(output.tokens, output.offsets):
|
||||
print(f"{token:10s} → [{start:2d}, {end:2d}): {text[start:end]!r}")
|
||||
|
||||
# Output:
|
||||
# [CLS] → [ 0, 0): ''
|
||||
# hello → [ 0, 5): 'Hello'
|
||||
# , → [ 5, 6): ','
|
||||
# world → [ 7, 12): 'world'
|
||||
# ! → [12, 13): '!'
|
||||
# [SEP] → [ 0, 0): ''
|
||||
```
|
||||
|
||||
### Word-level alignment
|
||||
|
||||
```python
|
||||
# Get word_ids (which word each token belongs to)
|
||||
encoding = tokenizer.encode("Hello world")
|
||||
word_ids = encoding.word_ids
|
||||
|
||||
print(word_ids)
|
||||
# [None, 0, 0, 1, None]
|
||||
# None = special token, 0 = first word, 1 = second word
|
||||
```
|
||||
|
||||
**Use case**: Token classification (NER)
|
||||
```python
|
||||
# Align predictions to words
|
||||
predictions = ["O", "B-PER", "I-PER", "O", "O"]
|
||||
word_predictions = {}
|
||||
|
||||
for token_idx, word_idx in enumerate(encoding.word_ids):
|
||||
if word_idx is not None and word_idx not in word_predictions:
|
||||
word_predictions[word_idx] = predictions[token_idx]
|
||||
|
||||
print(word_predictions)
|
||||
# {0: "B-PER", 1: "O"} # First word is PERSON, second is OTHER
|
||||
```
|
||||
|
||||
### Span alignment
|
||||
|
||||
```python
|
||||
# Find token span for character span
|
||||
text = "Machine learning is awesome"
|
||||
char_start, char_end = 8, 16 # "learning"
|
||||
|
||||
encoding = tokenizer.encode(text)
|
||||
|
||||
# Find token span
|
||||
token_start = encoding.char_to_token(char_start)
|
||||
token_end = encoding.char_to_token(char_end - 1) + 1
|
||||
|
||||
print(f"Tokens {token_start}:{token_end} = {encoding.tokens[token_start:token_end]}")
|
||||
# Tokens 2:3 = ['learning']
|
||||
```
|
||||
|
||||
**Use case**: Question answering (extract answer span)
|
||||
|
||||
## Custom components
|
||||
|
||||
### Custom normalizer
|
||||
|
||||
```python
|
||||
from tokenizers import NormalizedString, Normalizer
|
||||
|
||||
class CustomNormalizer:
|
||||
def normalize(self, normalized: NormalizedString):
|
||||
# Custom normalization logic
|
||||
normalized.lowercase()
|
||||
normalized.replace(" ", " ") # Replace double spaces
|
||||
|
||||
# Use custom normalizer
|
||||
tokenizer.normalizer = CustomNormalizer()
|
||||
```
|
||||
|
||||
### Custom pre-tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import PreTokenizedString
|
||||
|
||||
class CustomPreTokenizer:
|
||||
def pre_tokenize(self, pretok: PreTokenizedString):
|
||||
# Custom pre-tokenization logic
|
||||
pretok.split(lambda i, char: char.isspace())
|
||||
|
||||
tokenizer.pre_tokenizer = CustomPreTokenizer()
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: Misaligned offsets
|
||||
|
||||
**Symptom**: Offsets don't match original text
|
||||
```python
|
||||
text = " hello" # Leading spaces
|
||||
offsets = [(0, 5)] # Expects " hel"
|
||||
```
|
||||
|
||||
**Solution**: Check normalization strips spaces
|
||||
```python
|
||||
# Preserve offsets
|
||||
tokenizer.normalizer = Sequence([
|
||||
Strip(), # This changes offsets!
|
||||
])
|
||||
|
||||
# Use trim_offsets in post-processor instead
|
||||
tokenizer.post_processor = ByteLevelProcessing(trim_offsets=True)
|
||||
```
|
||||
|
||||
### Issue: Special tokens not added
|
||||
|
||||
**Symptom**: No [CLS] or [SEP] in output
|
||||
|
||||
**Solution**: Check post-processor is set
|
||||
```python
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="[CLS] $A [SEP]",
|
||||
special_tokens=[("[CLS]", 101), ("[SEP]", 102)],
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Incorrect decoding
|
||||
|
||||
**Symptom**: Decoded text has ## or ▁
|
||||
|
||||
**Solution**: Set correct decoder
|
||||
```python
|
||||
# For WordPiece
|
||||
tokenizer.decoder = WordPieceDecoder(prefix="##")
|
||||
|
||||
# For SentencePiece
|
||||
tokenizer.decoder = MetaspaceDecoder(replacement="▁")
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Match pipeline to model architecture**:
|
||||
- BERT → BertNormalizer + BertPreTokenizer + WordPiece
|
||||
- GPT-2 → NFC + ByteLevel + BPE
|
||||
- T5 → NFKC + Metaspace + Unigram
|
||||
|
||||
2. **Test pipeline on sample inputs**:
|
||||
- Check normalization doesn't over-normalize
|
||||
- Verify pre-tokenization splits correctly
|
||||
- Ensure decoding reconstructs text
|
||||
|
||||
3. **Preserve alignment for downstream tasks**:
|
||||
- Use `trim_offsets` instead of stripping in normalizer
|
||||
- Test `char_to_token()` on sample spans
|
||||
|
||||
4. **Document your pipeline**:
|
||||
- Save complete tokenizer config
|
||||
- Document special tokens
|
||||
- Note any custom components
|
||||
|
|
@ -0,0 +1,565 @@
|
|||
# Training Custom Tokenizers
|
||||
|
||||
Complete guide to training tokenizers from scratch.
|
||||
|
||||
## Training workflow
|
||||
|
||||
### Step 1: Choose tokenization algorithm
|
||||
|
||||
**Decision tree**:
|
||||
- **GPT-style model** → BPE
|
||||
- **BERT-style model** → WordPiece
|
||||
- **Multilingual/No word boundaries** → Unigram
|
||||
|
||||
### Step 2: Prepare training data
|
||||
|
||||
```python
|
||||
# Option 1: From files
|
||||
files = ["train.txt", "validation.txt"]
|
||||
|
||||
# Option 2: From Python list
|
||||
texts = [
|
||||
"This is the first sentence.",
|
||||
"This is the second sentence.",
|
||||
# ... more texts
|
||||
]
|
||||
|
||||
# Option 3: From dataset iterator
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
|
||||
|
||||
def batch_iterator(batch_size=1000):
|
||||
for i in range(0, len(dataset), batch_size):
|
||||
yield dataset[i:i + batch_size]["text"]
|
||||
```
|
||||
|
||||
### Step 3: Initialize tokenizer
|
||||
|
||||
**BPE example**:
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
|
||||
|
||||
tokenizer = Tokenizer(BPE())
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
tokenizer.decoder = ByteLevelDecoder()
|
||||
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=50000,
|
||||
min_frequency=2,
|
||||
special_tokens=["<|endoftext|>", "<|padding|>"],
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
**WordPiece example**:
|
||||
```python
|
||||
from tokenizers.models import WordPiece
|
||||
from tokenizers.trainers import WordPieceTrainer
|
||||
from tokenizers.normalizers import BertNormalizer
|
||||
from tokenizers.pre_tokenizers import BertPreTokenizer
|
||||
|
||||
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
|
||||
tokenizer.normalizer = BertNormalizer(lowercase=True)
|
||||
tokenizer.pre_tokenizer = BertPreTokenizer()
|
||||
|
||||
trainer = WordPieceTrainer(
|
||||
vocab_size=30522,
|
||||
min_frequency=2,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
|
||||
continuing_subword_prefix="##",
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
**Unigram example**:
|
||||
```python
|
||||
from tokenizers.models import Unigram
|
||||
from tokenizers.trainers import UnigramTrainer
|
||||
|
||||
tokenizer = Tokenizer(Unigram())
|
||||
|
||||
trainer = UnigramTrainer(
|
||||
vocab_size=8000,
|
||||
special_tokens=["<unk>", "<s>", "</s>", "<pad>"],
|
||||
unk_token="<unk>",
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
### Step 4: Train
|
||||
|
||||
```python
|
||||
# From files
|
||||
tokenizer.train(files=files, trainer=trainer)
|
||||
|
||||
# From iterator (recommended for large datasets)
|
||||
tokenizer.train_from_iterator(
|
||||
batch_iterator(),
|
||||
trainer=trainer,
|
||||
length=len(dataset) # Optional, for progress bar
|
||||
)
|
||||
```
|
||||
|
||||
**Training time** (30k vocab on 16-core CPU):
|
||||
- 10 MB: 15-30 seconds
|
||||
- 100 MB: 1-3 minutes
|
||||
- 1 GB: 15-30 minutes
|
||||
- 10 GB: 2-4 hours
|
||||
|
||||
### Step 5: Add post-processing
|
||||
|
||||
```python
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
|
||||
# BERT-style
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="[CLS] $A [SEP]",
|
||||
pair="[CLS] $A [SEP] $B [SEP]",
|
||||
special_tokens=[
|
||||
("[CLS]", tokenizer.token_to_id("[CLS]")),
|
||||
("[SEP]", tokenizer.token_to_id("[SEP]")),
|
||||
],
|
||||
)
|
||||
|
||||
# GPT-2 style
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="$A <|endoftext|>",
|
||||
special_tokens=[
|
||||
("<|endoftext|>", tokenizer.token_to_id("<|endoftext|>")),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
### Step 6: Save
|
||||
|
||||
```python
|
||||
# Save to JSON
|
||||
tokenizer.save("my-tokenizer.json")
|
||||
|
||||
# Save to directory (for transformers)
|
||||
tokenizer.save("my-tokenizer-dir/tokenizer.json")
|
||||
|
||||
# Convert to transformers format
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
transformers_tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=tokenizer,
|
||||
unk_token="[UNK]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
sep_token="[SEP]",
|
||||
mask_token="[MASK]"
|
||||
)
|
||||
|
||||
transformers_tokenizer.save_pretrained("my-tokenizer-dir")
|
||||
```
|
||||
|
||||
## Trainer configuration
|
||||
|
||||
### BpeTrainer parameters
|
||||
|
||||
```python
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=30000, # Target vocabulary size
|
||||
min_frequency=2, # Minimum frequency for merges
|
||||
special_tokens=["[UNK]"], # Special tokens (added first)
|
||||
limit_alphabet=1000, # Limit initial alphabet size
|
||||
initial_alphabet=[], # Pre-defined initial characters
|
||||
show_progress=True, # Show progress bar
|
||||
continuing_subword_prefix="", # Prefix for continuing subwords
|
||||
end_of_word_suffix="" # Suffix for end of words
|
||||
)
|
||||
```
|
||||
|
||||
**Parameter tuning**:
|
||||
- **vocab_size**: Start with 30k for English, 50k for multilingual
|
||||
- **min_frequency**: 2-5 for large corpora, 1 for small
|
||||
- **limit_alphabet**: Reduce for non-English (CJK languages)
|
||||
|
||||
### WordPieceTrainer parameters
|
||||
|
||||
```python
|
||||
from tokenizers.trainers import WordPieceTrainer
|
||||
|
||||
trainer = WordPieceTrainer(
|
||||
vocab_size=30522, # BERT uses 30,522
|
||||
min_frequency=2,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
|
||||
limit_alphabet=1000,
|
||||
continuing_subword_prefix="##", # BERT-style prefix
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
### UnigramTrainer parameters
|
||||
|
||||
```python
|
||||
from tokenizers.trainers import UnigramTrainer
|
||||
|
||||
trainer = UnigramTrainer(
|
||||
vocab_size=8000, # Typically smaller than BPE/WordPiece
|
||||
special_tokens=["<unk>", "<s>", "</s>"],
|
||||
unk_token="<unk>",
|
||||
max_piece_length=16, # Maximum token length
|
||||
n_sub_iterations=2, # EM algorithm iterations
|
||||
shrinking_factor=0.75, # Vocabulary reduction rate
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
## Training from large datasets
|
||||
|
||||
### Memory-efficient training
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("wikipedia", "20220301.en", split="train", streaming=True)
|
||||
|
||||
# Create iterator (yields batches)
|
||||
def batch_iterator(batch_size=1000):
|
||||
batch = []
|
||||
for sample in dataset:
|
||||
batch.append(sample["text"])
|
||||
if len(batch) >= batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
# Initialize tokenizer
|
||||
tokenizer = Tokenizer(BPE())
|
||||
trainer = BpeTrainer(vocab_size=50000, special_tokens=["<|endoftext|>"])
|
||||
|
||||
# Train (memory efficient - streams data)
|
||||
tokenizer.train_from_iterator(
|
||||
batch_iterator(),
|
||||
trainer=trainer
|
||||
)
|
||||
```
|
||||
|
||||
**Memory usage**: ~200 MB (vs 10+ GB loading full dataset)
|
||||
|
||||
### Multi-file training
|
||||
|
||||
```python
|
||||
import glob
|
||||
|
||||
# Find all training files
|
||||
files = glob.glob("data/train/*.txt")
|
||||
print(f"Training on {len(files)} files")
|
||||
|
||||
# Train on all files
|
||||
tokenizer.train(files=files, trainer=trainer)
|
||||
```
|
||||
|
||||
### Parallel training (multi-processing)
|
||||
|
||||
```python
|
||||
from multiprocessing import Pool, cpu_count
|
||||
import os
|
||||
|
||||
def train_shard(shard_files):
|
||||
"""Train tokenizer on a shard of files."""
|
||||
tokenizer = Tokenizer(BPE())
|
||||
trainer = BpeTrainer(vocab_size=50000)
|
||||
tokenizer.train(files=shard_files, trainer=trainer)
|
||||
return tokenizer.get_vocab()
|
||||
|
||||
# Split files into shards
|
||||
num_shards = cpu_count()
|
||||
file_shards = [files[i::num_shards] for i in range(num_shards)]
|
||||
|
||||
# Train shards in parallel
|
||||
with Pool(num_shards) as pool:
|
||||
vocab_shards = pool.map(train_shard, file_shards)
|
||||
|
||||
# Merge vocabularies (custom logic needed)
|
||||
# This is a simplified example - real implementation would merge intelligently
|
||||
final_vocab = {}
|
||||
for vocab in vocab_shards:
|
||||
final_vocab.update(vocab)
|
||||
```
|
||||
|
||||
## Domain-specific tokenizers
|
||||
|
||||
### Code tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
from tokenizers.normalizers import Sequence, NFC
|
||||
|
||||
# Code-optimized configuration
|
||||
tokenizer = Tokenizer(BPE())
|
||||
|
||||
# Minimal normalization (preserve case, whitespace)
|
||||
tokenizer.normalizer = NFC() # Only normalize Unicode
|
||||
|
||||
# Byte-level pre-tokenization (handles all characters)
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
|
||||
# Train on code corpus
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=50000,
|
||||
special_tokens=["<|endoftext|>", "<|pad|>"],
|
||||
min_frequency=2
|
||||
)
|
||||
|
||||
tokenizer.train(files=["code_corpus.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
### Medical/scientific tokenizer
|
||||
|
||||
```python
|
||||
# Preserve case and special characters
|
||||
from tokenizers.normalizers import NFKC
|
||||
from tokenizers.pre_tokenizers import Whitespace, Punctuation, Sequence
|
||||
|
||||
tokenizer = Tokenizer(BPE())
|
||||
|
||||
# Minimal normalization
|
||||
tokenizer.normalizer = NFKC()
|
||||
|
||||
# Preserve medical terms
|
||||
tokenizer.pre_tokenizer = Sequence([
|
||||
Whitespace(),
|
||||
Punctuation(behavior="isolated") # Keep punctuation separate
|
||||
])
|
||||
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=50000,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]"],
|
||||
min_frequency=3 # Higher threshold for rare medical terms
|
||||
)
|
||||
|
||||
tokenizer.train(files=["pubmed_corpus.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
### Multilingual tokenizer
|
||||
|
||||
```python
|
||||
# Handle multiple scripts
|
||||
from tokenizers.normalizers import NFKC, Lowercase, Sequence
|
||||
|
||||
tokenizer = Tokenizer(BPE())
|
||||
|
||||
# Normalize but don't lowercase (preserves script differences)
|
||||
tokenizer.normalizer = NFKC()
|
||||
|
||||
# Byte-level handles all Unicode
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=100000, # Larger vocab for multiple languages
|
||||
special_tokens=["<unk>", "<s>", "</s>"],
|
||||
limit_alphabet=None # No limit (handles all scripts)
|
||||
)
|
||||
|
||||
# Train on multilingual corpus
|
||||
tokenizer.train(files=["multilingual_corpus.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
## Vocabulary size selection
|
||||
|
||||
### Guidelines by task
|
||||
|
||||
| Task | Recommended Vocab Size | Rationale |
|
||||
|-----------------------|------------------------|-----------|
|
||||
| English (monolingual) | 30,000 - 50,000 | Balanced coverage |
|
||||
| Multilingual | 50,000 - 250,000 | More languages = more tokens |
|
||||
| Code | 30,000 - 50,000 | Similar to English |
|
||||
| Domain-specific | 10,000 - 30,000 | Smaller, focused vocabulary |
|
||||
| Character-level tasks | 1,000 - 5,000 | Only characters + subwords |
|
||||
|
||||
### Vocabulary size impact
|
||||
|
||||
**Small vocab (10k)**:
|
||||
- Pros: Faster training, smaller model, less memory
|
||||
- Cons: More tokens per sentence, worse OOV handling
|
||||
|
||||
**Medium vocab (30k-50k)**:
|
||||
- Pros: Good balance, standard choice
|
||||
- Cons: None (recommended default)
|
||||
|
||||
**Large vocab (100k+)**:
|
||||
- Pros: Fewer tokens per sentence, better OOV
|
||||
- Cons: Slower training, larger embedding table
|
||||
|
||||
### Empirical testing
|
||||
|
||||
```python
|
||||
# Train multiple tokenizers with different vocab sizes
|
||||
vocab_sizes = [10000, 30000, 50000, 100000]
|
||||
|
||||
for vocab_size in vocab_sizes:
|
||||
tokenizer = Tokenizer(BPE())
|
||||
trainer = BpeTrainer(vocab_size=vocab_size)
|
||||
tokenizer.train(files=["sample.txt"], trainer=trainer)
|
||||
|
||||
# Evaluate on test set
|
||||
test_text = "Test sentence for evaluation..."
|
||||
tokens = tokenizer.encode(test_text).ids
|
||||
|
||||
print(f"Vocab: {vocab_size:6d} | Tokens: {len(tokens):3d} | Avg: {len(test_text)/len(tokens):.2f} chars/token")
|
||||
|
||||
# Example output:
|
||||
# Vocab: 10000 | Tokens: 12 | Avg: 2.33 chars/token
|
||||
# Vocab: 30000 | Tokens: 8 | Avg: 3.50 chars/token
|
||||
# Vocab: 50000 | Tokens: 7 | Avg: 4.00 chars/token
|
||||
# Vocab: 100000 | Tokens: 6 | Avg: 4.67 chars/token
|
||||
```
|
||||
|
||||
## Testing tokenizer quality
|
||||
|
||||
### Coverage test
|
||||
|
||||
```python
|
||||
# Test on held-out data
|
||||
test_corpus = load_dataset("wikitext", "wikitext-103-raw-v1", split="test")
|
||||
|
||||
total_tokens = 0
|
||||
unk_tokens = 0
|
||||
unk_id = tokenizer.token_to_id("[UNK]")
|
||||
|
||||
for text in test_corpus["text"]:
|
||||
if text.strip():
|
||||
encoding = tokenizer.encode(text)
|
||||
total_tokens += len(encoding.ids)
|
||||
unk_tokens += encoding.ids.count(unk_id)
|
||||
|
||||
unk_rate = unk_tokens / total_tokens
|
||||
print(f"Unknown token rate: {unk_rate:.2%}")
|
||||
|
||||
# Good quality: <1% unknown tokens
|
||||
# Acceptable: 1-5%
|
||||
# Poor: >5%
|
||||
```
|
||||
|
||||
### Compression test
|
||||
|
||||
```python
|
||||
# Measure tokenization efficiency
|
||||
import numpy as np
|
||||
|
||||
token_lengths = []
|
||||
|
||||
for text in test_corpus["text"][:1000]:
|
||||
if text.strip():
|
||||
encoding = tokenizer.encode(text)
|
||||
chars_per_token = len(text) / len(encoding.ids)
|
||||
token_lengths.append(chars_per_token)
|
||||
|
||||
avg_chars_per_token = np.mean(token_lengths)
|
||||
print(f"Average characters per token: {avg_chars_per_token:.2f}")
|
||||
|
||||
# Good: 4-6 chars/token (English)
|
||||
# Acceptable: 3-4 chars/token
|
||||
# Poor: <3 chars/token (under-compression)
|
||||
```
|
||||
|
||||
### Semantic test
|
||||
|
||||
```python
|
||||
# Manually inspect tokenization of common words/phrases
|
||||
test_phrases = [
|
||||
"tokenization",
|
||||
"machine learning",
|
||||
"artificial intelligence",
|
||||
"preprocessing",
|
||||
"hello world"
|
||||
]
|
||||
|
||||
for phrase in test_phrases:
|
||||
tokens = tokenizer.encode(phrase).tokens
|
||||
print(f"{phrase:25s} → {tokens}")
|
||||
|
||||
# Good tokenization:
|
||||
# tokenization → ['token', 'ization']
|
||||
# machine learning → ['machine', 'learning']
|
||||
# artificial intelligence → ['artificial', 'intelligence']
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: Training too slow
|
||||
|
||||
**Solutions**:
|
||||
1. Reduce vocabulary size
|
||||
2. Increase `min_frequency`
|
||||
3. Use `limit_alphabet` to reduce initial alphabet
|
||||
4. Train on subset first
|
||||
|
||||
```python
|
||||
# Fast training configuration
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=20000, # Smaller vocab
|
||||
min_frequency=5, # Higher threshold
|
||||
limit_alphabet=500, # Limit alphabet
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: High unknown token rate
|
||||
|
||||
**Solutions**:
|
||||
1. Increase vocabulary size
|
||||
2. Decrease `min_frequency`
|
||||
3. Check normalization (might be too aggressive)
|
||||
|
||||
```python
|
||||
# Better coverage configuration
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=50000, # Larger vocab
|
||||
min_frequency=1, # Lower threshold
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Poor quality tokenization
|
||||
|
||||
**Solutions**:
|
||||
1. Verify normalization matches your use case
|
||||
2. Check pre-tokenization splits correctly
|
||||
3. Ensure training data is representative
|
||||
4. Try different algorithm (BPE vs WordPiece vs Unigram)
|
||||
|
||||
```python
|
||||
# Debug tokenization pipeline
|
||||
text = "Sample text to debug"
|
||||
|
||||
# Check normalization
|
||||
normalized = tokenizer.normalizer.normalize_str(text)
|
||||
print(f"Normalized: {normalized}")
|
||||
|
||||
# Check pre-tokenization
|
||||
pre_tokens = tokenizer.pre_tokenizer.pre_tokenize_str(text)
|
||||
print(f"Pre-tokens: {pre_tokens}")
|
||||
|
||||
# Check final tokenization
|
||||
tokens = tokenizer.encode(text).tokens
|
||||
print(f"Tokens: {tokens}")
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use representative training data** - Match your target domain
|
||||
2. **Start with standard configs** - BERT WordPiece or GPT-2 BPE
|
||||
3. **Test on held-out data** - Measure unknown token rate
|
||||
4. **Iterate on vocabulary size** - Test 30k, 50k, 100k
|
||||
5. **Save tokenizer with model** - Ensure reproducibility
|
||||
6. **Version your tokenizers** - Track changes for reproducibility
|
||||
7. **Document special tokens** - Critical for model training
|
||||
743
optional-skills/mlops/instructor/SKILL.md
Normal file
743
optional-skills/mlops/instructor/SKILL.md
Normal file
|
|
@ -0,0 +1,743 @@
|
|||
---
|
||||
name: instructor
|
||||
description: Extract structured data from LLM responses with Pydantic validation, retry failed extractions automatically, parse complex JSON with type safety, and stream partial results with Instructor - battle-tested structured output library
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [instructor, pydantic, openai, anthropic]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Prompt Engineering, Instructor, Structured Output, Pydantic, Data Extraction, JSON Parsing, Type Safety, Validation, Streaming, OpenAI, Anthropic]
|
||||
|
||||
---
|
||||
|
||||
# Instructor: Structured LLM Outputs
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use Instructor when you need to:
|
||||
- **Extract structured data** from LLM responses reliably
|
||||
- **Validate outputs** against Pydantic schemas automatically
|
||||
- **Retry failed extractions** with automatic error handling
|
||||
- **Parse complex JSON** with type safety and validation
|
||||
- **Stream partial results** for real-time processing
|
||||
- **Support multiple LLM providers** with consistent API
|
||||
|
||||
**GitHub Stars**: 15,000+ | **Battle-tested**: 100,000+ developers
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Base installation
|
||||
pip install instructor
|
||||
|
||||
# With specific providers
|
||||
pip install "instructor[anthropic]" # Anthropic Claude
|
||||
pip install "instructor[openai]" # OpenAI
|
||||
pip install "instructor[all]" # All providers
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Example: Extract User Data
|
||||
|
||||
```python
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
from anthropic import Anthropic
|
||||
|
||||
# Define output structure
|
||||
class User(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
email: str
|
||||
|
||||
# Create instructor client
|
||||
client = instructor.from_anthropic(Anthropic())
|
||||
|
||||
# Extract structured data
|
||||
user = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "John Doe is 30 years old. His email is john@example.com"
|
||||
}],
|
||||
response_model=User
|
||||
)
|
||||
|
||||
print(user.name) # "John Doe"
|
||||
print(user.age) # 30
|
||||
print(user.email) # "john@example.com"
|
||||
```
|
||||
|
||||
### With OpenAI
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = instructor.from_openai(OpenAI())
|
||||
|
||||
user = client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
response_model=User,
|
||||
messages=[{"role": "user", "content": "Extract: Alice, 25, alice@email.com"}]
|
||||
)
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. Response Models (Pydantic)
|
||||
|
||||
Response models define the structure and validation rules for LLM outputs.
|
||||
|
||||
#### Basic Model
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Article(BaseModel):
|
||||
title: str = Field(description="Article title")
|
||||
author: str = Field(description="Author name")
|
||||
word_count: int = Field(description="Number of words", gt=0)
|
||||
tags: list[str] = Field(description="List of relevant tags")
|
||||
|
||||
article = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Analyze this article: [article text]"
|
||||
}],
|
||||
response_model=Article
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Type safety with Python type hints
|
||||
- Automatic validation (word_count > 0)
|
||||
- Self-documenting with Field descriptions
|
||||
- IDE autocomplete support
|
||||
|
||||
#### Nested Models
|
||||
|
||||
```python
|
||||
class Address(BaseModel):
|
||||
street: str
|
||||
city: str
|
||||
country: str
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
address: Address # Nested model
|
||||
|
||||
person = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "John lives at 123 Main St, Boston, USA"
|
||||
}],
|
||||
response_model=Person
|
||||
)
|
||||
|
||||
print(person.address.city) # "Boston"
|
||||
```
|
||||
|
||||
#### Optional Fields
|
||||
|
||||
```python
|
||||
from typing import Optional
|
||||
|
||||
class Product(BaseModel):
|
||||
name: str
|
||||
price: float
|
||||
discount: Optional[float] = None # Optional
|
||||
description: str = Field(default="No description") # Default value
|
||||
|
||||
# LLM doesn't need to provide discount or description
|
||||
```
|
||||
|
||||
#### Enums for Constraints
|
||||
|
||||
```python
|
||||
from enum import Enum
|
||||
|
||||
class Sentiment(str, Enum):
|
||||
POSITIVE = "positive"
|
||||
NEGATIVE = "negative"
|
||||
NEUTRAL = "neutral"
|
||||
|
||||
class Review(BaseModel):
|
||||
text: str
|
||||
sentiment: Sentiment # Only these 3 values allowed
|
||||
|
||||
review = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "This product is amazing!"
|
||||
}],
|
||||
response_model=Review
|
||||
)
|
||||
|
||||
print(review.sentiment) # Sentiment.POSITIVE
|
||||
```
|
||||
|
||||
### 2. Validation
|
||||
|
||||
Pydantic validates LLM outputs automatically. If validation fails, Instructor retries.
|
||||
|
||||
#### Built-in Validators
|
||||
|
||||
```python
|
||||
from pydantic import Field, EmailStr, HttpUrl
|
||||
|
||||
class Contact(BaseModel):
|
||||
name: str = Field(min_length=2, max_length=100)
|
||||
age: int = Field(ge=0, le=120) # 0 <= age <= 120
|
||||
email: EmailStr # Validates email format
|
||||
website: HttpUrl # Validates URL format
|
||||
|
||||
# If LLM provides invalid data, Instructor retries automatically
|
||||
```
|
||||
|
||||
#### Custom Validators
|
||||
|
||||
```python
|
||||
from pydantic import field_validator
|
||||
|
||||
class Event(BaseModel):
|
||||
name: str
|
||||
date: str
|
||||
attendees: int
|
||||
|
||||
@field_validator('date')
|
||||
def validate_date(cls, v):
|
||||
"""Ensure date is in YYYY-MM-DD format."""
|
||||
import re
|
||||
if not re.match(r'\d{4}-\d{2}-\d{2}', v):
|
||||
raise ValueError('Date must be YYYY-MM-DD format')
|
||||
return v
|
||||
|
||||
@field_validator('attendees')
|
||||
def validate_attendees(cls, v):
|
||||
"""Ensure positive attendees."""
|
||||
if v < 1:
|
||||
raise ValueError('Must have at least 1 attendee')
|
||||
return v
|
||||
```
|
||||
|
||||
#### Model-Level Validation
|
||||
|
||||
```python
|
||||
from pydantic import model_validator
|
||||
|
||||
class DateRange(BaseModel):
|
||||
start_date: str
|
||||
end_date: str
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_dates(self):
|
||||
"""Ensure end_date is after start_date."""
|
||||
from datetime import datetime
|
||||
start = datetime.strptime(self.start_date, '%Y-%m-%d')
|
||||
end = datetime.strptime(self.end_date, '%Y-%m-%d')
|
||||
|
||||
if end < start:
|
||||
raise ValueError('end_date must be after start_date')
|
||||
return self
|
||||
```
|
||||
|
||||
### 3. Automatic Retrying
|
||||
|
||||
Instructor retries automatically when validation fails, providing error feedback to the LLM.
|
||||
|
||||
```python
|
||||
# Retries up to 3 times if validation fails
|
||||
user = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Extract user from: John, age unknown"
|
||||
}],
|
||||
response_model=User,
|
||||
max_retries=3 # Default is 3
|
||||
)
|
||||
|
||||
# If age can't be extracted, Instructor tells the LLM:
|
||||
# "Validation error: age - field required"
|
||||
# LLM tries again with better extraction
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
1. LLM generates output
|
||||
2. Pydantic validates
|
||||
3. If invalid: Error message sent back to LLM
|
||||
4. LLM tries again with error feedback
|
||||
5. Repeats up to max_retries
|
||||
|
||||
### 4. Streaming
|
||||
|
||||
Stream partial results for real-time processing.
|
||||
|
||||
#### Streaming Partial Objects
|
||||
|
||||
```python
|
||||
from instructor import Partial
|
||||
|
||||
class Story(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
tags: list[str]
|
||||
|
||||
# Stream partial updates as LLM generates
|
||||
for partial_story in client.messages.create_partial(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Write a short sci-fi story"
|
||||
}],
|
||||
response_model=Story
|
||||
):
|
||||
print(f"Title: {partial_story.title}")
|
||||
print(f"Content so far: {partial_story.content[:100]}...")
|
||||
# Update UI in real-time
|
||||
```
|
||||
|
||||
#### Streaming Iterables
|
||||
|
||||
```python
|
||||
class Task(BaseModel):
|
||||
title: str
|
||||
priority: str
|
||||
|
||||
# Stream list items as they're generated
|
||||
tasks = client.messages.create_iterable(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Generate 10 project tasks"
|
||||
}],
|
||||
response_model=Task
|
||||
)
|
||||
|
||||
for task in tasks:
|
||||
print(f"- {task.title} ({task.priority})")
|
||||
# Process each task as it arrives
|
||||
```
|
||||
|
||||
## Provider Configuration
|
||||
|
||||
### Anthropic Claude
|
||||
|
||||
```python
|
||||
import instructor
|
||||
from anthropic import Anthropic
|
||||
|
||||
client = instructor.from_anthropic(
|
||||
Anthropic(api_key="your-api-key")
|
||||
)
|
||||
|
||||
# Use with Claude models
|
||||
response = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[...],
|
||||
response_model=YourModel
|
||||
)
|
||||
```
|
||||
|
||||
### OpenAI
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = instructor.from_openai(
|
||||
OpenAI(api_key="your-api-key")
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
response_model=YourModel,
|
||||
messages=[...]
|
||||
)
|
||||
```
|
||||
|
||||
### Local Models (Ollama)
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
# Point to local Ollama server
|
||||
client = instructor.from_openai(
|
||||
OpenAI(
|
||||
base_url="http://localhost:11434/v1",
|
||||
api_key="ollama" # Required but ignored
|
||||
),
|
||||
mode=instructor.Mode.JSON
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="llama3.1",
|
||||
response_model=YourModel,
|
||||
messages=[...]
|
||||
)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Pattern 1: Data Extraction from Text
|
||||
|
||||
```python
|
||||
class CompanyInfo(BaseModel):
|
||||
name: str
|
||||
founded_year: int
|
||||
industry: str
|
||||
employees: int
|
||||
headquarters: str
|
||||
|
||||
text = """
|
||||
Tesla, Inc. was founded in 2003. It operates in the automotive and energy
|
||||
industry with approximately 140,000 employees. The company is headquartered
|
||||
in Austin, Texas.
|
||||
"""
|
||||
|
||||
company = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Extract company information from: {text}"
|
||||
}],
|
||||
response_model=CompanyInfo
|
||||
)
|
||||
```
|
||||
|
||||
### Pattern 2: Classification
|
||||
|
||||
```python
|
||||
class Category(str, Enum):
|
||||
TECHNOLOGY = "technology"
|
||||
FINANCE = "finance"
|
||||
HEALTHCARE = "healthcare"
|
||||
EDUCATION = "education"
|
||||
OTHER = "other"
|
||||
|
||||
class ArticleClassification(BaseModel):
|
||||
category: Category
|
||||
confidence: float = Field(ge=0.0, le=1.0)
|
||||
keywords: list[str]
|
||||
|
||||
classification = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Classify this article: [article text]"
|
||||
}],
|
||||
response_model=ArticleClassification
|
||||
)
|
||||
```
|
||||
|
||||
### Pattern 3: Multi-Entity Extraction
|
||||
|
||||
```python
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
role: str
|
||||
|
||||
class Organization(BaseModel):
|
||||
name: str
|
||||
industry: str
|
||||
|
||||
class Entities(BaseModel):
|
||||
people: list[Person]
|
||||
organizations: list[Organization]
|
||||
locations: list[str]
|
||||
|
||||
text = "Tim Cook, CEO of Apple, announced at the event in Cupertino..."
|
||||
|
||||
entities = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Extract all entities from: {text}"
|
||||
}],
|
||||
response_model=Entities
|
||||
)
|
||||
|
||||
for person in entities.people:
|
||||
print(f"{person.name} - {person.role}")
|
||||
```
|
||||
|
||||
### Pattern 4: Structured Analysis
|
||||
|
||||
```python
|
||||
class SentimentAnalysis(BaseModel):
|
||||
overall_sentiment: Sentiment
|
||||
positive_aspects: list[str]
|
||||
negative_aspects: list[str]
|
||||
suggestions: list[str]
|
||||
score: float = Field(ge=-1.0, le=1.0)
|
||||
|
||||
review = "The product works well but setup was confusing..."
|
||||
|
||||
analysis = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Analyze this review: {review}"
|
||||
}],
|
||||
response_model=SentimentAnalysis
|
||||
)
|
||||
```
|
||||
|
||||
### Pattern 5: Batch Processing
|
||||
|
||||
```python
|
||||
def extract_person(text: str) -> Person:
|
||||
return client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Extract person from: {text}"
|
||||
}],
|
||||
response_model=Person
|
||||
)
|
||||
|
||||
texts = [
|
||||
"John Doe is a 30-year-old engineer",
|
||||
"Jane Smith, 25, works in marketing",
|
||||
"Bob Johnson, age 40, software developer"
|
||||
]
|
||||
|
||||
people = [extract_person(text) for text in texts]
|
||||
```
|
||||
|
||||
## Advanced Features
|
||||
|
||||
### Union Types
|
||||
|
||||
```python
|
||||
from typing import Union
|
||||
|
||||
class TextContent(BaseModel):
|
||||
type: str = "text"
|
||||
content: str
|
||||
|
||||
class ImageContent(BaseModel):
|
||||
type: str = "image"
|
||||
url: HttpUrl
|
||||
caption: str
|
||||
|
||||
class Post(BaseModel):
|
||||
title: str
|
||||
content: Union[TextContent, ImageContent] # Either type
|
||||
|
||||
# LLM chooses appropriate type based on content
|
||||
```
|
||||
|
||||
### Dynamic Models
|
||||
|
||||
```python
|
||||
from pydantic import create_model
|
||||
|
||||
# Create model at runtime
|
||||
DynamicUser = create_model(
|
||||
'User',
|
||||
name=(str, ...),
|
||||
age=(int, Field(ge=0)),
|
||||
email=(EmailStr, ...)
|
||||
)
|
||||
|
||||
user = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[...],
|
||||
response_model=DynamicUser
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Modes
|
||||
|
||||
```python
|
||||
# For providers without native structured outputs
|
||||
client = instructor.from_anthropic(
|
||||
Anthropic(),
|
||||
mode=instructor.Mode.JSON # JSON mode
|
||||
)
|
||||
|
||||
# Available modes:
|
||||
# - Mode.ANTHROPIC_TOOLS (recommended for Claude)
|
||||
# - Mode.JSON (fallback)
|
||||
# - Mode.TOOLS (OpenAI tools)
|
||||
```
|
||||
|
||||
### Context Management
|
||||
|
||||
```python
|
||||
# Single-use client
|
||||
with instructor.from_anthropic(Anthropic()) as client:
|
||||
result = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[...],
|
||||
response_model=YourModel
|
||||
)
|
||||
# Client closed automatically
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Handling Validation Errors
|
||||
|
||||
```python
|
||||
from pydantic import ValidationError
|
||||
|
||||
try:
|
||||
user = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[...],
|
||||
response_model=User,
|
||||
max_retries=3
|
||||
)
|
||||
except ValidationError as e:
|
||||
print(f"Failed after retries: {e}")
|
||||
# Handle gracefully
|
||||
|
||||
except Exception as e:
|
||||
print(f"API error: {e}")
|
||||
```
|
||||
|
||||
### Custom Error Messages
|
||||
|
||||
```python
|
||||
class ValidatedUser(BaseModel):
|
||||
name: str = Field(description="Full name, 2-100 characters")
|
||||
age: int = Field(description="Age between 0 and 120", ge=0, le=120)
|
||||
email: EmailStr = Field(description="Valid email address")
|
||||
|
||||
class Config:
|
||||
# Custom error messages
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
{
|
||||
"name": "John Doe",
|
||||
"age": 30,
|
||||
"email": "john@example.com"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Clear Field Descriptions
|
||||
|
||||
```python
|
||||
# ❌ Bad: Vague
|
||||
class Product(BaseModel):
|
||||
name: str
|
||||
price: float
|
||||
|
||||
# ✅ Good: Descriptive
|
||||
class Product(BaseModel):
|
||||
name: str = Field(description="Product name from the text")
|
||||
price: float = Field(description="Price in USD, without currency symbol")
|
||||
```
|
||||
|
||||
### 2. Use Appropriate Validation
|
||||
|
||||
```python
|
||||
# ✅ Good: Constrain values
|
||||
class Rating(BaseModel):
|
||||
score: int = Field(ge=1, le=5, description="Rating from 1 to 5 stars")
|
||||
review: str = Field(min_length=10, description="Review text, at least 10 chars")
|
||||
```
|
||||
|
||||
### 3. Provide Examples in Prompts
|
||||
|
||||
```python
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": """Extract person info from: "John, 30, engineer"
|
||||
|
||||
Example format:
|
||||
{
|
||||
"name": "John Doe",
|
||||
"age": 30,
|
||||
"occupation": "engineer"
|
||||
}"""
|
||||
}]
|
||||
```
|
||||
|
||||
### 4. Use Enums for Fixed Categories
|
||||
|
||||
```python
|
||||
# ✅ Good: Enum ensures valid values
|
||||
class Status(str, Enum):
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
REJECTED = "rejected"
|
||||
|
||||
class Application(BaseModel):
|
||||
status: Status # LLM must choose from enum
|
||||
```
|
||||
|
||||
### 5. Handle Missing Data Gracefully
|
||||
|
||||
```python
|
||||
class PartialData(BaseModel):
|
||||
required_field: str
|
||||
optional_field: Optional[str] = None
|
||||
default_field: str = "default_value"
|
||||
|
||||
# LLM only needs to provide required_field
|
||||
```
|
||||
|
||||
## Comparison to Alternatives
|
||||
|
||||
| Feature | Instructor | Manual JSON | LangChain | DSPy |
|
||||
|---------|------------|-------------|-----------|------|
|
||||
| Type Safety | ✅ Yes | ❌ No | ⚠️ Partial | ✅ Yes |
|
||||
| Auto Validation | ✅ Yes | ❌ No | ❌ No | ⚠️ Limited |
|
||||
| Auto Retry | ✅ Yes | ❌ No | ❌ No | ✅ Yes |
|
||||
| Streaming | ✅ Yes | ❌ No | ✅ Yes | ❌ No |
|
||||
| Multi-Provider | ✅ Yes | ⚠️ Manual | ✅ Yes | ✅ Yes |
|
||||
| Learning Curve | Low | Low | Medium | High |
|
||||
|
||||
**When to choose Instructor:**
|
||||
- Need structured, validated outputs
|
||||
- Want type safety and IDE support
|
||||
- Require automatic retries
|
||||
- Building data extraction systems
|
||||
|
||||
**When to choose alternatives:**
|
||||
- DSPy: Need prompt optimization
|
||||
- LangChain: Building complex chains
|
||||
- Manual: Simple, one-off extractions
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://python.useinstructor.com
|
||||
- **GitHub**: https://github.com/jxnl/instructor (15k+ stars)
|
||||
- **Cookbook**: https://python.useinstructor.com/examples
|
||||
- **Discord**: Community support available
|
||||
|
||||
## See Also
|
||||
|
||||
- `references/validation.md` - Advanced validation patterns
|
||||
- `references/providers.md` - Provider-specific configuration
|
||||
- `references/examples.md` - Real-world use cases
|
||||
|
||||
|
||||
107
optional-skills/mlops/instructor/references/examples.md
Normal file
107
optional-skills/mlops/instructor/references/examples.md
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
# Real-World Examples
|
||||
|
||||
Practical examples of using Instructor for structured data extraction.
|
||||
|
||||
## Data Extraction
|
||||
|
||||
```python
|
||||
class CompanyInfo(BaseModel):
|
||||
name: str
|
||||
founded: int
|
||||
industry: str
|
||||
employees: int
|
||||
|
||||
text = "Apple was founded in 1976 in the technology industry with 164,000 employees."
|
||||
|
||||
company = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": f"Extract: {text}"}],
|
||||
response_model=CompanyInfo
|
||||
)
|
||||
```
|
||||
|
||||
## Classification
|
||||
|
||||
```python
|
||||
class Sentiment(str, Enum):
|
||||
POSITIVE = "positive"
|
||||
NEGATIVE = "negative"
|
||||
NEUTRAL = "neutral"
|
||||
|
||||
class Review(BaseModel):
|
||||
sentiment: Sentiment
|
||||
confidence: float = Field(ge=0.0, le=1.0)
|
||||
|
||||
review = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": "This product is amazing!"}],
|
||||
response_model=Review
|
||||
)
|
||||
```
|
||||
|
||||
## Multi-Entity Extraction
|
||||
|
||||
```python
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
role: str
|
||||
|
||||
class Entities(BaseModel):
|
||||
people: list[Person]
|
||||
organizations: list[str]
|
||||
locations: list[str]
|
||||
|
||||
entities = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": "Tim Cook, CEO of Apple, spoke in Cupertino..."}],
|
||||
response_model=Entities
|
||||
)
|
||||
```
|
||||
|
||||
## Structured Analysis
|
||||
|
||||
```python
|
||||
class Analysis(BaseModel):
|
||||
summary: str
|
||||
key_points: list[str]
|
||||
sentiment: Sentiment
|
||||
actionable_items: list[str]
|
||||
|
||||
analysis = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": "Analyze: [long text]"}],
|
||||
response_model=Analysis
|
||||
)
|
||||
```
|
||||
|
||||
## Batch Processing
|
||||
|
||||
```python
|
||||
texts = ["text1", "text2", "text3"]
|
||||
results = [
|
||||
client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": text}],
|
||||
response_model=YourModel
|
||||
)
|
||||
for text in texts
|
||||
]
|
||||
```
|
||||
|
||||
## Streaming
|
||||
|
||||
```python
|
||||
for partial in client.messages.create_partial(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": "Generate report..."}],
|
||||
response_model=Report
|
||||
):
|
||||
print(f"Progress: {partial.title}")
|
||||
# Update UI in real-time
|
||||
```
|
||||
70
optional-skills/mlops/instructor/references/providers.md
Normal file
70
optional-skills/mlops/instructor/references/providers.md
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
# Provider Configuration
|
||||
|
||||
Guide to using Instructor with different LLM providers.
|
||||
|
||||
## Anthropic Claude
|
||||
|
||||
```python
|
||||
import instructor
|
||||
from anthropic import Anthropic
|
||||
|
||||
# Basic setup
|
||||
client = instructor.from_anthropic(Anthropic())
|
||||
|
||||
# With API key
|
||||
client = instructor.from_anthropic(
|
||||
Anthropic(api_key="your-api-key")
|
||||
)
|
||||
|
||||
# Recommended mode
|
||||
client = instructor.from_anthropic(
|
||||
Anthropic(),
|
||||
mode=instructor.Mode.ANTHROPIC_TOOLS
|
||||
)
|
||||
|
||||
# Usage
|
||||
result = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": "..."}],
|
||||
response_model=YourModel
|
||||
)
|
||||
```
|
||||
|
||||
## OpenAI
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = instructor.from_openai(OpenAI())
|
||||
|
||||
result = client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
response_model=YourModel,
|
||||
messages=[{"role": "user", "content": "..."}]
|
||||
)
|
||||
```
|
||||
|
||||
## Local Models (Ollama)
|
||||
|
||||
```python
|
||||
client = instructor.from_openai(
|
||||
OpenAI(
|
||||
base_url="http://localhost:11434/v1",
|
||||
api_key="ollama"
|
||||
),
|
||||
mode=instructor.Mode.JSON
|
||||
)
|
||||
|
||||
result = client.chat.completions.create(
|
||||
model="llama3.1",
|
||||
response_model=YourModel,
|
||||
messages=[...]
|
||||
)
|
||||
```
|
||||
|
||||
## Modes
|
||||
|
||||
- `Mode.ANTHROPIC_TOOLS`: Recommended for Claude
|
||||
- `Mode.TOOLS`: OpenAI function calling
|
||||
- `Mode.JSON`: Fallback for unsupported providers
|
||||
606
optional-skills/mlops/instructor/references/validation.md
Normal file
606
optional-skills/mlops/instructor/references/validation.md
Normal file
|
|
@ -0,0 +1,606 @@
|
|||
# Advanced Validation Patterns
|
||||
|
||||
Complete guide to validation in Instructor using Pydantic.
|
||||
|
||||
## Table of Contents
|
||||
- Built-in Validators
|
||||
- Custom Field Validators
|
||||
- Model-Level Validation
|
||||
- Complex Validation Patterns
|
||||
- Error Handling
|
||||
|
||||
## Built-in Validators
|
||||
|
||||
### Numeric Constraints
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Product(BaseModel):
|
||||
price: float = Field(gt=0, description="Price must be positive")
|
||||
discount: float = Field(ge=0, le=100, description="Discount 0-100%")
|
||||
quantity: int = Field(ge=1, description="At least 1 item")
|
||||
rating: float = Field(ge=0.0, le=5.0, description="Rating 0-5 stars")
|
||||
|
||||
# If LLM provides invalid values, automatic retry with error feedback
|
||||
```
|
||||
|
||||
**Available constraints:**
|
||||
- `gt`: Greater than
|
||||
- `ge`: Greater than or equal
|
||||
- `lt`: Less than
|
||||
- `le`: Less than or equal
|
||||
- `multiple_of`: Must be multiple of this number
|
||||
|
||||
### String Constraints
|
||||
|
||||
```python
|
||||
class User(BaseModel):
|
||||
username: str = Field(
|
||||
min_length=3,
|
||||
max_length=20,
|
||||
pattern=r'^[a-zA-Z0-9_]+$',
|
||||
description="3-20 alphanumeric characters"
|
||||
)
|
||||
bio: str = Field(max_length=500, description="Bio up to 500 chars")
|
||||
status: str = Field(pattern=r'^(active|inactive|pending)$')
|
||||
|
||||
# pattern validates against regex
|
||||
```
|
||||
|
||||
### Email and URL Validation
|
||||
|
||||
```python
|
||||
from pydantic import EmailStr, HttpUrl, AnyUrl
|
||||
|
||||
class Contact(BaseModel):
|
||||
email: EmailStr # Validates email format
|
||||
website: HttpUrl # Validates HTTP/HTTPS URLs
|
||||
portfolio: AnyUrl # Any valid URL scheme
|
||||
|
||||
contact = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Extract: john@example.com, https://example.com"
|
||||
}],
|
||||
response_model=Contact
|
||||
)
|
||||
```
|
||||
|
||||
### Date and DateTime Validation
|
||||
|
||||
```python
|
||||
from datetime import date, datetime
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
class Event(BaseModel):
|
||||
event_date: date # Validates date format
|
||||
created_at: datetime # Validates datetime format
|
||||
year: int = Field(ge=1900, le=2100)
|
||||
|
||||
@field_validator('event_date')
|
||||
def future_date(cls, v):
|
||||
"""Ensure event is in the future."""
|
||||
if v < date.today():
|
||||
raise ValueError('Event must be in the future')
|
||||
return v
|
||||
```
|
||||
|
||||
### List and Dict Validation
|
||||
|
||||
```python
|
||||
class Document(BaseModel):
|
||||
tags: list[str] = Field(min_length=1, max_length=10)
|
||||
keywords: list[str] = Field(min_length=3, description="At least 3 keywords")
|
||||
metadata: dict[str, str] = Field(description="String key-value pairs")
|
||||
|
||||
@field_validator('tags')
|
||||
def unique_tags(cls, v):
|
||||
"""Ensure tags are unique."""
|
||||
if len(v) != len(set(v)):
|
||||
raise ValueError('Tags must be unique')
|
||||
return v
|
||||
```
|
||||
|
||||
## Custom Field Validators
|
||||
|
||||
### Basic Field Validator
|
||||
|
||||
```python
|
||||
from pydantic import field_validator
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
@field_validator('name')
|
||||
def name_must_not_be_empty(cls, v):
|
||||
"""Validate name is not empty or just whitespace."""
|
||||
if not v or not v.strip():
|
||||
raise ValueError('Name cannot be empty')
|
||||
return v.strip()
|
||||
|
||||
@field_validator('age')
|
||||
def age_must_be_reasonable(cls, v):
|
||||
"""Validate age is between 0 and 120."""
|
||||
if v < 0 or v > 120:
|
||||
raise ValueError('Age must be between 0 and 120')
|
||||
return v
|
||||
```
|
||||
|
||||
### Validator with Field Info
|
||||
|
||||
```python
|
||||
from pydantic import ValidationInfo
|
||||
|
||||
class Article(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
|
||||
@field_validator('content')
|
||||
def content_length(cls, v, info: ValidationInfo):
|
||||
"""Validate content is longer than title."""
|
||||
if 'title' in info.data:
|
||||
title_len = len(info.data['title'])
|
||||
if len(v) < title_len * 2:
|
||||
raise ValueError('Content should be at least 2x title length')
|
||||
return v
|
||||
```
|
||||
|
||||
### Multiple Fields Validation
|
||||
|
||||
```python
|
||||
class TimeRange(BaseModel):
|
||||
start_time: str
|
||||
end_time: str
|
||||
|
||||
@field_validator('start_time', 'end_time')
|
||||
def valid_time_format(cls, v):
|
||||
"""Validate both times are in HH:MM format."""
|
||||
import re
|
||||
if not re.match(r'^\d{2}:\d{2}$', v):
|
||||
raise ValueError('Time must be in HH:MM format')
|
||||
return v
|
||||
```
|
||||
|
||||
### Transform and Validate
|
||||
|
||||
```python
|
||||
class URL(BaseModel):
|
||||
url: str
|
||||
|
||||
@field_validator('url')
|
||||
def normalize_url(cls, v):
|
||||
"""Add https:// if missing."""
|
||||
if not v.startswith(('http://', 'https://')):
|
||||
v = f'https://{v}'
|
||||
return v
|
||||
```
|
||||
|
||||
## Model-Level Validation
|
||||
|
||||
### Cross-Field Validation
|
||||
|
||||
```python
|
||||
from pydantic import model_validator
|
||||
|
||||
class DateRange(BaseModel):
|
||||
start_date: str
|
||||
end_date: str
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_dates(self):
|
||||
"""Ensure end_date is after start_date."""
|
||||
from datetime import datetime
|
||||
start = datetime.strptime(self.start_date, '%Y-%m-%d')
|
||||
end = datetime.strptime(self.end_date, '%Y-%m-%d')
|
||||
|
||||
if end < start:
|
||||
raise ValueError('end_date must be after start_date')
|
||||
return self
|
||||
|
||||
class PriceRange(BaseModel):
|
||||
min_price: float
|
||||
max_price: float
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_price_range(self):
|
||||
"""Ensure max > min."""
|
||||
if self.max_price <= self.min_price:
|
||||
raise ValueError('max_price must be greater than min_price')
|
||||
return self
|
||||
```
|
||||
|
||||
### Conditional Validation
|
||||
|
||||
```python
|
||||
class Order(BaseModel):
|
||||
order_type: str # "standard" or "express"
|
||||
delivery_date: str
|
||||
delivery_time: Optional[str] = None
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_delivery_time(self):
|
||||
"""Express orders need delivery time."""
|
||||
if self.order_type == "express" and not self.delivery_time:
|
||||
raise ValueError('Express orders require delivery_time')
|
||||
return self
|
||||
```
|
||||
|
||||
### Complex Business Logic
|
||||
|
||||
```python
|
||||
class Discount(BaseModel):
|
||||
code: str
|
||||
percentage: float = Field(ge=0, le=100)
|
||||
min_purchase: float = Field(ge=0)
|
||||
max_discount: float = Field(ge=0)
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_discount(self):
|
||||
"""Ensure discount logic is sound."""
|
||||
# Max discount can't exceed percentage of min_purchase
|
||||
theoretical_max = (self.percentage / 100) * self.min_purchase
|
||||
if self.max_discount > theoretical_max:
|
||||
self.max_discount = theoretical_max
|
||||
return self
|
||||
```
|
||||
|
||||
## Complex Validation Patterns
|
||||
|
||||
### Nested Model Validation
|
||||
|
||||
```python
|
||||
class Address(BaseModel):
|
||||
street: str
|
||||
city: str
|
||||
country: str
|
||||
postal_code: str
|
||||
|
||||
@field_validator('postal_code')
|
||||
def validate_postal_code(cls, v, info: ValidationInfo):
|
||||
"""Validate postal code format based on country."""
|
||||
if 'country' in info.data:
|
||||
country = info.data['country']
|
||||
if country == "USA":
|
||||
import re
|
||||
if not re.match(r'^\d{5}(-\d{4})?$', v):
|
||||
raise ValueError('Invalid US postal code')
|
||||
elif country == "Canada":
|
||||
if not re.match(r'^[A-Z]\d[A-Z] \d[A-Z]\d$', v):
|
||||
raise ValueError('Invalid Canadian postal code')
|
||||
return v
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
address: Address
|
||||
|
||||
# Nested validation runs automatically
|
||||
```
|
||||
|
||||
### List of Models
|
||||
|
||||
```python
|
||||
class Task(BaseModel):
|
||||
title: str = Field(min_length=1)
|
||||
priority: int = Field(ge=1, le=5)
|
||||
|
||||
class Project(BaseModel):
|
||||
name: str
|
||||
tasks: list[Task] = Field(min_length=1, description="At least 1 task")
|
||||
|
||||
@field_validator('tasks')
|
||||
def at_least_one_high_priority(cls, v):
|
||||
"""Ensure at least one task has priority >= 4."""
|
||||
if not any(task.priority >= 4 for task in v):
|
||||
raise ValueError('Project needs at least one high-priority task')
|
||||
return v
|
||||
```
|
||||
|
||||
### Union Type Validation
|
||||
|
||||
```python
|
||||
from typing import Union
|
||||
|
||||
class TextBlock(BaseModel):
|
||||
type: str = "text"
|
||||
content: str = Field(min_length=1)
|
||||
|
||||
class ImageBlock(BaseModel):
|
||||
type: str = "image"
|
||||
url: HttpUrl
|
||||
alt_text: str
|
||||
|
||||
class Page(BaseModel):
|
||||
title: str
|
||||
blocks: list[Union[TextBlock, ImageBlock]]
|
||||
|
||||
@field_validator('blocks')
|
||||
def validate_block_types(cls, v):
|
||||
"""Ensure first block is TextBlock."""
|
||||
if v and not isinstance(v[0], TextBlock):
|
||||
raise ValueError('First block must be text')
|
||||
return v
|
||||
```
|
||||
|
||||
### Dependent Fields
|
||||
|
||||
```python
|
||||
class Subscription(BaseModel):
|
||||
plan: str # "free", "pro", "enterprise"
|
||||
max_users: int
|
||||
features: list[str]
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_plan_limits(self):
|
||||
"""Enforce plan-specific limits."""
|
||||
limits = {
|
||||
"free": {"max_users": 1, "required_features": ["basic"]},
|
||||
"pro": {"max_users": 10, "required_features": ["basic", "advanced"]},
|
||||
"enterprise": {"max_users": 999, "required_features": ["basic", "advanced", "premium"]}
|
||||
}
|
||||
|
||||
if self.plan in limits:
|
||||
limit = limits[self.plan]
|
||||
|
||||
if self.max_users > limit["max_users"]:
|
||||
raise ValueError(f'{self.plan} plan limited to {limit["max_users"]} users')
|
||||
|
||||
for feature in limit["required_features"]:
|
||||
if feature not in self.features:
|
||||
raise ValueError(f'{self.plan} plan requires {feature} feature')
|
||||
|
||||
return self
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Graceful Degradation
|
||||
|
||||
```python
|
||||
class OptionalExtraction(BaseModel):
|
||||
# Required fields
|
||||
title: str
|
||||
|
||||
# Optional fields with defaults
|
||||
author: Optional[str] = None
|
||||
date: Optional[str] = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
|
||||
# LLM can succeed even if it can't extract everything
|
||||
```
|
||||
|
||||
### Partial Validation
|
||||
|
||||
```python
|
||||
from pydantic import ValidationError
|
||||
|
||||
def extract_with_fallback(text: str):
|
||||
"""Try full extraction, fall back to partial."""
|
||||
try:
|
||||
# Try full extraction
|
||||
return client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": text}],
|
||||
response_model=FullModel
|
||||
)
|
||||
except ValidationError:
|
||||
# Fall back to partial model
|
||||
return client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": text}],
|
||||
response_model=PartialModel
|
||||
)
|
||||
```
|
||||
|
||||
### Validation Error Inspection
|
||||
|
||||
```python
|
||||
from pydantic import ValidationError
|
||||
|
||||
try:
|
||||
result = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[...],
|
||||
response_model=MyModel,
|
||||
max_retries=3
|
||||
)
|
||||
except ValidationError as e:
|
||||
# Inspect specific errors
|
||||
for error in e.errors():
|
||||
field = error['loc'][0]
|
||||
message = error['msg']
|
||||
print(f"Field '{field}' failed: {message}")
|
||||
|
||||
# Custom handling per field
|
||||
if field == 'email':
|
||||
# Handle email validation failure
|
||||
pass
|
||||
```
|
||||
|
||||
### Custom Error Messages
|
||||
|
||||
```python
|
||||
class DetailedModel(BaseModel):
|
||||
name: str = Field(
|
||||
min_length=2,
|
||||
max_length=100,
|
||||
description="Name between 2-100 characters"
|
||||
)
|
||||
age: int = Field(
|
||||
ge=0,
|
||||
le=120,
|
||||
description="Age between 0 and 120 years"
|
||||
)
|
||||
|
||||
@field_validator('name')
|
||||
def validate_name(cls, v):
|
||||
"""Provide helpful error message."""
|
||||
if not v.strip():
|
||||
raise ValueError(
|
||||
'Name cannot be empty. '
|
||||
'Please provide a valid name from the text.'
|
||||
)
|
||||
return v
|
||||
|
||||
# When validation fails, LLM sees these helpful messages
|
||||
```
|
||||
|
||||
## Validation Best Practices
|
||||
|
||||
### 1. Be Specific
|
||||
|
||||
```python
|
||||
# ❌ Bad: Vague validation
|
||||
class Item(BaseModel):
|
||||
name: str
|
||||
|
||||
# ✅ Good: Specific constraints
|
||||
class Item(BaseModel):
|
||||
name: str = Field(
|
||||
min_length=1,
|
||||
max_length=200,
|
||||
description="Item name, 1-200 characters"
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Provide Context
|
||||
|
||||
```python
|
||||
# ✅ Good: Explain why validation failed
|
||||
@field_validator('price')
|
||||
def validate_price(cls, v):
|
||||
if v <= 0:
|
||||
raise ValueError(
|
||||
'Price must be positive. '
|
||||
'Extract numeric price from text without currency symbols.'
|
||||
)
|
||||
return v
|
||||
```
|
||||
|
||||
### 3. Use Enums for Fixed Sets
|
||||
|
||||
```python
|
||||
# ❌ Bad: String validation
|
||||
status: str
|
||||
|
||||
@field_validator('status')
|
||||
def validate_status(cls, v):
|
||||
if v not in ['active', 'inactive', 'pending']:
|
||||
raise ValueError('Invalid status')
|
||||
return v
|
||||
|
||||
# ✅ Good: Enum
|
||||
class Status(str, Enum):
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
PENDING = "pending"
|
||||
|
||||
status: Status # Validation automatic
|
||||
```
|
||||
|
||||
### 4. Balance Strictness
|
||||
|
||||
```python
|
||||
# Too strict: May fail unnecessarily
|
||||
class StrictModel(BaseModel):
|
||||
date: str = Field(pattern=r'^\d{4}-\d{2}-\d{2}$')
|
||||
# Fails if LLM uses "2024-1-5" instead of "2024-01-05"
|
||||
|
||||
# Better: Normalize in validator
|
||||
class FlexibleModel(BaseModel):
|
||||
date: str
|
||||
|
||||
@field_validator('date')
|
||||
def normalize_date(cls, v):
|
||||
from datetime import datetime
|
||||
# Parse flexible formats
|
||||
for fmt in ['%Y-%m-%d', '%Y/%m/%d', '%m/%d/%Y']:
|
||||
try:
|
||||
dt = datetime.strptime(v, fmt)
|
||||
return dt.strftime('%Y-%m-%d') # Normalize
|
||||
except ValueError:
|
||||
continue
|
||||
raise ValueError('Invalid date format')
|
||||
```
|
||||
|
||||
### 5. Test Validation
|
||||
|
||||
```python
|
||||
# Test your validators with edge cases
|
||||
def test_validation():
|
||||
# Should succeed
|
||||
valid = MyModel(field="valid_value")
|
||||
|
||||
# Should fail
|
||||
try:
|
||||
invalid = MyModel(field="invalid")
|
||||
assert False, "Should have raised ValidationError"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
|
||||
# Run tests before using in production
|
||||
```
|
||||
|
||||
## Advanced Techniques
|
||||
|
||||
### Conditional Required Fields
|
||||
|
||||
```python
|
||||
from typing import Optional
|
||||
|
||||
class ConditionalModel(BaseModel):
|
||||
type: str
|
||||
detail_a: Optional[str] = None
|
||||
detail_b: Optional[str] = None
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_required_details(self):
|
||||
"""Require different fields based on type."""
|
||||
if self.type == "type_a" and not self.detail_a:
|
||||
raise ValueError('type_a requires detail_a')
|
||||
if self.type == "type_b" and not self.detail_b:
|
||||
raise ValueError('type_b requires detail_b')
|
||||
return self
|
||||
```
|
||||
|
||||
### Validation with External Data
|
||||
|
||||
```python
|
||||
class Product(BaseModel):
|
||||
sku: str
|
||||
name: str
|
||||
|
||||
@field_validator('sku')
|
||||
def validate_sku(cls, v):
|
||||
"""Check SKU exists in database."""
|
||||
# Query database or API
|
||||
if not database.sku_exists(v):
|
||||
raise ValueError(f'SKU {v} not found in catalog')
|
||||
return v
|
||||
```
|
||||
|
||||
### Progressive Validation
|
||||
|
||||
```python
|
||||
# Start with loose validation
|
||||
class Stage1(BaseModel):
|
||||
data: str # Any string
|
||||
|
||||
# Then strict validation
|
||||
class Stage2(BaseModel):
|
||||
data: str = Field(pattern=r'^[A-Z]{3}-\d{6}$')
|
||||
|
||||
# Use Stage1 for initial extraction
|
||||
# Use Stage2 for final validation
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Pydantic Docs**: https://docs.pydantic.dev/latest/concepts/validators/
|
||||
- **Instructor Examples**: https://python.useinstructor.com/examples
|
||||
548
optional-skills/mlops/lambda-labs/SKILL.md
Normal file
548
optional-skills/mlops/lambda-labs/SKILL.md
Normal file
|
|
@ -0,0 +1,548 @@
|
|||
---
|
||||
name: lambda-labs-gpu-cloud
|
||||
description: Reserved and on-demand GPU cloud instances for ML training and inference. Use when you need dedicated GPU instances with simple SSH access, persistent filesystems, or high-performance multi-node clusters for large-scale training.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [lambda-cloud-client>=1.0.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Infrastructure, GPU Cloud, Training, Inference, Lambda Labs]
|
||||
|
||||
---
|
||||
|
||||
# Lambda Labs GPU Cloud
|
||||
|
||||
Comprehensive guide to running ML workloads on Lambda Labs GPU cloud with on-demand instances and 1-Click Clusters.
|
||||
|
||||
## When to use Lambda Labs
|
||||
|
||||
**Use Lambda Labs when:**
|
||||
- Need dedicated GPU instances with full SSH access
|
||||
- Running long training jobs (hours to days)
|
||||
- Want simple pricing with no egress fees
|
||||
- Need persistent storage across sessions
|
||||
- Require high-performance multi-node clusters (16-512 GPUs)
|
||||
- Want pre-installed ML stack (Lambda Stack with PyTorch, CUDA, NCCL)
|
||||
|
||||
**Key features:**
|
||||
- **GPU variety**: B200, H100, GH200, A100, A10, A6000, V100
|
||||
- **Lambda Stack**: Pre-installed PyTorch, TensorFlow, CUDA, cuDNN, NCCL
|
||||
- **Persistent filesystems**: Keep data across instance restarts
|
||||
- **1-Click Clusters**: 16-512 GPU Slurm clusters with InfiniBand
|
||||
- **Simple pricing**: Pay-per-minute, no egress fees
|
||||
- **Global regions**: 12+ regions worldwide
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Modal**: For serverless, auto-scaling workloads
|
||||
- **SkyPilot**: For multi-cloud orchestration and cost optimization
|
||||
- **RunPod**: For cheaper spot instances and serverless endpoints
|
||||
- **Vast.ai**: For GPU marketplace with lowest prices
|
||||
|
||||
## Quick start
|
||||
|
||||
### Account setup
|
||||
|
||||
1. Create account at https://lambda.ai
|
||||
2. Add payment method
|
||||
3. Generate API key from dashboard
|
||||
4. Add SSH key (required before launching instances)
|
||||
|
||||
### Launch via console
|
||||
|
||||
1. Go to https://cloud.lambda.ai/instances
|
||||
2. Click "Launch instance"
|
||||
3. Select GPU type and region
|
||||
4. Choose SSH key
|
||||
5. Optionally attach filesystem
|
||||
6. Launch and wait 3-15 minutes
|
||||
|
||||
### Connect via SSH
|
||||
|
||||
```bash
|
||||
# Get instance IP from console
|
||||
ssh ubuntu@<INSTANCE-IP>
|
||||
|
||||
# Or with specific key
|
||||
ssh -i ~/.ssh/lambda_key ubuntu@<INSTANCE-IP>
|
||||
```
|
||||
|
||||
## GPU instances
|
||||
|
||||
### Available GPUs
|
||||
|
||||
| GPU | VRAM | Price/GPU/hr | Best For |
|
||||
|-----|------|--------------|----------|
|
||||
| B200 SXM6 | 180 GB | $4.99 | Largest models, fastest training |
|
||||
| H100 SXM | 80 GB | $2.99-3.29 | Large model training |
|
||||
| H100 PCIe | 80 GB | $2.49 | Cost-effective H100 |
|
||||
| GH200 | 96 GB | $1.49 | Single-GPU large models |
|
||||
| A100 80GB | 80 GB | $1.79 | Production training |
|
||||
| A100 40GB | 40 GB | $1.29 | Standard training |
|
||||
| A10 | 24 GB | $0.75 | Inference, fine-tuning |
|
||||
| A6000 | 48 GB | $0.80 | Good VRAM/price ratio |
|
||||
| V100 | 16 GB | $0.55 | Budget training |
|
||||
|
||||
### Instance configurations
|
||||
|
||||
```
|
||||
8x GPU: Best for distributed training (DDP, FSDP)
|
||||
4x GPU: Large models, multi-GPU training
|
||||
2x GPU: Medium workloads
|
||||
1x GPU: Fine-tuning, inference, development
|
||||
```
|
||||
|
||||
### Launch times
|
||||
|
||||
- Single-GPU: 3-5 minutes
|
||||
- Multi-GPU: 10-15 minutes
|
||||
|
||||
## Lambda Stack
|
||||
|
||||
All instances come with Lambda Stack pre-installed:
|
||||
|
||||
```bash
|
||||
# Included software
|
||||
- Ubuntu 22.04 LTS
|
||||
- NVIDIA drivers (latest)
|
||||
- CUDA 12.x
|
||||
- cuDNN 8.x
|
||||
- NCCL (for multi-GPU)
|
||||
- PyTorch (latest)
|
||||
- TensorFlow (latest)
|
||||
- JAX
|
||||
- JupyterLab
|
||||
```
|
||||
|
||||
### Verify installation
|
||||
|
||||
```bash
|
||||
# Check GPU
|
||||
nvidia-smi
|
||||
|
||||
# Check PyTorch
|
||||
python -c "import torch; print(torch.cuda.is_available())"
|
||||
|
||||
# Check CUDA version
|
||||
nvcc --version
|
||||
```
|
||||
|
||||
## Python API
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
pip install lambda-cloud-client
|
||||
```
|
||||
|
||||
### Authentication
|
||||
|
||||
```python
|
||||
import os
|
||||
import lambda_cloud_client
|
||||
|
||||
# Configure with API key
|
||||
configuration = lambda_cloud_client.Configuration(
|
||||
host="https://cloud.lambdalabs.com/api/v1",
|
||||
access_token=os.environ["LAMBDA_API_KEY"]
|
||||
)
|
||||
```
|
||||
|
||||
### List available instances
|
||||
|
||||
```python
|
||||
with lambda_cloud_client.ApiClient(configuration) as api_client:
|
||||
api = lambda_cloud_client.DefaultApi(api_client)
|
||||
|
||||
# Get available instance types
|
||||
types = api.instance_types()
|
||||
for name, info in types.data.items():
|
||||
print(f"{name}: {info.instance_type.description}")
|
||||
```
|
||||
|
||||
### Launch instance
|
||||
|
||||
```python
|
||||
from lambda_cloud_client.models import LaunchInstanceRequest
|
||||
|
||||
request = LaunchInstanceRequest(
|
||||
region_name="us-west-1",
|
||||
instance_type_name="gpu_1x_h100_sxm5",
|
||||
ssh_key_names=["my-ssh-key"],
|
||||
file_system_names=["my-filesystem"], # Optional
|
||||
name="training-job"
|
||||
)
|
||||
|
||||
response = api.launch_instance(request)
|
||||
instance_id = response.data.instance_ids[0]
|
||||
print(f"Launched: {instance_id}")
|
||||
```
|
||||
|
||||
### List running instances
|
||||
|
||||
```python
|
||||
instances = api.list_instances()
|
||||
for instance in instances.data:
|
||||
print(f"{instance.name}: {instance.ip} ({instance.status})")
|
||||
```
|
||||
|
||||
### Terminate instance
|
||||
|
||||
```python
|
||||
from lambda_cloud_client.models import TerminateInstanceRequest
|
||||
|
||||
request = TerminateInstanceRequest(
|
||||
instance_ids=[instance_id]
|
||||
)
|
||||
api.terminate_instance(request)
|
||||
```
|
||||
|
||||
### SSH key management
|
||||
|
||||
```python
|
||||
from lambda_cloud_client.models import AddSshKeyRequest
|
||||
|
||||
# Add SSH key
|
||||
request = AddSshKeyRequest(
|
||||
name="my-key",
|
||||
public_key="ssh-rsa AAAA..."
|
||||
)
|
||||
api.add_ssh_key(request)
|
||||
|
||||
# List keys
|
||||
keys = api.list_ssh_keys()
|
||||
|
||||
# Delete key
|
||||
api.delete_ssh_key(key_id)
|
||||
```
|
||||
|
||||
## CLI with curl
|
||||
|
||||
### List instance types
|
||||
|
||||
```bash
|
||||
curl -u $LAMBDA_API_KEY: \
|
||||
https://cloud.lambdalabs.com/api/v1/instance-types | jq
|
||||
```
|
||||
|
||||
### Launch instance
|
||||
|
||||
```bash
|
||||
curl -u $LAMBDA_API_KEY: \
|
||||
-X POST https://cloud.lambdalabs.com/api/v1/instance-operations/launch \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"region_name": "us-west-1",
|
||||
"instance_type_name": "gpu_1x_h100_sxm5",
|
||||
"ssh_key_names": ["my-key"]
|
||||
}' | jq
|
||||
```
|
||||
|
||||
### Terminate instance
|
||||
|
||||
```bash
|
||||
curl -u $LAMBDA_API_KEY: \
|
||||
-X POST https://cloud.lambdalabs.com/api/v1/instance-operations/terminate \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"instance_ids": ["<INSTANCE-ID>"]}' | jq
|
||||
```
|
||||
|
||||
## Persistent storage
|
||||
|
||||
### Filesystems
|
||||
|
||||
Filesystems persist data across instance restarts:
|
||||
|
||||
```bash
|
||||
# Mount location
|
||||
/lambda/nfs/<FILESYSTEM_NAME>
|
||||
|
||||
# Example: save checkpoints
|
||||
python train.py --checkpoint-dir /lambda/nfs/my-storage/checkpoints
|
||||
```
|
||||
|
||||
### Create filesystem
|
||||
|
||||
1. Go to Storage in Lambda console
|
||||
2. Click "Create filesystem"
|
||||
3. Select region (must match instance region)
|
||||
4. Name and create
|
||||
|
||||
### Attach to instance
|
||||
|
||||
Filesystems must be attached at instance launch time:
|
||||
- Via console: Select filesystem when launching
|
||||
- Via API: Include `file_system_names` in launch request
|
||||
|
||||
### Best practices
|
||||
|
||||
```bash
|
||||
# Store on filesystem (persists)
|
||||
/lambda/nfs/storage/
|
||||
├── datasets/
|
||||
├── checkpoints/
|
||||
├── models/
|
||||
└── outputs/
|
||||
|
||||
# Local SSD (faster, ephemeral)
|
||||
/home/ubuntu/
|
||||
└── working/ # Temporary files
|
||||
```
|
||||
|
||||
## SSH configuration
|
||||
|
||||
### Add SSH key
|
||||
|
||||
```bash
|
||||
# Generate key locally
|
||||
ssh-keygen -t ed25519 -f ~/.ssh/lambda_key
|
||||
|
||||
# Add public key to Lambda console
|
||||
# Or via API
|
||||
```
|
||||
|
||||
### Multiple keys
|
||||
|
||||
```bash
|
||||
# On instance, add more keys
|
||||
echo 'ssh-rsa AAAA...' >> ~/.ssh/authorized_keys
|
||||
```
|
||||
|
||||
### Import from GitHub
|
||||
|
||||
```bash
|
||||
# On instance
|
||||
ssh-import-id gh:username
|
||||
```
|
||||
|
||||
### SSH tunneling
|
||||
|
||||
```bash
|
||||
# Forward Jupyter
|
||||
ssh -L 8888:localhost:8888 ubuntu@<IP>
|
||||
|
||||
# Forward TensorBoard
|
||||
ssh -L 6006:localhost:6006 ubuntu@<IP>
|
||||
|
||||
# Multiple ports
|
||||
ssh -L 8888:localhost:8888 -L 6006:localhost:6006 ubuntu@<IP>
|
||||
```
|
||||
|
||||
## JupyterLab
|
||||
|
||||
### Launch from console
|
||||
|
||||
1. Go to Instances page
|
||||
2. Click "Launch" in Cloud IDE column
|
||||
3. JupyterLab opens in browser
|
||||
|
||||
### Manual access
|
||||
|
||||
```bash
|
||||
# On instance
|
||||
jupyter lab --ip=0.0.0.0 --port=8888
|
||||
|
||||
# From local machine with tunnel
|
||||
ssh -L 8888:localhost:8888 ubuntu@<IP>
|
||||
# Open http://localhost:8888
|
||||
```
|
||||
|
||||
## Training workflows
|
||||
|
||||
### Single-GPU training
|
||||
|
||||
```bash
|
||||
# SSH to instance
|
||||
ssh ubuntu@<IP>
|
||||
|
||||
# Clone repo
|
||||
git clone https://github.com/user/project
|
||||
cd project
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Train
|
||||
python train.py --epochs 100 --checkpoint-dir /lambda/nfs/storage/checkpoints
|
||||
```
|
||||
|
||||
### Multi-GPU training (single node)
|
||||
|
||||
```python
|
||||
# train_ddp.py
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
def main():
|
||||
dist.init_process_group("nccl")
|
||||
rank = dist.get_rank()
|
||||
device = rank % torch.cuda.device_count()
|
||||
|
||||
model = MyModel().to(device)
|
||||
model = DDP(model, device_ids=[device])
|
||||
|
||||
# Training loop...
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
|
||||
```bash
|
||||
# Launch with torchrun (8 GPUs)
|
||||
torchrun --nproc_per_node=8 train_ddp.py
|
||||
```
|
||||
|
||||
### Checkpoint to filesystem
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
checkpoint_dir = "/lambda/nfs/my-storage/checkpoints"
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Save checkpoint
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': loss,
|
||||
}, f"{checkpoint_dir}/checkpoint_{epoch}.pt")
|
||||
```
|
||||
|
||||
## 1-Click Clusters
|
||||
|
||||
### Overview
|
||||
|
||||
High-performance Slurm clusters with:
|
||||
- 16-512 NVIDIA H100 or B200 GPUs
|
||||
- NVIDIA Quantum-2 400 Gb/s InfiniBand
|
||||
- GPUDirect RDMA at 3200 Gb/s
|
||||
- Pre-installed distributed ML stack
|
||||
|
||||
### Included software
|
||||
|
||||
- Ubuntu 22.04 LTS + Lambda Stack
|
||||
- NCCL, Open MPI
|
||||
- PyTorch with DDP and FSDP
|
||||
- TensorFlow
|
||||
- OFED drivers
|
||||
|
||||
### Storage
|
||||
|
||||
- 24 TB NVMe per compute node (ephemeral)
|
||||
- Lambda filesystems for persistent data
|
||||
|
||||
### Multi-node training
|
||||
|
||||
```bash
|
||||
# On Slurm cluster
|
||||
srun --nodes=4 --ntasks-per-node=8 --gpus-per-node=8 \
|
||||
torchrun --nnodes=4 --nproc_per_node=8 \
|
||||
--rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR:29500 \
|
||||
train.py
|
||||
```
|
||||
|
||||
## Networking
|
||||
|
||||
### Bandwidth
|
||||
|
||||
- Inter-instance (same region): up to 200 Gbps
|
||||
- Internet outbound: 20 Gbps max
|
||||
|
||||
### Firewall
|
||||
|
||||
- Default: Only port 22 (SSH) open
|
||||
- Configure additional ports in Lambda console
|
||||
- ICMP traffic allowed by default
|
||||
|
||||
### Private IPs
|
||||
|
||||
```bash
|
||||
# Find private IP
|
||||
ip addr show | grep 'inet '
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Fine-tuning LLM
|
||||
|
||||
```bash
|
||||
# 1. Launch 8x H100 instance with filesystem
|
||||
|
||||
# 2. SSH and setup
|
||||
ssh ubuntu@<IP>
|
||||
pip install transformers accelerate peft
|
||||
|
||||
# 3. Download model to filesystem
|
||||
python -c "
|
||||
from transformers import AutoModelForCausalLM
|
||||
model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf')
|
||||
model.save_pretrained('/lambda/nfs/storage/models/llama-2-7b')
|
||||
"
|
||||
|
||||
# 4. Fine-tune with checkpoints on filesystem
|
||||
accelerate launch --num_processes 8 train.py \
|
||||
--model_path /lambda/nfs/storage/models/llama-2-7b \
|
||||
--output_dir /lambda/nfs/storage/outputs \
|
||||
--checkpoint_dir /lambda/nfs/storage/checkpoints
|
||||
```
|
||||
|
||||
### Workflow 2: Batch inference
|
||||
|
||||
```bash
|
||||
# 1. Launch A10 instance (cost-effective for inference)
|
||||
|
||||
# 2. Run inference
|
||||
python inference.py \
|
||||
--model /lambda/nfs/storage/models/fine-tuned \
|
||||
--input /lambda/nfs/storage/data/inputs.jsonl \
|
||||
--output /lambda/nfs/storage/data/outputs.jsonl
|
||||
```
|
||||
|
||||
## Cost optimization
|
||||
|
||||
### Choose right GPU
|
||||
|
||||
| Task | Recommended GPU |
|
||||
|------|-----------------|
|
||||
| LLM fine-tuning (7B) | A100 40GB |
|
||||
| LLM fine-tuning (70B) | 8x H100 |
|
||||
| Inference | A10, A6000 |
|
||||
| Development | V100, A10 |
|
||||
| Maximum performance | B200 |
|
||||
|
||||
### Reduce costs
|
||||
|
||||
1. **Use filesystems**: Avoid re-downloading data
|
||||
2. **Checkpoint frequently**: Resume interrupted training
|
||||
3. **Right-size**: Don't over-provision GPUs
|
||||
4. **Terminate idle**: No auto-stop, manually terminate
|
||||
|
||||
### Monitor usage
|
||||
|
||||
- Dashboard shows real-time GPU utilization
|
||||
- API for programmatic monitoring
|
||||
|
||||
## Common issues
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| Instance won't launch | Check region availability, try different GPU |
|
||||
| SSH connection refused | Wait for instance to initialize (3-15 min) |
|
||||
| Data lost after terminate | Use persistent filesystems |
|
||||
| Slow data transfer | Use filesystem in same region |
|
||||
| GPU not detected | Reboot instance, check drivers |
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Multi-node training, API automation
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://docs.lambda.ai
|
||||
- **Console**: https://cloud.lambda.ai
|
||||
- **Pricing**: https://lambda.ai/instances
|
||||
- **Support**: https://support.lambdalabs.com
|
||||
- **Blog**: https://lambda.ai/blog
|
||||
611
optional-skills/mlops/lambda-labs/references/advanced-usage.md
Normal file
611
optional-skills/mlops/lambda-labs/references/advanced-usage.md
Normal file
|
|
@ -0,0 +1,611 @@
|
|||
# Lambda Labs Advanced Usage Guide
|
||||
|
||||
## Multi-Node Distributed Training
|
||||
|
||||
### PyTorch DDP across nodes
|
||||
|
||||
```python
|
||||
# train_multi_node.py
|
||||
import os
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
def setup_distributed():
|
||||
# Environment variables set by launcher
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
rank=rank,
|
||||
world_size=world_size
|
||||
)
|
||||
|
||||
torch.cuda.set_device(local_rank)
|
||||
return rank, world_size, local_rank
|
||||
|
||||
def main():
|
||||
rank, world_size, local_rank = setup_distributed()
|
||||
|
||||
model = MyModel().cuda(local_rank)
|
||||
model = DDP(model, device_ids=[local_rank])
|
||||
|
||||
# Training loop with synchronized gradients
|
||||
for epoch in range(num_epochs):
|
||||
train_one_epoch(model, dataloader)
|
||||
|
||||
# Save checkpoint on rank 0 only
|
||||
if rank == 0:
|
||||
torch.save(model.module.state_dict(), f"checkpoint_{epoch}.pt")
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
|
||||
### Launch on multiple instances
|
||||
|
||||
```bash
|
||||
# On Node 0 (master)
|
||||
export MASTER_ADDR=<NODE0_PRIVATE_IP>
|
||||
export MASTER_PORT=29500
|
||||
|
||||
torchrun \
|
||||
--nnodes=2 \
|
||||
--nproc_per_node=8 \
|
||||
--node_rank=0 \
|
||||
--master_addr=$MASTER_ADDR \
|
||||
--master_port=$MASTER_PORT \
|
||||
train_multi_node.py
|
||||
|
||||
# On Node 1
|
||||
export MASTER_ADDR=<NODE0_PRIVATE_IP>
|
||||
export MASTER_PORT=29500
|
||||
|
||||
torchrun \
|
||||
--nnodes=2 \
|
||||
--nproc_per_node=8 \
|
||||
--node_rank=1 \
|
||||
--master_addr=$MASTER_ADDR \
|
||||
--master_port=$MASTER_PORT \
|
||||
train_multi_node.py
|
||||
```
|
||||
|
||||
### FSDP for large models
|
||||
|
||||
```python
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
||||
|
||||
# Wrap policy for transformer models
|
||||
auto_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={LlamaDecoderLayer}
|
||||
)
|
||||
|
||||
model = FSDP(
|
||||
model,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
mixed_precision=MixedPrecision(
|
||||
param_dtype=torch.bfloat16,
|
||||
reduce_dtype=torch.bfloat16,
|
||||
buffer_dtype=torch.bfloat16,
|
||||
),
|
||||
device_id=local_rank,
|
||||
)
|
||||
```
|
||||
|
||||
### DeepSpeed ZeRO
|
||||
|
||||
```python
|
||||
# ds_config.json
|
||||
{
|
||||
"train_batch_size": 64,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"fp16": {"enabled": true},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {"device": "cpu"},
|
||||
"offload_param": {"device": "cpu"}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
# Launch with DeepSpeed
|
||||
deepspeed --num_nodes=2 \
|
||||
--num_gpus=8 \
|
||||
--hostfile=hostfile.txt \
|
||||
train.py --deepspeed ds_config.json
|
||||
```
|
||||
|
||||
### Hostfile for multi-node
|
||||
|
||||
```bash
|
||||
# hostfile.txt
|
||||
node0_ip slots=8
|
||||
node1_ip slots=8
|
||||
```
|
||||
|
||||
## API Automation
|
||||
|
||||
### Auto-launch training jobs
|
||||
|
||||
```python
|
||||
import os
|
||||
import time
|
||||
import lambda_cloud_client
|
||||
from lambda_cloud_client.models import LaunchInstanceRequest
|
||||
|
||||
class LambdaJobManager:
|
||||
def __init__(self, api_key: str):
|
||||
self.config = lambda_cloud_client.Configuration(
|
||||
host="https://cloud.lambdalabs.com/api/v1",
|
||||
access_token=api_key
|
||||
)
|
||||
|
||||
def find_available_gpu(self, gpu_types: list[str], regions: list[str] = None):
|
||||
"""Find first available GPU type across regions."""
|
||||
with lambda_cloud_client.ApiClient(self.config) as client:
|
||||
api = lambda_cloud_client.DefaultApi(client)
|
||||
types = api.instance_types()
|
||||
|
||||
for gpu_type in gpu_types:
|
||||
if gpu_type in types.data:
|
||||
info = types.data[gpu_type]
|
||||
for region in info.regions_with_capacity_available:
|
||||
if regions is None or region.name in regions:
|
||||
return gpu_type, region.name
|
||||
|
||||
return None, None
|
||||
|
||||
def launch_and_wait(self, instance_type: str, region: str,
|
||||
ssh_key: str, filesystem: str = None,
|
||||
timeout: int = 900) -> dict:
|
||||
"""Launch instance and wait for it to be ready."""
|
||||
with lambda_cloud_client.ApiClient(self.config) as client:
|
||||
api = lambda_cloud_client.DefaultApi(client)
|
||||
|
||||
request = LaunchInstanceRequest(
|
||||
region_name=region,
|
||||
instance_type_name=instance_type,
|
||||
ssh_key_names=[ssh_key],
|
||||
file_system_names=[filesystem] if filesystem else [],
|
||||
)
|
||||
|
||||
response = api.launch_instance(request)
|
||||
instance_id = response.data.instance_ids[0]
|
||||
|
||||
# Poll until ready
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
instance = api.get_instance(instance_id)
|
||||
if instance.data.status == "active":
|
||||
return {
|
||||
"id": instance_id,
|
||||
"ip": instance.data.ip,
|
||||
"status": "active"
|
||||
}
|
||||
time.sleep(30)
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} not ready after {timeout}s")
|
||||
|
||||
def terminate(self, instance_ids: list[str]):
|
||||
"""Terminate instances."""
|
||||
from lambda_cloud_client.models import TerminateInstanceRequest
|
||||
|
||||
with lambda_cloud_client.ApiClient(self.config) as client:
|
||||
api = lambda_cloud_client.DefaultApi(client)
|
||||
request = TerminateInstanceRequest(instance_ids=instance_ids)
|
||||
api.terminate_instance(request)
|
||||
|
||||
|
||||
# Usage
|
||||
manager = LambdaJobManager(os.environ["LAMBDA_API_KEY"])
|
||||
|
||||
# Find available H100 or A100
|
||||
gpu_type, region = manager.find_available_gpu(
|
||||
["gpu_8x_h100_sxm5", "gpu_8x_a100_80gb_sxm4"],
|
||||
regions=["us-west-1", "us-east-1"]
|
||||
)
|
||||
|
||||
if gpu_type:
|
||||
instance = manager.launch_and_wait(
|
||||
gpu_type, region,
|
||||
ssh_key="my-key",
|
||||
filesystem="training-data"
|
||||
)
|
||||
print(f"Ready: ssh ubuntu@{instance['ip']}")
|
||||
```
|
||||
|
||||
### Batch job submission
|
||||
|
||||
```python
|
||||
import subprocess
|
||||
import paramiko
|
||||
|
||||
def run_remote_job(ip: str, ssh_key_path: str, commands: list[str]):
|
||||
"""Execute commands on remote instance."""
|
||||
client = paramiko.SSHClient()
|
||||
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
client.connect(ip, username="ubuntu", key_filename=ssh_key_path)
|
||||
|
||||
for cmd in commands:
|
||||
stdin, stdout, stderr = client.exec_command(cmd)
|
||||
print(stdout.read().decode())
|
||||
if stderr.read():
|
||||
print(f"Error: {stderr.read().decode()}")
|
||||
|
||||
client.close()
|
||||
|
||||
# Submit training job
|
||||
commands = [
|
||||
"cd /lambda/nfs/storage/project",
|
||||
"git pull",
|
||||
"pip install -r requirements.txt",
|
||||
"nohup torchrun --nproc_per_node=8 train.py > train.log 2>&1 &"
|
||||
]
|
||||
|
||||
run_remote_job(instance["ip"], "~/.ssh/lambda_key", commands)
|
||||
```
|
||||
|
||||
### Monitor training progress
|
||||
|
||||
```python
|
||||
def monitor_job(ip: str, ssh_key_path: str, log_file: str = "train.log"):
|
||||
"""Stream training logs from remote instance."""
|
||||
import time
|
||||
|
||||
client = paramiko.SSHClient()
|
||||
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
client.connect(ip, username="ubuntu", key_filename=ssh_key_path)
|
||||
|
||||
# Tail log file
|
||||
stdin, stdout, stderr = client.exec_command(f"tail -f {log_file}")
|
||||
|
||||
try:
|
||||
for line in stdout:
|
||||
print(line.strip())
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
client.close()
|
||||
```
|
||||
|
||||
## 1-Click Cluster Workflows
|
||||
|
||||
### Slurm job submission
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=llm-training
|
||||
#SBATCH --nodes=4
|
||||
#SBATCH --ntasks-per-node=8
|
||||
#SBATCH --gpus-per-node=8
|
||||
#SBATCH --time=24:00:00
|
||||
#SBATCH --output=logs/%j.out
|
||||
#SBATCH --error=logs/%j.err
|
||||
|
||||
# Set up distributed environment
|
||||
export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
|
||||
export MASTER_PORT=29500
|
||||
|
||||
# Launch training
|
||||
srun torchrun \
|
||||
--nnodes=$SLURM_NNODES \
|
||||
--nproc_per_node=$SLURM_GPUS_PER_NODE \
|
||||
--rdzv_backend=c10d \
|
||||
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
|
||||
train.py \
|
||||
--config config.yaml
|
||||
```
|
||||
|
||||
### Interactive cluster session
|
||||
|
||||
```bash
|
||||
# Request interactive session
|
||||
srun --nodes=1 --ntasks=1 --gpus=8 --time=4:00:00 --pty bash
|
||||
|
||||
# Now on compute node with 8 GPUs
|
||||
nvidia-smi
|
||||
python train.py
|
||||
```
|
||||
|
||||
### Monitoring cluster jobs
|
||||
|
||||
```bash
|
||||
# View job queue
|
||||
squeue
|
||||
|
||||
# View job details
|
||||
scontrol show job <JOB_ID>
|
||||
|
||||
# Cancel job
|
||||
scancel <JOB_ID>
|
||||
|
||||
# View node status
|
||||
sinfo
|
||||
|
||||
# View GPU usage across cluster
|
||||
srun --nodes=4 nvidia-smi --query-gpu=name,utilization.gpu --format=csv
|
||||
```
|
||||
|
||||
## Advanced Filesystem Usage
|
||||
|
||||
### Data staging workflow
|
||||
|
||||
```bash
|
||||
# Stage data from S3 to filesystem (one-time)
|
||||
aws s3 sync s3://my-bucket/dataset /lambda/nfs/storage/datasets/
|
||||
|
||||
# Or use rclone
|
||||
rclone sync s3:my-bucket/dataset /lambda/nfs/storage/datasets/
|
||||
```
|
||||
|
||||
### Shared filesystem across instances
|
||||
|
||||
```python
|
||||
# Instance 1: Write checkpoints
|
||||
checkpoint_path = "/lambda/nfs/shared/checkpoints/model_step_1000.pt"
|
||||
torch.save(model.state_dict(), checkpoint_path)
|
||||
|
||||
# Instance 2: Read checkpoints
|
||||
model.load_state_dict(torch.load(checkpoint_path))
|
||||
```
|
||||
|
||||
### Filesystem best practices
|
||||
|
||||
```bash
|
||||
# Organize for ML workflows
|
||||
/lambda/nfs/storage/
|
||||
├── datasets/
|
||||
│ ├── raw/ # Original data
|
||||
│ └── processed/ # Preprocessed data
|
||||
├── models/
|
||||
│ ├── pretrained/ # Base models
|
||||
│ └── fine-tuned/ # Your trained models
|
||||
├── checkpoints/
|
||||
│ └── experiment_1/ # Per-experiment checkpoints
|
||||
├── logs/
|
||||
│ └── tensorboard/ # Training logs
|
||||
└── outputs/
|
||||
└── inference/ # Inference results
|
||||
```
|
||||
|
||||
## Environment Management
|
||||
|
||||
### Custom Python environments
|
||||
|
||||
```bash
|
||||
# Don't modify system Python, create venv
|
||||
python -m venv ~/myenv
|
||||
source ~/myenv/bin/activate
|
||||
|
||||
# Install packages
|
||||
pip install torch transformers accelerate
|
||||
|
||||
# Save to filesystem for reuse
|
||||
cp -r ~/myenv /lambda/nfs/storage/envs/myenv
|
||||
```
|
||||
|
||||
### Conda environments
|
||||
|
||||
```bash
|
||||
# Install miniconda (if not present)
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
|
||||
bash Miniconda3-latest-Linux-x86_64.sh -b -p ~/miniconda3
|
||||
|
||||
# Create environment
|
||||
~/miniconda3/bin/conda create -n ml python=3.10 pytorch pytorch-cuda=12.1 -c pytorch -c nvidia -y
|
||||
|
||||
# Activate
|
||||
source ~/miniconda3/bin/activate ml
|
||||
```
|
||||
|
||||
### Docker containers
|
||||
|
||||
```bash
|
||||
# Pull and run NVIDIA container
|
||||
docker run --gpus all -it --rm \
|
||||
-v /lambda/nfs/storage:/data \
|
||||
nvcr.io/nvidia/pytorch:24.01-py3
|
||||
|
||||
# Run training in container
|
||||
docker run --gpus all -d \
|
||||
-v /lambda/nfs/storage:/data \
|
||||
-v $(pwd):/workspace \
|
||||
nvcr.io/nvidia/pytorch:24.01-py3 \
|
||||
python /workspace/train.py
|
||||
```
|
||||
|
||||
## Monitoring and Observability
|
||||
|
||||
### GPU monitoring
|
||||
|
||||
```bash
|
||||
# Real-time GPU stats
|
||||
watch -n 1 nvidia-smi
|
||||
|
||||
# GPU utilization over time
|
||||
nvidia-smi dmon -s u -d 1
|
||||
|
||||
# Detailed GPU info
|
||||
nvidia-smi -q
|
||||
```
|
||||
|
||||
### System monitoring
|
||||
|
||||
```bash
|
||||
# CPU and memory
|
||||
htop
|
||||
|
||||
# Disk I/O
|
||||
iostat -x 1
|
||||
|
||||
# Network
|
||||
iftop
|
||||
|
||||
# All resources
|
||||
glances
|
||||
```
|
||||
|
||||
### TensorBoard integration
|
||||
|
||||
```bash
|
||||
# Start TensorBoard
|
||||
tensorboard --logdir /lambda/nfs/storage/logs --port 6006 --bind_all
|
||||
|
||||
# SSH tunnel from local machine
|
||||
ssh -L 6006:localhost:6006 ubuntu@<IP>
|
||||
|
||||
# Access at http://localhost:6006
|
||||
```
|
||||
|
||||
### Weights & Biases integration
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
# Initialize with API key
|
||||
wandb.login(key=os.environ["WANDB_API_KEY"])
|
||||
|
||||
# Start run
|
||||
wandb.init(
|
||||
project="lambda-training",
|
||||
config={"learning_rate": 1e-4, "epochs": 100}
|
||||
)
|
||||
|
||||
# Log metrics
|
||||
wandb.log({"loss": loss, "accuracy": acc})
|
||||
|
||||
# Save artifacts to filesystem + W&B
|
||||
wandb.save("/lambda/nfs/storage/checkpoints/best_model.pt")
|
||||
```
|
||||
|
||||
## Cost Optimization Strategies
|
||||
|
||||
### Checkpointing for interruption recovery
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
def save_checkpoint(model, optimizer, epoch, loss, path):
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': loss,
|
||||
}, path)
|
||||
|
||||
def load_checkpoint(path, model, optimizer):
|
||||
if os.path.exists(path):
|
||||
checkpoint = torch.load(path)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
return checkpoint['epoch'], checkpoint['loss']
|
||||
return 0, float('inf')
|
||||
|
||||
# Save every N steps to filesystem
|
||||
checkpoint_path = "/lambda/nfs/storage/checkpoints/latest.pt"
|
||||
if step % 1000 == 0:
|
||||
save_checkpoint(model, optimizer, epoch, loss, checkpoint_path)
|
||||
```
|
||||
|
||||
### Instance selection by workload
|
||||
|
||||
```python
|
||||
def recommend_instance(model_params: int, batch_size: int, task: str) -> str:
|
||||
"""Recommend Lambda instance based on workload."""
|
||||
|
||||
if task == "inference":
|
||||
if model_params < 7e9:
|
||||
return "gpu_1x_a10" # $0.75/hr
|
||||
elif model_params < 13e9:
|
||||
return "gpu_1x_a6000" # $0.80/hr
|
||||
else:
|
||||
return "gpu_1x_h100_pcie" # $2.49/hr
|
||||
|
||||
elif task == "fine-tuning":
|
||||
if model_params < 7e9:
|
||||
return "gpu_1x_a100" # $1.29/hr
|
||||
elif model_params < 13e9:
|
||||
return "gpu_4x_a100" # $5.16/hr
|
||||
else:
|
||||
return "gpu_8x_h100_sxm5" # $23.92/hr
|
||||
|
||||
elif task == "pretraining":
|
||||
return "gpu_8x_h100_sxm5" # Maximum performance
|
||||
|
||||
return "gpu_1x_a100" # Default
|
||||
```
|
||||
|
||||
### Auto-terminate idle instances
|
||||
|
||||
```python
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
def auto_terminate_idle(api_key: str, idle_threshold_hours: float = 2):
|
||||
"""Terminate instances idle for too long."""
|
||||
manager = LambdaJobManager(api_key)
|
||||
|
||||
with lambda_cloud_client.ApiClient(manager.config) as client:
|
||||
api = lambda_cloud_client.DefaultApi(client)
|
||||
instances = api.list_instances()
|
||||
|
||||
for instance in instances.data:
|
||||
# Check if instance has been running without activity
|
||||
# (You'd need to track this separately)
|
||||
launch_time = instance.launched_at
|
||||
if datetime.now() - launch_time > timedelta(hours=idle_threshold_hours):
|
||||
print(f"Terminating idle instance: {instance.id}")
|
||||
manager.terminate([instance.id])
|
||||
```
|
||||
|
||||
## Security Best Practices
|
||||
|
||||
### SSH key rotation
|
||||
|
||||
```bash
|
||||
# Generate new key pair
|
||||
ssh-keygen -t ed25519 -f ~/.ssh/lambda_key_new -C "lambda-$(date +%Y%m)"
|
||||
|
||||
# Add new key via Lambda console or API
|
||||
# Update authorized_keys on running instances
|
||||
ssh ubuntu@<IP> "echo '$(cat ~/.ssh/lambda_key_new.pub)' >> ~/.ssh/authorized_keys"
|
||||
|
||||
# Test new key
|
||||
ssh -i ~/.ssh/lambda_key_new ubuntu@<IP>
|
||||
|
||||
# Remove old key from Lambda console
|
||||
```
|
||||
|
||||
### Firewall configuration
|
||||
|
||||
```bash
|
||||
# Lambda console: Only open necessary ports
|
||||
# Recommended:
|
||||
# - 22 (SSH) - Always needed
|
||||
# - 6006 (TensorBoard) - If using
|
||||
# - 8888 (Jupyter) - If using
|
||||
# - 29500 (PyTorch distributed) - For multi-node only
|
||||
```
|
||||
|
||||
### Secrets management
|
||||
|
||||
```bash
|
||||
# Don't hardcode API keys in code
|
||||
# Use environment variables
|
||||
export HF_TOKEN="hf_..."
|
||||
export WANDB_API_KEY="..."
|
||||
|
||||
# Or use .env file (add to .gitignore)
|
||||
source .env
|
||||
|
||||
# On instance, store in ~/.bashrc
|
||||
echo 'export HF_TOKEN="..."' >> ~/.bashrc
|
||||
```
|
||||
530
optional-skills/mlops/lambda-labs/references/troubleshooting.md
Normal file
530
optional-skills/mlops/lambda-labs/references/troubleshooting.md
Normal file
|
|
@ -0,0 +1,530 @@
|
|||
# Lambda Labs Troubleshooting Guide
|
||||
|
||||
## Instance Launch Issues
|
||||
|
||||
### No instances available
|
||||
|
||||
**Error**: "No capacity available" or instance type not listed
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check availability via API
|
||||
curl -u $LAMBDA_API_KEY: \
|
||||
https://cloud.lambdalabs.com/api/v1/instance-types | jq '.data | to_entries[] | select(.value.regions_with_capacity_available | length > 0) | .key'
|
||||
|
||||
# Try different regions
|
||||
# US regions: us-west-1, us-east-1, us-south-1
|
||||
# International: eu-west-1, asia-northeast-1, etc.
|
||||
|
||||
# Try alternative GPU types
|
||||
# H100 not available? Try A100
|
||||
# A100 not available? Try A10 or A6000
|
||||
```
|
||||
|
||||
### Instance stuck launching
|
||||
|
||||
**Problem**: Instance shows "booting" for over 20 minutes
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Single-GPU: Should be ready in 3-5 minutes
|
||||
# Multi-GPU (8x): May take 10-15 minutes
|
||||
|
||||
# If stuck longer:
|
||||
# 1. Terminate the instance
|
||||
# 2. Try a different region
|
||||
# 3. Try a different instance type
|
||||
# 4. Contact Lambda support if persistent
|
||||
```
|
||||
|
||||
### API authentication fails
|
||||
|
||||
**Error**: `401 Unauthorized` or `403 Forbidden`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Verify API key format (should start with specific prefix)
|
||||
echo $LAMBDA_API_KEY
|
||||
|
||||
# Test API key
|
||||
curl -u $LAMBDA_API_KEY: \
|
||||
https://cloud.lambdalabs.com/api/v1/instance-types
|
||||
|
||||
# Generate new API key from Lambda console if needed
|
||||
# Settings > API keys > Generate
|
||||
```
|
||||
|
||||
### Quota limits reached
|
||||
|
||||
**Error**: "Instance limit reached" or "Quota exceeded"
|
||||
|
||||
**Solutions**:
|
||||
- Check current running instances in console
|
||||
- Terminate unused instances
|
||||
- Contact Lambda support to request quota increase
|
||||
- Use 1-Click Clusters for large-scale needs
|
||||
|
||||
## SSH Connection Issues
|
||||
|
||||
### Connection refused
|
||||
|
||||
**Error**: `ssh: connect to host <IP> port 22: Connection refused`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Wait for instance to fully initialize
|
||||
# Single-GPU: 3-5 minutes
|
||||
# Multi-GPU: 10-15 minutes
|
||||
|
||||
# Check instance status in console (should be "active")
|
||||
|
||||
# Verify correct IP address
|
||||
curl -u $LAMBDA_API_KEY: \
|
||||
https://cloud.lambdalabs.com/api/v1/instances | jq '.data[].ip'
|
||||
```
|
||||
|
||||
### Permission denied
|
||||
|
||||
**Error**: `Permission denied (publickey)`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Verify SSH key matches
|
||||
ssh -v -i ~/.ssh/lambda_key ubuntu@<IP>
|
||||
|
||||
# Check key permissions
|
||||
chmod 600 ~/.ssh/lambda_key
|
||||
chmod 644 ~/.ssh/lambda_key.pub
|
||||
|
||||
# Verify key was added to Lambda console before launch
|
||||
# Keys must be added BEFORE launching instance
|
||||
|
||||
# Check authorized_keys on instance (if you have another way in)
|
||||
cat ~/.ssh/authorized_keys
|
||||
```
|
||||
|
||||
### Host key verification failed
|
||||
|
||||
**Error**: `WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED!`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# This happens when IP is reused by different instance
|
||||
# Remove old key
|
||||
ssh-keygen -R <IP>
|
||||
|
||||
# Then connect again
|
||||
ssh ubuntu@<IP>
|
||||
```
|
||||
|
||||
### Timeout during SSH
|
||||
|
||||
**Error**: `ssh: connect to host <IP> port 22: Operation timed out`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check if instance is in "active" state
|
||||
|
||||
# Verify firewall allows SSH (port 22)
|
||||
# Lambda console > Firewall
|
||||
|
||||
# Check your local network allows outbound SSH
|
||||
|
||||
# Try from different network/VPN
|
||||
```
|
||||
|
||||
## GPU Issues
|
||||
|
||||
### GPU not detected
|
||||
|
||||
**Error**: `nvidia-smi: command not found` or no GPUs shown
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Reboot instance
|
||||
sudo reboot
|
||||
|
||||
# Reinstall NVIDIA drivers (if needed)
|
||||
wget -nv -O- https://lambdalabs.com/install-lambda-stack.sh | sh -
|
||||
sudo reboot
|
||||
|
||||
# Check driver status
|
||||
nvidia-smi
|
||||
lsmod | grep nvidia
|
||||
```
|
||||
|
||||
### CUDA out of memory
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check GPU memory
|
||||
import torch
|
||||
print(torch.cuda.get_device_properties(0).total_memory / 1e9, "GB")
|
||||
|
||||
# Clear cache
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Reduce batch size
|
||||
batch_size = batch_size // 2
|
||||
|
||||
# Enable gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Use mixed precision
|
||||
from torch.cuda.amp import autocast
|
||||
with autocast():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Use larger GPU instance
|
||||
# A100-40GB → A100-80GB → H100
|
||||
```
|
||||
|
||||
### CUDA version mismatch
|
||||
|
||||
**Error**: `CUDA driver version is insufficient for CUDA runtime version`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check versions
|
||||
nvidia-smi # Shows driver CUDA version
|
||||
nvcc --version # Shows toolkit version
|
||||
|
||||
# Lambda Stack should have compatible versions
|
||||
# If mismatch, reinstall Lambda Stack
|
||||
wget -nv -O- https://lambdalabs.com/install-lambda-stack.sh | sh -
|
||||
sudo reboot
|
||||
|
||||
# Or install specific PyTorch version
|
||||
pip install torch==2.1.0+cu121 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
```
|
||||
|
||||
### Multi-GPU not working
|
||||
|
||||
**Error**: Only one GPU being used
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check all GPUs visible
|
||||
import torch
|
||||
print(f"GPUs available: {torch.cuda.device_count()}")
|
||||
|
||||
# Verify CUDA_VISIBLE_DEVICES not set restrictively
|
||||
import os
|
||||
print(os.environ.get("CUDA_VISIBLE_DEVICES", "not set"))
|
||||
|
||||
# Use DataParallel or DistributedDataParallel
|
||||
model = torch.nn.DataParallel(model)
|
||||
# or
|
||||
model = torch.nn.parallel.DistributedDataParallel(model)
|
||||
```
|
||||
|
||||
## Filesystem Issues
|
||||
|
||||
### Filesystem not mounted
|
||||
|
||||
**Error**: `/lambda/nfs/<name>` doesn't exist
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Filesystem must be attached at launch time
|
||||
# Cannot attach to running instance
|
||||
|
||||
# Verify filesystem was selected during launch
|
||||
|
||||
# Check mount points
|
||||
df -h | grep lambda
|
||||
|
||||
# If missing, terminate and relaunch with filesystem
|
||||
```
|
||||
|
||||
### Slow filesystem performance
|
||||
|
||||
**Problem**: Reading/writing to filesystem is slow
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Use local SSD for temporary/intermediate files
|
||||
# /home/ubuntu has fast NVMe storage
|
||||
|
||||
# Copy frequently accessed data to local storage
|
||||
cp -r /lambda/nfs/storage/dataset /home/ubuntu/dataset
|
||||
|
||||
# Use filesystem for checkpoints and final outputs only
|
||||
|
||||
# Check network bandwidth
|
||||
iperf3 -c <filesystem_server>
|
||||
```
|
||||
|
||||
### Data lost after termination
|
||||
|
||||
**Problem**: Files disappeared after instance terminated
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Root volume (/home/ubuntu) is EPHEMERAL
|
||||
# Data there is lost on termination
|
||||
|
||||
# ALWAYS use filesystem for persistent data
|
||||
/lambda/nfs/<filesystem_name>/
|
||||
|
||||
# Sync important local files before terminating
|
||||
rsync -av /home/ubuntu/outputs/ /lambda/nfs/storage/outputs/
|
||||
```
|
||||
|
||||
### Filesystem full
|
||||
|
||||
**Error**: `No space left on device`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check filesystem usage
|
||||
df -h /lambda/nfs/storage
|
||||
|
||||
# Find large files
|
||||
du -sh /lambda/nfs/storage/* | sort -h
|
||||
|
||||
# Clean up old checkpoints
|
||||
find /lambda/nfs/storage/checkpoints -mtime +7 -delete
|
||||
|
||||
# Increase filesystem size in Lambda console
|
||||
# (may require support request)
|
||||
```
|
||||
|
||||
## Network Issues
|
||||
|
||||
### Port not accessible
|
||||
|
||||
**Error**: Cannot connect to service (TensorBoard, Jupyter, etc.)
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Lambda default: Only port 22 is open
|
||||
# Configure firewall in Lambda console
|
||||
|
||||
# Or use SSH tunneling (recommended)
|
||||
ssh -L 6006:localhost:6006 ubuntu@<IP>
|
||||
# Access at http://localhost:6006
|
||||
|
||||
# For Jupyter
|
||||
ssh -L 8888:localhost:8888 ubuntu@<IP>
|
||||
```
|
||||
|
||||
### Slow data download
|
||||
|
||||
**Problem**: Downloading datasets is slow
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check available bandwidth
|
||||
speedtest-cli
|
||||
|
||||
# Use multi-threaded download
|
||||
aria2c -x 16 <URL>
|
||||
|
||||
# For HuggingFace models
|
||||
export HF_HUB_ENABLE_HF_TRANSFER=1
|
||||
pip install hf_transfer
|
||||
|
||||
# For S3, use parallel transfer
|
||||
aws s3 sync s3://bucket/data /local/data --quiet
|
||||
```
|
||||
|
||||
### Inter-node communication fails
|
||||
|
||||
**Error**: Distributed training can't connect between nodes
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Verify nodes in same region (required)
|
||||
|
||||
# Check private IPs can communicate
|
||||
ping <other_node_private_ip>
|
||||
|
||||
# Verify NCCL settings
|
||||
export NCCL_DEBUG=INFO
|
||||
export NCCL_IB_DISABLE=0 # Enable InfiniBand if available
|
||||
|
||||
# Check firewall allows distributed ports
|
||||
# Need: 29500 (PyTorch), or configured MASTER_PORT
|
||||
```
|
||||
|
||||
## Software Issues
|
||||
|
||||
### Package installation fails
|
||||
|
||||
**Error**: `pip install` errors
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Use virtual environment (don't modify system Python)
|
||||
python -m venv ~/myenv
|
||||
source ~/myenv/bin/activate
|
||||
pip install <package>
|
||||
|
||||
# For CUDA packages, match CUDA version
|
||||
pip install torch --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# Clear pip cache if corrupted
|
||||
pip cache purge
|
||||
```
|
||||
|
||||
### Python version issues
|
||||
|
||||
**Error**: Package requires different Python version
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install alternate Python (don't replace system Python)
|
||||
sudo apt install python3.11 python3.11-venv python3.11-dev
|
||||
|
||||
# Create venv with specific Python
|
||||
python3.11 -m venv ~/py311env
|
||||
source ~/py311env/bin/activate
|
||||
```
|
||||
|
||||
### ImportError or ModuleNotFoundError
|
||||
|
||||
**Error**: Module not found despite installation
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Verify correct Python environment
|
||||
which python
|
||||
pip list | grep <module>
|
||||
|
||||
# Ensure virtual environment is activated
|
||||
source ~/myenv/bin/activate
|
||||
|
||||
# Reinstall in correct environment
|
||||
pip uninstall <package>
|
||||
pip install <package>
|
||||
```
|
||||
|
||||
## Training Issues
|
||||
|
||||
### Training hangs
|
||||
|
||||
**Problem**: Training stops progressing, no output
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check GPU utilization
|
||||
watch -n 1 nvidia-smi
|
||||
|
||||
# If GPUs at 0%, likely data loading bottleneck
|
||||
# Increase num_workers in DataLoader
|
||||
|
||||
# Check for deadlocks in distributed training
|
||||
export NCCL_DEBUG=INFO
|
||||
|
||||
# Add timeouts
|
||||
dist.init_process_group(..., timeout=timedelta(minutes=30))
|
||||
```
|
||||
|
||||
### Checkpoint corruption
|
||||
|
||||
**Error**: `RuntimeError: storage has wrong size` or similar
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use safe saving pattern
|
||||
checkpoint_path = "/lambda/nfs/storage/checkpoint.pt"
|
||||
temp_path = checkpoint_path + ".tmp"
|
||||
|
||||
# Save to temp first
|
||||
torch.save(state_dict, temp_path)
|
||||
# Then atomic rename
|
||||
os.rename(temp_path, checkpoint_path)
|
||||
|
||||
# For loading corrupted checkpoint
|
||||
try:
|
||||
state = torch.load(checkpoint_path)
|
||||
except:
|
||||
# Fall back to previous checkpoint
|
||||
state = torch.load(checkpoint_path + ".backup")
|
||||
```
|
||||
|
||||
### Memory leak
|
||||
|
||||
**Problem**: Memory usage grows over time
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Clear CUDA cache periodically
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Detach tensors when logging
|
||||
loss_value = loss.detach().cpu().item()
|
||||
|
||||
# Don't accumulate gradients unintentionally
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Use gradient accumulation properly
|
||||
if (step + 1) % accumulation_steps == 0:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
## Billing Issues
|
||||
|
||||
### Unexpected charges
|
||||
|
||||
**Problem**: Bill higher than expected
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check for forgotten running instances
|
||||
curl -u $LAMBDA_API_KEY: \
|
||||
https://cloud.lambdalabs.com/api/v1/instances | jq '.data[].id'
|
||||
|
||||
# Terminate all instances
|
||||
# Lambda console > Instances > Terminate all
|
||||
|
||||
# Lambda charges by the minute
|
||||
# No charge for stopped instances (but no "stop" feature - only terminate)
|
||||
```
|
||||
|
||||
### Instance terminated unexpectedly
|
||||
|
||||
**Problem**: Instance disappeared without manual termination
|
||||
|
||||
**Possible causes**:
|
||||
- Payment issue (card declined)
|
||||
- Account suspension
|
||||
- Instance health check failure
|
||||
|
||||
**Solutions**:
|
||||
- Check email for Lambda notifications
|
||||
- Verify payment method in console
|
||||
- Contact Lambda support
|
||||
- Always checkpoint to filesystem
|
||||
|
||||
## Common Error Messages
|
||||
|
||||
| Error | Cause | Solution |
|
||||
|-------|-------|----------|
|
||||
| `No capacity available` | Region/GPU sold out | Try different region or GPU type |
|
||||
| `Permission denied (publickey)` | SSH key mismatch | Re-add key, check permissions |
|
||||
| `CUDA out of memory` | Model too large | Reduce batch size, use larger GPU |
|
||||
| `No space left on device` | Disk full | Clean up or use filesystem |
|
||||
| `Connection refused` | Instance not ready | Wait 3-15 minutes for boot |
|
||||
| `Module not found` | Wrong Python env | Activate correct virtualenv |
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **Documentation**: https://docs.lambda.ai
|
||||
2. **Support**: https://support.lambdalabs.com
|
||||
3. **Email**: support@lambdalabs.com
|
||||
4. **Status**: Check Lambda status page for outages
|
||||
|
||||
### Information to Include
|
||||
|
||||
When contacting support, include:
|
||||
- Instance ID
|
||||
- Region
|
||||
- Instance type
|
||||
- Error message (full traceback)
|
||||
- Steps to reproduce
|
||||
- Time of occurrence
|
||||
307
optional-skills/mlops/llava/SKILL.md
Normal file
307
optional-skills/mlops/llava/SKILL.md
Normal file
|
|
@ -0,0 +1,307 @@
|
|||
---
|
||||
name: llava
|
||||
description: Large Language and Vision Assistant. Enables visual instruction tuning and image-based conversations. Combines CLIP vision encoder with Vicuna/LLaMA language models. Supports multi-turn image chat, visual question answering, and instruction following. Use for vision-language chatbots or image understanding tasks. Best for conversational image analysis.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [transformers, torch, pillow]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [LLaVA, Vision-Language, Multimodal, Visual Question Answering, Image Chat, CLIP, Vicuna, Conversational AI, Instruction Tuning, VQA]
|
||||
|
||||
---
|
||||
|
||||
# LLaVA - Large Language and Vision Assistant
|
||||
|
||||
Open-source vision-language model for conversational image understanding.
|
||||
|
||||
## When to use LLaVA
|
||||
|
||||
**Use when:**
|
||||
- Building vision-language chatbots
|
||||
- Visual question answering (VQA)
|
||||
- Image description and captioning
|
||||
- Multi-turn image conversations
|
||||
- Visual instruction following
|
||||
- Document understanding with images
|
||||
|
||||
**Metrics**:
|
||||
- **23,000+ GitHub stars**
|
||||
- GPT-4V level capabilities (targeted)
|
||||
- Apache 2.0 License
|
||||
- Multiple model sizes (7B-34B params)
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **GPT-4V**: Highest quality, API-based
|
||||
- **CLIP**: Simple zero-shot classification
|
||||
- **BLIP-2**: Better for captioning only
|
||||
- **Flamingo**: Research, not open-source
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Clone repository
|
||||
git clone https://github.com/haotian-liu/LLaVA
|
||||
cd LLaVA
|
||||
|
||||
# Install
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Basic usage
|
||||
|
||||
```python
|
||||
from llava.model.builder import load_pretrained_model
|
||||
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
|
||||
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
||||
from llava.conversation import conv_templates
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
# Load model
|
||||
model_path = "liuhaotian/llava-v1.5-7b"
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
||||
model_path=model_path,
|
||||
model_base=None,
|
||||
model_name=get_model_name_from_path(model_path)
|
||||
)
|
||||
|
||||
# Load image
|
||||
image = Image.open("image.jpg")
|
||||
image_tensor = process_images([image], image_processor, model.config)
|
||||
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
|
||||
|
||||
# Create conversation
|
||||
conv = conv_templates["llava_v1"].copy()
|
||||
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\nWhat is in this image?")
|
||||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
# Generate response
|
||||
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
output_ids = model.generate(
|
||||
input_ids,
|
||||
images=image_tensor,
|
||||
do_sample=True,
|
||||
temperature=0.2,
|
||||
max_new_tokens=512
|
||||
)
|
||||
|
||||
response = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
|
||||
print(response)
|
||||
```
|
||||
|
||||
## Available models
|
||||
|
||||
| Model | Parameters | VRAM | Quality |
|
||||
|-------|------------|------|---------|
|
||||
| LLaVA-v1.5-7B | 7B | ~14 GB | Good |
|
||||
| LLaVA-v1.5-13B | 13B | ~28 GB | Better |
|
||||
| LLaVA-v1.6-34B | 34B | ~70 GB | Best |
|
||||
|
||||
```python
|
||||
# Load different models
|
||||
model_7b = "liuhaotian/llava-v1.5-7b"
|
||||
model_13b = "liuhaotian/llava-v1.5-13b"
|
||||
model_34b = "liuhaotian/llava-v1.6-34b"
|
||||
|
||||
# 4-bit quantization for lower VRAM
|
||||
load_4bit = True # Reduces VRAM by ~4×
|
||||
```
|
||||
|
||||
## CLI usage
|
||||
|
||||
```bash
|
||||
# Single image query
|
||||
python -m llava.serve.cli \
|
||||
--model-path liuhaotian/llava-v1.5-7b \
|
||||
--image-file image.jpg \
|
||||
--query "What is in this image?"
|
||||
|
||||
# Multi-turn conversation
|
||||
python -m llava.serve.cli \
|
||||
--model-path liuhaotian/llava-v1.5-7b \
|
||||
--image-file image.jpg
|
||||
# Then type questions interactively
|
||||
```
|
||||
|
||||
## Web UI (Gradio)
|
||||
|
||||
```bash
|
||||
# Launch Gradio interface
|
||||
python -m llava.serve.gradio_web_server \
|
||||
--model-path liuhaotian/llava-v1.5-7b \
|
||||
--load-4bit # Optional: reduce VRAM
|
||||
|
||||
# Access at http://localhost:7860
|
||||
```
|
||||
|
||||
## Multi-turn conversations
|
||||
|
||||
```python
|
||||
# Initialize conversation
|
||||
conv = conv_templates["llava_v1"].copy()
|
||||
|
||||
# Turn 1
|
||||
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\nWhat is in this image?")
|
||||
conv.append_message(conv.roles[1], None)
|
||||
response1 = generate(conv, model, image) # "A dog playing in a park"
|
||||
|
||||
# Turn 2
|
||||
conv.messages[-1][1] = response1 # Add previous response
|
||||
conv.append_message(conv.roles[0], "What breed is the dog?")
|
||||
conv.append_message(conv.roles[1], None)
|
||||
response2 = generate(conv, model, image) # "Golden Retriever"
|
||||
|
||||
# Turn 3
|
||||
conv.messages[-1][1] = response2
|
||||
conv.append_message(conv.roles[0], "What time of day is it?")
|
||||
conv.append_message(conv.roles[1], None)
|
||||
response3 = generate(conv, model, image)
|
||||
```
|
||||
|
||||
## Common tasks
|
||||
|
||||
### Image captioning
|
||||
|
||||
```python
|
||||
question = "Describe this image in detail."
|
||||
response = ask(model, image, question)
|
||||
```
|
||||
|
||||
### Visual question answering
|
||||
|
||||
```python
|
||||
question = "How many people are in the image?"
|
||||
response = ask(model, image, question)
|
||||
```
|
||||
|
||||
### Object detection (textual)
|
||||
|
||||
```python
|
||||
question = "List all the objects you can see in this image."
|
||||
response = ask(model, image, question)
|
||||
```
|
||||
|
||||
### Scene understanding
|
||||
|
||||
```python
|
||||
question = "What is happening in this scene?"
|
||||
response = ask(model, image, question)
|
||||
```
|
||||
|
||||
### Document understanding
|
||||
|
||||
```python
|
||||
question = "What is the main topic of this document?"
|
||||
response = ask(model, document_image, question)
|
||||
```
|
||||
|
||||
## Training custom model
|
||||
|
||||
```bash
|
||||
# Stage 1: Feature alignment (558K image-caption pairs)
|
||||
bash scripts/v1_5/pretrain.sh
|
||||
|
||||
# Stage 2: Visual instruction tuning (150K instruction data)
|
||||
bash scripts/v1_5/finetune.sh
|
||||
```
|
||||
|
||||
## Quantization (reduce VRAM)
|
||||
|
||||
```python
|
||||
# 4-bit quantization
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
||||
model_path="liuhaotian/llava-v1.5-13b",
|
||||
model_base=None,
|
||||
model_name=get_model_name_from_path("liuhaotian/llava-v1.5-13b"),
|
||||
load_4bit=True # Reduces VRAM ~4×
|
||||
)
|
||||
|
||||
# 8-bit quantization
|
||||
load_8bit=True # Reduces VRAM ~2×
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Start with 7B model** - Good quality, manageable VRAM
|
||||
2. **Use 4-bit quantization** - Reduces VRAM significantly
|
||||
3. **GPU required** - CPU inference extremely slow
|
||||
4. **Clear prompts** - Specific questions get better answers
|
||||
5. **Multi-turn conversations** - Maintain conversation context
|
||||
6. **Temperature 0.2-0.7** - Balance creativity/consistency
|
||||
7. **max_new_tokens 512-1024** - For detailed responses
|
||||
8. **Batch processing** - Process multiple images sequentially
|
||||
|
||||
## Performance
|
||||
|
||||
| Model | VRAM (FP16) | VRAM (4-bit) | Speed (tokens/s) |
|
||||
|-------|-------------|--------------|------------------|
|
||||
| 7B | ~14 GB | ~4 GB | ~20 |
|
||||
| 13B | ~28 GB | ~8 GB | ~12 |
|
||||
| 34B | ~70 GB | ~18 GB | ~5 |
|
||||
|
||||
*On A100 GPU*
|
||||
|
||||
## Benchmarks
|
||||
|
||||
LLaVA achieves competitive scores on:
|
||||
- **VQAv2**: 78.5%
|
||||
- **GQA**: 62.0%
|
||||
- **MM-Vet**: 35.4%
|
||||
- **MMBench**: 64.3%
|
||||
|
||||
## Limitations
|
||||
|
||||
1. **Hallucinations** - May describe things not in image
|
||||
2. **Spatial reasoning** - Struggles with precise locations
|
||||
3. **Small text** - Difficulty reading fine print
|
||||
4. **Object counting** - Imprecise for many objects
|
||||
5. **VRAM requirements** - Need powerful GPU
|
||||
6. **Inference speed** - Slower than CLIP
|
||||
|
||||
## Integration with frameworks
|
||||
|
||||
### LangChain
|
||||
|
||||
```python
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
class LLaVALLM(LLM):
|
||||
def _call(self, prompt, stop=None):
|
||||
# Custom LLaVA inference
|
||||
return response
|
||||
|
||||
llm = LLaVALLM()
|
||||
```
|
||||
|
||||
### Gradio App
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
def chat(image, text, history):
|
||||
response = ask_llava(model, image, text)
|
||||
return response
|
||||
|
||||
demo = gr.ChatInterface(
|
||||
chat,
|
||||
additional_inputs=[gr.Image(type="pil")],
|
||||
title="LLaVA Chat"
|
||||
)
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/haotian-liu/LLaVA ⭐ 23,000+
|
||||
- **Paper**: https://arxiv.org/abs/2304.08485
|
||||
- **Demo**: https://llava.hliu.cc
|
||||
- **Models**: https://huggingface.co/liuhaotian
|
||||
- **License**: Apache 2.0
|
||||
|
||||
|
||||
197
optional-skills/mlops/llava/references/training.md
Normal file
197
optional-skills/mlops/llava/references/training.md
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
# LLaVA Training Guide
|
||||
|
||||
Guide to training and fine-tuning LLaVA models.
|
||||
|
||||
## Training stages
|
||||
|
||||
### Stage 1: Feature alignment (Pretraining)
|
||||
|
||||
**Purpose**: Align vision encoder with language model
|
||||
|
||||
**Data**: 558K image-caption pairs (CC3M subset)
|
||||
|
||||
```bash
|
||||
# Download pretrained projector or train from scratch
|
||||
bash scripts/v1_5/pretrain.sh
|
||||
```
|
||||
|
||||
**Configuration:**
|
||||
- Base model: Vicuna-7B or LLaMA-2-7B
|
||||
- Vision encoder: CLIP ViT-L/14
|
||||
- Training time: ~20 hours on 8× A100
|
||||
|
||||
### Stage 2: Visual instruction tuning
|
||||
|
||||
**Purpose**: Teach model to follow visual instructions
|
||||
|
||||
**Data**: 150K GPT-generated multimodal instruction data
|
||||
|
||||
```bash
|
||||
# Fine-tune with instruction data
|
||||
bash scripts/v1_5/finetune.sh
|
||||
```
|
||||
|
||||
**Configuration:**
|
||||
- Epochs: 1
|
||||
- Batch size: 128 (across 8 GPUs)
|
||||
- Learning rate: 2e-5
|
||||
- Training time: ~24 hours on 8× A100
|
||||
|
||||
## Data format
|
||||
|
||||
### Instruction data format
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"id": "001",
|
||||
"image": "path/to/image.jpg",
|
||||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "<image>\nWhat is in this image?"
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "The image shows a dog playing in a park."
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "What breed is the dog?"
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "It appears to be a Golden Retriever."
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
## Fine-tuning on custom data
|
||||
|
||||
### Prepare your data
|
||||
|
||||
```python
|
||||
import json
|
||||
|
||||
# Create instruction data
|
||||
data = []
|
||||
for image_path, qa_pairs in your_dataset:
|
||||
conversations = []
|
||||
for q, a in qa_pairs:
|
||||
conversations.append({"from": "human", "value": f"<image>\n{q}"})
|
||||
conversations.append({"from": "gpt", "value": a})
|
||||
|
||||
data.append({
|
||||
"id": str(len(data)),
|
||||
"image": image_path,
|
||||
"conversations": conversations
|
||||
})
|
||||
|
||||
# Save
|
||||
with open("custom_data.json", "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
```
|
||||
|
||||
### Fine-tune script
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
|
||||
# Set paths
|
||||
DATA_PATH="custom_data.json"
|
||||
IMAGE_FOLDER="path/to/images"
|
||||
MODEL_PATH="liuhaotian/llava-v1.5-7b"
|
||||
OUTPUT_DIR="./checkpoints/llava-custom"
|
||||
|
||||
# Fine-tune
|
||||
deepspeed llava/train/train_mem.py \
|
||||
--deepspeed ./scripts/zero2.json \
|
||||
--model_name_or_path $MODEL_PATH \
|
||||
--version v1 \
|
||||
--data_path $DATA_PATH \
|
||||
--image_folder $IMAGE_FOLDER \
|
||||
--vision_tower openai/clip-vit-large-patch14-336 \
|
||||
--mm_projector_type mlp2x_gelu \
|
||||
--mm_vision_select_layer -2 \
|
||||
--mm_use_im_start_end False \
|
||||
--mm_use_im_patch_token False \
|
||||
--image_aspect_ratio pad \
|
||||
--group_by_modality_length True \
|
||||
--bf16 True \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--num_train_epochs 1 \
|
||||
--per_device_train_batch_size 16 \
|
||||
--per_device_eval_batch_size 4 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--evaluation_strategy "no" \
|
||||
--save_strategy "steps" \
|
||||
--save_steps 50000 \
|
||||
--save_total_limit 1 \
|
||||
--learning_rate 2e-5 \
|
||||
--weight_decay 0. \
|
||||
--warmup_ratio 0.03 \
|
||||
--lr_scheduler_type "cosine" \
|
||||
--logging_steps 1 \
|
||||
--tf32 True \
|
||||
--model_max_length 2048 \
|
||||
--gradient_checkpointing True \
|
||||
--dataloader_num_workers 4 \
|
||||
--lazy_preprocess True \
|
||||
--report_to wandb
|
||||
```
|
||||
|
||||
## LoRA fine-tuning (memory efficient)
|
||||
|
||||
```python
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
# LoRA config
|
||||
lora_config = LoraConfig(
|
||||
r=8, # LoRA rank
|
||||
lora_alpha=16,
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
model = get_peft_model(base_model, lora_config)
|
||||
|
||||
# Train with much lower memory
|
||||
```
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
### Full fine-tuning
|
||||
|
||||
- **7B model**: 8× A100 (40GB)
|
||||
- **13B model**: 8× A100 (80GB)
|
||||
- **Training time**: 20-48 hours
|
||||
|
||||
### LoRA fine-tuning
|
||||
|
||||
- **7B model**: 1× A100 (40GB)
|
||||
- **13B model**: 2× A100 (40GB)
|
||||
- **Training time**: 10-24 hours
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Start with pretrained** - Don't train from scratch
|
||||
2. **Use LoRA for efficiency** - 10× less memory
|
||||
3. **Quality over quantity** - 1K high-quality > 10K low-quality
|
||||
4. **Multi-turn conversations** - More engaging than single Q&A
|
||||
5. **Diverse images** - Cover different scenarios
|
||||
6. **Clear instructions** - Specific questions get better answers
|
||||
7. **Monitor loss** - Should decrease smoothly
|
||||
8. **Save checkpoints** - Training can fail
|
||||
9. **Test regularly** - Validate on held-out set
|
||||
10. **Use DeepSpeed** - For multi-GPU training
|
||||
|
||||
## Resources
|
||||
|
||||
- **Training script**: https://github.com/haotian-liu/LLaVA/tree/main/scripts
|
||||
- **Data format**: https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md
|
||||
- **Paper**: https://arxiv.org/abs/2304.08485
|
||||
386
optional-skills/mlops/nemo-curator/SKILL.md
Normal file
386
optional-skills/mlops/nemo-curator/SKILL.md
Normal file
|
|
@ -0,0 +1,386 @@
|
|||
---
|
||||
name: nemo-curator
|
||||
description: GPU-accelerated data curation for LLM training. Supports text/image/video/audio. Features fuzzy deduplication (16× faster), quality filtering (30+ heuristics), semantic deduplication, PII redaction, NSFW detection. Scales across GPUs with RAPIDS. Use for preparing high-quality training datasets, cleaning web data, or deduplicating large corpora.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [nemo-curator, cudf, dask, rapids]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Data Processing, NeMo Curator, Data Curation, GPU Acceleration, Deduplication, Quality Filtering, NVIDIA, RAPIDS, PII Redaction, Multimodal, LLM Training Data]
|
||||
|
||||
---
|
||||
|
||||
# NeMo Curator - GPU-Accelerated Data Curation
|
||||
|
||||
NVIDIA's toolkit for preparing high-quality training data for LLMs.
|
||||
|
||||
## When to use NeMo Curator
|
||||
|
||||
**Use NeMo Curator when:**
|
||||
- Preparing LLM training data from web scrapes (Common Crawl)
|
||||
- Need fast deduplication (16× faster than CPU)
|
||||
- Curating multi-modal datasets (text, images, video, audio)
|
||||
- Filtering low-quality or toxic content
|
||||
- Scaling data processing across GPU cluster
|
||||
|
||||
**Performance**:
|
||||
- **16× faster** fuzzy deduplication (8TB RedPajama v2)
|
||||
- **40% lower TCO** vs CPU alternatives
|
||||
- **Near-linear scaling** across GPU nodes
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **datatrove**: CPU-based, open-source data processing
|
||||
- **dolma**: Allen AI's data toolkit
|
||||
- **Ray Data**: General ML data processing (no curation focus)
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Text curation (CUDA 12)
|
||||
uv pip install "nemo-curator[text_cuda12]"
|
||||
|
||||
# All modalities
|
||||
uv pip install "nemo-curator[all_cuda12]"
|
||||
|
||||
# CPU-only (slower)
|
||||
uv pip install "nemo-curator[cpu]"
|
||||
```
|
||||
|
||||
### Basic text curation pipeline
|
||||
|
||||
```python
|
||||
from nemo_curator import ScoreFilter, Modify
|
||||
from nemo_curator.datasets import DocumentDataset
|
||||
import pandas as pd
|
||||
|
||||
# Load data
|
||||
df = pd.DataFrame({"text": ["Good document", "Bad doc", "Excellent text"]})
|
||||
dataset = DocumentDataset(df)
|
||||
|
||||
# Quality filtering
|
||||
def quality_score(doc):
|
||||
return len(doc["text"].split()) > 5 # Filter short docs
|
||||
|
||||
filtered = ScoreFilter(quality_score)(dataset)
|
||||
|
||||
# Deduplication
|
||||
from nemo_curator.modules import ExactDuplicates
|
||||
deduped = ExactDuplicates()(filtered)
|
||||
|
||||
# Save
|
||||
deduped.to_parquet("curated_data/")
|
||||
```
|
||||
|
||||
## Data curation pipeline
|
||||
|
||||
### Stage 1: Quality filtering
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import (
|
||||
WordCountFilter,
|
||||
RepeatedLinesFilter,
|
||||
UrlRatioFilter,
|
||||
NonAlphaNumericFilter
|
||||
)
|
||||
|
||||
# Apply 30+ heuristic filters
|
||||
from nemo_curator import ScoreFilter
|
||||
|
||||
# Word count filter
|
||||
dataset = dataset.filter(WordCountFilter(min_words=50, max_words=100000))
|
||||
|
||||
# Remove repetitive content
|
||||
dataset = dataset.filter(RepeatedLinesFilter(max_repeated_line_fraction=0.3))
|
||||
|
||||
# URL ratio filter
|
||||
dataset = dataset.filter(UrlRatioFilter(max_url_ratio=0.2))
|
||||
```
|
||||
|
||||
### Stage 2: Deduplication
|
||||
|
||||
**Exact deduplication**:
|
||||
```python
|
||||
from nemo_curator.modules import ExactDuplicates
|
||||
|
||||
# Remove exact duplicates
|
||||
deduped = ExactDuplicates(id_field="id", text_field="text")(dataset)
|
||||
```
|
||||
|
||||
**Fuzzy deduplication** (16× faster on GPU):
|
||||
```python
|
||||
from nemo_curator.modules import FuzzyDuplicates
|
||||
|
||||
# MinHash + LSH deduplication
|
||||
fuzzy_dedup = FuzzyDuplicates(
|
||||
id_field="id",
|
||||
text_field="text",
|
||||
num_hashes=260, # MinHash parameters
|
||||
num_buckets=20,
|
||||
hash_method="md5"
|
||||
)
|
||||
|
||||
deduped = fuzzy_dedup(dataset)
|
||||
```
|
||||
|
||||
**Semantic deduplication**:
|
||||
```python
|
||||
from nemo_curator.modules import SemanticDuplicates
|
||||
|
||||
# Embedding-based deduplication
|
||||
semantic_dedup = SemanticDuplicates(
|
||||
id_field="id",
|
||||
text_field="text",
|
||||
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
threshold=0.8 # Cosine similarity threshold
|
||||
)
|
||||
|
||||
deduped = semantic_dedup(dataset)
|
||||
```
|
||||
|
||||
### Stage 3: PII redaction
|
||||
|
||||
```python
|
||||
from nemo_curator.modules import Modify
|
||||
from nemo_curator.modifiers import PIIRedactor
|
||||
|
||||
# Redact personally identifiable information
|
||||
pii_redactor = PIIRedactor(
|
||||
supported_entities=["EMAIL_ADDRESS", "PHONE_NUMBER", "PERSON", "LOCATION"],
|
||||
anonymize_action="replace" # or "redact"
|
||||
)
|
||||
|
||||
redacted = Modify(pii_redactor)(dataset)
|
||||
```
|
||||
|
||||
### Stage 4: Classifier filtering
|
||||
|
||||
```python
|
||||
from nemo_curator.classifiers import QualityClassifier
|
||||
|
||||
# Quality classification
|
||||
quality_clf = QualityClassifier(
|
||||
model_path="nvidia/quality-classifier-deberta",
|
||||
batch_size=256,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Filter low-quality documents
|
||||
high_quality = dataset.filter(lambda doc: quality_clf(doc["text"]) > 0.5)
|
||||
```
|
||||
|
||||
## GPU acceleration
|
||||
|
||||
### GPU vs CPU performance
|
||||
|
||||
| Operation | CPU (16 cores) | GPU (A100) | Speedup |
|
||||
|-----------|----------------|------------|---------|
|
||||
| Fuzzy dedup (8TB) | 120 hours | 7.5 hours | 16× |
|
||||
| Exact dedup (1TB) | 8 hours | 0.5 hours | 16× |
|
||||
| Quality filtering | 2 hours | 0.2 hours | 10× |
|
||||
|
||||
### Multi-GPU scaling
|
||||
|
||||
```python
|
||||
from nemo_curator import get_client
|
||||
import dask_cuda
|
||||
|
||||
# Initialize GPU cluster
|
||||
client = get_client(cluster_type="gpu", n_workers=8)
|
||||
|
||||
# Process with 8 GPUs
|
||||
deduped = FuzzyDuplicates(...)(dataset)
|
||||
```
|
||||
|
||||
## Multi-modal curation
|
||||
|
||||
### Image curation
|
||||
|
||||
```python
|
||||
from nemo_curator.image import (
|
||||
AestheticFilter,
|
||||
NSFWFilter,
|
||||
CLIPEmbedder
|
||||
)
|
||||
|
||||
# Aesthetic scoring
|
||||
aesthetic_filter = AestheticFilter(threshold=5.0)
|
||||
filtered_images = aesthetic_filter(image_dataset)
|
||||
|
||||
# NSFW detection
|
||||
nsfw_filter = NSFWFilter(threshold=0.9)
|
||||
safe_images = nsfw_filter(filtered_images)
|
||||
|
||||
# Generate CLIP embeddings
|
||||
clip_embedder = CLIPEmbedder(model="openai/clip-vit-base-patch32")
|
||||
image_embeddings = clip_embedder(safe_images)
|
||||
```
|
||||
|
||||
### Video curation
|
||||
|
||||
```python
|
||||
from nemo_curator.video import (
|
||||
SceneDetector,
|
||||
ClipExtractor,
|
||||
InternVideo2Embedder
|
||||
)
|
||||
|
||||
# Detect scenes
|
||||
scene_detector = SceneDetector(threshold=27.0)
|
||||
scenes = scene_detector(video_dataset)
|
||||
|
||||
# Extract clips
|
||||
clip_extractor = ClipExtractor(min_duration=2.0, max_duration=10.0)
|
||||
clips = clip_extractor(scenes)
|
||||
|
||||
# Generate embeddings
|
||||
video_embedder = InternVideo2Embedder()
|
||||
video_embeddings = video_embedder(clips)
|
||||
```
|
||||
|
||||
### Audio curation
|
||||
|
||||
```python
|
||||
from nemo_curator.audio import (
|
||||
ASRInference,
|
||||
WERFilter,
|
||||
DurationFilter
|
||||
)
|
||||
|
||||
# ASR transcription
|
||||
asr = ASRInference(model="nvidia/stt_en_fastconformer_hybrid_large_pc")
|
||||
transcribed = asr(audio_dataset)
|
||||
|
||||
# Filter by WER (word error rate)
|
||||
wer_filter = WERFilter(max_wer=0.3)
|
||||
high_quality_audio = wer_filter(transcribed)
|
||||
|
||||
# Duration filtering
|
||||
duration_filter = DurationFilter(min_duration=1.0, max_duration=30.0)
|
||||
filtered_audio = duration_filter(high_quality_audio)
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Web scrape curation (Common Crawl)
|
||||
|
||||
```python
|
||||
from nemo_curator import ScoreFilter, Modify
|
||||
from nemo_curator.filters import *
|
||||
from nemo_curator.modules import *
|
||||
from nemo_curator.datasets import DocumentDataset
|
||||
|
||||
# Load Common Crawl data
|
||||
dataset = DocumentDataset.read_parquet("common_crawl/*.parquet")
|
||||
|
||||
# Pipeline
|
||||
pipeline = [
|
||||
# 1. Quality filtering
|
||||
WordCountFilter(min_words=100, max_words=50000),
|
||||
RepeatedLinesFilter(max_repeated_line_fraction=0.2),
|
||||
SymbolToWordRatioFilter(max_symbol_to_word_ratio=0.3),
|
||||
UrlRatioFilter(max_url_ratio=0.3),
|
||||
|
||||
# 2. Language filtering
|
||||
LanguageIdentificationFilter(target_languages=["en"]),
|
||||
|
||||
# 3. Deduplication
|
||||
ExactDuplicates(id_field="id", text_field="text"),
|
||||
FuzzyDuplicates(id_field="id", text_field="text", num_hashes=260),
|
||||
|
||||
# 4. PII redaction
|
||||
PIIRedactor(),
|
||||
|
||||
# 5. NSFW filtering
|
||||
NSFWClassifier(threshold=0.8)
|
||||
]
|
||||
|
||||
# Execute
|
||||
for stage in pipeline:
|
||||
dataset = stage(dataset)
|
||||
|
||||
# Save
|
||||
dataset.to_parquet("curated_common_crawl/")
|
||||
```
|
||||
|
||||
### Distributed processing
|
||||
|
||||
```python
|
||||
from nemo_curator import get_client
|
||||
from dask_cuda import LocalCUDACluster
|
||||
|
||||
# Multi-GPU cluster
|
||||
cluster = LocalCUDACluster(n_workers=8)
|
||||
client = get_client(cluster=cluster)
|
||||
|
||||
# Process large dataset
|
||||
dataset = DocumentDataset.read_parquet("s3://large_dataset/*.parquet")
|
||||
deduped = FuzzyDuplicates(...)(dataset)
|
||||
|
||||
# Cleanup
|
||||
client.close()
|
||||
cluster.close()
|
||||
```
|
||||
|
||||
## Performance benchmarks
|
||||
|
||||
### Fuzzy deduplication (8TB RedPajama v2)
|
||||
|
||||
- **CPU (256 cores)**: 120 hours
|
||||
- **GPU (8× A100)**: 7.5 hours
|
||||
- **Speedup**: 16×
|
||||
|
||||
### Exact deduplication (1TB)
|
||||
|
||||
- **CPU (64 cores)**: 8 hours
|
||||
- **GPU (4× A100)**: 0.5 hours
|
||||
- **Speedup**: 16×
|
||||
|
||||
### Quality filtering (100GB)
|
||||
|
||||
- **CPU (32 cores)**: 2 hours
|
||||
- **GPU (2× A100)**: 0.2 hours
|
||||
- **Speedup**: 10×
|
||||
|
||||
## Cost comparison
|
||||
|
||||
**CPU-based curation** (AWS c5.18xlarge × 10):
|
||||
- Cost: $3.60/hour × 10 = $36/hour
|
||||
- Time for 8TB: 120 hours
|
||||
- **Total**: $4,320
|
||||
|
||||
**GPU-based curation** (AWS p4d.24xlarge × 2):
|
||||
- Cost: $32.77/hour × 2 = $65.54/hour
|
||||
- Time for 8TB: 7.5 hours
|
||||
- **Total**: $491.55
|
||||
|
||||
**Savings**: 89% reduction ($3,828 saved)
|
||||
|
||||
## Supported data formats
|
||||
|
||||
- **Input**: Parquet, JSONL, CSV
|
||||
- **Output**: Parquet (recommended), JSONL
|
||||
- **WebDataset**: TAR archives for multi-modal
|
||||
|
||||
## Use cases
|
||||
|
||||
**Production deployments**:
|
||||
- NVIDIA used NeMo Curator to prepare Nemotron-4 training data
|
||||
- Open-source datasets curated: RedPajama v2, The Pile
|
||||
|
||||
## References
|
||||
|
||||
- **[Filtering Guide](references/filtering.md)** - 30+ quality filters, heuristics
|
||||
- **[Deduplication Guide](references/deduplication.md)** - Exact, fuzzy, semantic methods
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/NVIDIA/NeMo-Curator ⭐ 500+
|
||||
- **Docs**: https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/
|
||||
- **Version**: 0.4.0+
|
||||
- **License**: Apache 2.0
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
# Deduplication Guide
|
||||
|
||||
Complete guide to exact, fuzzy, and semantic deduplication.
|
||||
|
||||
## Exact deduplication
|
||||
|
||||
Remove documents with identical content.
|
||||
|
||||
```python
|
||||
from nemo_curator.modules import ExactDuplicates
|
||||
|
||||
# Exact deduplication
|
||||
exact_dedup = ExactDuplicates(
|
||||
id_field="id",
|
||||
text_field="text",
|
||||
hash_method="md5" # or "sha256"
|
||||
)
|
||||
|
||||
deduped = exact_dedup(dataset)
|
||||
```
|
||||
|
||||
**Performance**: ~16× faster on GPU vs CPU
|
||||
|
||||
## Fuzzy deduplication
|
||||
|
||||
Remove near-duplicate documents using MinHash + LSH.
|
||||
|
||||
```python
|
||||
from nemo_curator.modules import FuzzyDuplicates
|
||||
|
||||
fuzzy_dedup = FuzzyDuplicates(
|
||||
id_field="id",
|
||||
text_field="text",
|
||||
num_hashes=260, # MinHash permutations (more = accurate)
|
||||
num_buckets=20, # LSH buckets (more = faster, less recall)
|
||||
hash_method="md5",
|
||||
jaccard_threshold=0.8 # Similarity threshold
|
||||
)
|
||||
|
||||
deduped = fuzzy_dedup(dataset)
|
||||
```
|
||||
|
||||
**Parameters**:
|
||||
- `num_hashes`: 128-512 (default 260)
|
||||
- `num_buckets`: 10-50 (default 20)
|
||||
- `jaccard_threshold`: 0.7-0.9 (default 0.8)
|
||||
|
||||
**Performance**: 16× faster on 8TB dataset (120h → 7.5h)
|
||||
|
||||
## Semantic deduplication
|
||||
|
||||
Remove semantically similar documents using embeddings.
|
||||
|
||||
```python
|
||||
from nemo_curator.modules import SemanticDuplicates
|
||||
|
||||
semantic_dedup = SemanticDuplicates(
|
||||
id_field="id",
|
||||
text_field="text",
|
||||
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
embedding_batch_size=256,
|
||||
threshold=0.85, # Cosine similarity threshold
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
deduped = semantic_dedup(dataset)
|
||||
```
|
||||
|
||||
**Models**:
|
||||
- `all-MiniLM-L6-v2`: Fast, 384 dims
|
||||
- `all-mpnet-base-v2`: Better quality, 768 dims
|
||||
- Custom models supported
|
||||
|
||||
## Comparison
|
||||
|
||||
| Method | Speed | Recall | Use Case |
|
||||
|--------|-------|--------|----------|
|
||||
| Exact | Fastest | 100% | Exact matches only |
|
||||
| Fuzzy | Fast | ~95% | Near-duplicates (recommended) |
|
||||
| Semantic | Slow | ~90% | Paraphrases, rewrites |
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Start with exact dedup** - Remove obvious duplicates
|
||||
2. **Use fuzzy for large datasets** - Best speed/quality trade-off
|
||||
3. **Semantic for high-value data** - Expensive but thorough
|
||||
4. **GPU acceleration required** - 10-16× speedup
|
||||
102
optional-skills/mlops/nemo-curator/references/filtering.md
Normal file
102
optional-skills/mlops/nemo-curator/references/filtering.md
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
# Quality Filtering Guide
|
||||
|
||||
Complete guide to NeMo Curator's 30+ quality filters.
|
||||
|
||||
## Text-based filters
|
||||
|
||||
### Word count
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import WordCountFilter
|
||||
|
||||
# Filter by word count
|
||||
dataset = dataset.filter(WordCountFilter(min_words=50, max_words=100000))
|
||||
```
|
||||
|
||||
### Repeated content
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import RepeatedLinesFilter
|
||||
|
||||
# Remove documents with >30% repeated lines
|
||||
dataset = dataset.filter(RepeatedLinesFilter(max_repeated_line_fraction=0.3))
|
||||
```
|
||||
|
||||
### Symbol ratio
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import SymbolToWordRatioFilter
|
||||
|
||||
# Remove documents with too many symbols
|
||||
dataset = dataset.filter(SymbolToWordRatioFilter(max_symbol_to_word_ratio=0.3))
|
||||
```
|
||||
|
||||
### URL ratio
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import UrlRatioFilter
|
||||
|
||||
# Remove documents with many URLs
|
||||
dataset = dataset.filter(UrlRatioFilter(max_url_ratio=0.2))
|
||||
```
|
||||
|
||||
## Language filtering
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import LanguageIdentificationFilter
|
||||
|
||||
# Keep only English documents
|
||||
dataset = dataset.filter(LanguageIdentificationFilter(target_languages=["en"]))
|
||||
|
||||
# Multiple languages
|
||||
dataset = dataset.filter(LanguageIdentificationFilter(target_languages=["en", "es", "fr"]))
|
||||
```
|
||||
|
||||
## Classifier-based filtering
|
||||
|
||||
### Quality classifier
|
||||
|
||||
```python
|
||||
from nemo_curator.classifiers import QualityClassifier
|
||||
|
||||
quality_clf = QualityClassifier(
|
||||
model_path="nvidia/quality-classifier-deberta",
|
||||
batch_size=256,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Filter low-quality (threshold > 0.5 = high quality)
|
||||
dataset = dataset.filter(lambda doc: quality_clf(doc["text"]) > 0.5)
|
||||
```
|
||||
|
||||
### NSFW classifier
|
||||
|
||||
```python
|
||||
from nemo_curator.classifiers import NSFWClassifier
|
||||
|
||||
nsfw_clf = NSFWClassifier(threshold=0.9, device="cuda")
|
||||
|
||||
# Remove NSFW content
|
||||
dataset = dataset.filter(lambda doc: nsfw_clf(doc["text"]) < 0.9)
|
||||
```
|
||||
|
||||
## Heuristic filters
|
||||
|
||||
Full list of 30+ filters:
|
||||
- WordCountFilter
|
||||
- RepeatedLinesFilter
|
||||
- UrlRatioFilter
|
||||
- SymbolToWordRatioFilter
|
||||
- NonAlphaNumericFilter
|
||||
- BulletsFilter
|
||||
- WhiteSpaceFilter
|
||||
- ParenthesesFilter
|
||||
- LongWordFilter
|
||||
- And 20+ more...
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Apply cheap filters first** - Word count before GPU classifiers
|
||||
2. **Tune thresholds on sample** - Test on 10k docs before full run
|
||||
3. **Use GPU classifiers sparingly** - Expensive but effective
|
||||
4. **Chain filters efficiently** - Order by cost (cheap → expensive)
|
||||
361
optional-skills/mlops/pinecone/SKILL.md
Normal file
361
optional-skills/mlops/pinecone/SKILL.md
Normal file
|
|
@ -0,0 +1,361 @@
|
|||
---
|
||||
name: pinecone
|
||||
description: Managed vector database for production AI applications. Fully managed, auto-scaling, with hybrid search (dense + sparse), metadata filtering, and namespaces. Low latency (<100ms p95). Use for production RAG, recommendation systems, or semantic search at scale. Best for serverless, managed infrastructure.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [pinecone-client]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [RAG, Pinecone, Vector Database, Managed Service, Serverless, Hybrid Search, Production, Auto-Scaling, Low Latency, Recommendations]
|
||||
|
||||
---
|
||||
|
||||
# Pinecone - Managed Vector Database
|
||||
|
||||
The vector database for production AI applications.
|
||||
|
||||
## When to use Pinecone
|
||||
|
||||
**Use when:**
|
||||
- Need managed, serverless vector database
|
||||
- Production RAG applications
|
||||
- Auto-scaling required
|
||||
- Low latency critical (<100ms)
|
||||
- Don't want to manage infrastructure
|
||||
- Need hybrid search (dense + sparse vectors)
|
||||
|
||||
**Metrics**:
|
||||
- Fully managed SaaS
|
||||
- Auto-scales to billions of vectors
|
||||
- **p95 latency <100ms**
|
||||
- 99.9% uptime SLA
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **Chroma**: Self-hosted, open-source
|
||||
- **FAISS**: Offline, pure similarity search
|
||||
- **Weaviate**: Self-hosted with more features
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
pip install pinecone-client
|
||||
```
|
||||
|
||||
### Basic usage
|
||||
|
||||
```python
|
||||
from pinecone import Pinecone, ServerlessSpec
|
||||
|
||||
# Initialize
|
||||
pc = Pinecone(api_key="your-api-key")
|
||||
|
||||
# Create index
|
||||
pc.create_index(
|
||||
name="my-index",
|
||||
dimension=1536, # Must match embedding dimension
|
||||
metric="cosine", # or "euclidean", "dotproduct"
|
||||
spec=ServerlessSpec(cloud="aws", region="us-east-1")
|
||||
)
|
||||
|
||||
# Connect to index
|
||||
index = pc.Index("my-index")
|
||||
|
||||
# Upsert vectors
|
||||
index.upsert(vectors=[
|
||||
{"id": "vec1", "values": [0.1, 0.2, ...], "metadata": {"category": "A"}},
|
||||
{"id": "vec2", "values": [0.3, 0.4, ...], "metadata": {"category": "B"}}
|
||||
])
|
||||
|
||||
# Query
|
||||
results = index.query(
|
||||
vector=[0.1, 0.2, ...],
|
||||
top_k=5,
|
||||
include_metadata=True
|
||||
)
|
||||
|
||||
print(results["matches"])
|
||||
```
|
||||
|
||||
## Core operations
|
||||
|
||||
### Create index
|
||||
|
||||
```python
|
||||
# Serverless (recommended)
|
||||
pc.create_index(
|
||||
name="my-index",
|
||||
dimension=1536,
|
||||
metric="cosine",
|
||||
spec=ServerlessSpec(
|
||||
cloud="aws", # or "gcp", "azure"
|
||||
region="us-east-1"
|
||||
)
|
||||
)
|
||||
|
||||
# Pod-based (for consistent performance)
|
||||
from pinecone import PodSpec
|
||||
|
||||
pc.create_index(
|
||||
name="my-index",
|
||||
dimension=1536,
|
||||
metric="cosine",
|
||||
spec=PodSpec(
|
||||
environment="us-east1-gcp",
|
||||
pod_type="p1.x1"
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
### Upsert vectors
|
||||
|
||||
```python
|
||||
# Single upsert
|
||||
index.upsert(vectors=[
|
||||
{
|
||||
"id": "doc1",
|
||||
"values": [0.1, 0.2, ...], # 1536 dimensions
|
||||
"metadata": {
|
||||
"text": "Document content",
|
||||
"category": "tutorial",
|
||||
"timestamp": "2025-01-01"
|
||||
}
|
||||
}
|
||||
])
|
||||
|
||||
# Batch upsert (recommended)
|
||||
vectors = [
|
||||
{"id": f"vec{i}", "values": embedding, "metadata": metadata}
|
||||
for i, (embedding, metadata) in enumerate(zip(embeddings, metadatas))
|
||||
]
|
||||
|
||||
index.upsert(vectors=vectors, batch_size=100)
|
||||
```
|
||||
|
||||
### Query vectors
|
||||
|
||||
```python
|
||||
# Basic query
|
||||
results = index.query(
|
||||
vector=[0.1, 0.2, ...],
|
||||
top_k=10,
|
||||
include_metadata=True,
|
||||
include_values=False
|
||||
)
|
||||
|
||||
# With metadata filtering
|
||||
results = index.query(
|
||||
vector=[0.1, 0.2, ...],
|
||||
top_k=5,
|
||||
filter={"category": {"$eq": "tutorial"}}
|
||||
)
|
||||
|
||||
# Namespace query
|
||||
results = index.query(
|
||||
vector=[0.1, 0.2, ...],
|
||||
top_k=5,
|
||||
namespace="production"
|
||||
)
|
||||
|
||||
# Access results
|
||||
for match in results["matches"]:
|
||||
print(f"ID: {match['id']}")
|
||||
print(f"Score: {match['score']}")
|
||||
print(f"Metadata: {match['metadata']}")
|
||||
```
|
||||
|
||||
### Metadata filtering
|
||||
|
||||
```python
|
||||
# Exact match
|
||||
filter = {"category": "tutorial"}
|
||||
|
||||
# Comparison
|
||||
filter = {"price": {"$gte": 100}} # $gt, $gte, $lt, $lte, $ne
|
||||
|
||||
# Logical operators
|
||||
filter = {
|
||||
"$and": [
|
||||
{"category": "tutorial"},
|
||||
{"difficulty": {"$lte": 3}}
|
||||
]
|
||||
} # Also: $or
|
||||
|
||||
# In operator
|
||||
filter = {"tags": {"$in": ["python", "ml"]}}
|
||||
```
|
||||
|
||||
## Namespaces
|
||||
|
||||
```python
|
||||
# Partition data by namespace
|
||||
index.upsert(
|
||||
vectors=[{"id": "vec1", "values": [...]}],
|
||||
namespace="user-123"
|
||||
)
|
||||
|
||||
# Query specific namespace
|
||||
results = index.query(
|
||||
vector=[...],
|
||||
namespace="user-123",
|
||||
top_k=5
|
||||
)
|
||||
|
||||
# List namespaces
|
||||
stats = index.describe_index_stats()
|
||||
print(stats['namespaces'])
|
||||
```
|
||||
|
||||
## Hybrid search (dense + sparse)
|
||||
|
||||
```python
|
||||
# Upsert with sparse vectors
|
||||
index.upsert(vectors=[
|
||||
{
|
||||
"id": "doc1",
|
||||
"values": [0.1, 0.2, ...], # Dense vector
|
||||
"sparse_values": {
|
||||
"indices": [10, 45, 123], # Token IDs
|
||||
"values": [0.5, 0.3, 0.8] # TF-IDF scores
|
||||
},
|
||||
"metadata": {"text": "..."}
|
||||
}
|
||||
])
|
||||
|
||||
# Hybrid query
|
||||
results = index.query(
|
||||
vector=[0.1, 0.2, ...],
|
||||
sparse_vector={
|
||||
"indices": [10, 45],
|
||||
"values": [0.5, 0.3]
|
||||
},
|
||||
top_k=5,
|
||||
alpha=0.5 # 0=sparse, 1=dense, 0.5=hybrid
|
||||
)
|
||||
```
|
||||
|
||||
## LangChain integration
|
||||
|
||||
```python
|
||||
from langchain_pinecone import PineconeVectorStore
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
# Create vector store
|
||||
vectorstore = PineconeVectorStore.from_documents(
|
||||
documents=docs,
|
||||
embedding=OpenAIEmbeddings(),
|
||||
index_name="my-index"
|
||||
)
|
||||
|
||||
# Query
|
||||
results = vectorstore.similarity_search("query", k=5)
|
||||
|
||||
# With metadata filter
|
||||
results = vectorstore.similarity_search(
|
||||
"query",
|
||||
k=5,
|
||||
filter={"category": "tutorial"}
|
||||
)
|
||||
|
||||
# As retriever
|
||||
retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
|
||||
```
|
||||
|
||||
## LlamaIndex integration
|
||||
|
||||
```python
|
||||
from llama_index.vector_stores.pinecone import PineconeVectorStore
|
||||
|
||||
# Connect to Pinecone
|
||||
pc = Pinecone(api_key="your-key")
|
||||
pinecone_index = pc.Index("my-index")
|
||||
|
||||
# Create vector store
|
||||
vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
|
||||
|
||||
# Use in LlamaIndex
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
|
||||
```
|
||||
|
||||
## Index management
|
||||
|
||||
```python
|
||||
# List indices
|
||||
indexes = pc.list_indexes()
|
||||
|
||||
# Describe index
|
||||
index_info = pc.describe_index("my-index")
|
||||
print(index_info)
|
||||
|
||||
# Get index stats
|
||||
stats = index.describe_index_stats()
|
||||
print(f"Total vectors: {stats['total_vector_count']}")
|
||||
print(f"Namespaces: {stats['namespaces']}")
|
||||
|
||||
# Delete index
|
||||
pc.delete_index("my-index")
|
||||
```
|
||||
|
||||
## Delete vectors
|
||||
|
||||
```python
|
||||
# Delete by ID
|
||||
index.delete(ids=["vec1", "vec2"])
|
||||
|
||||
# Delete by filter
|
||||
index.delete(filter={"category": "old"})
|
||||
|
||||
# Delete all in namespace
|
||||
index.delete(delete_all=True, namespace="test")
|
||||
|
||||
# Delete entire index
|
||||
index.delete(delete_all=True)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use serverless** - Auto-scaling, cost-effective
|
||||
2. **Batch upserts** - More efficient (100-200 per batch)
|
||||
3. **Add metadata** - Enable filtering
|
||||
4. **Use namespaces** - Isolate data by user/tenant
|
||||
5. **Monitor usage** - Check Pinecone dashboard
|
||||
6. **Optimize filters** - Index frequently filtered fields
|
||||
7. **Test with free tier** - 1 index, 100K vectors free
|
||||
8. **Use hybrid search** - Better quality
|
||||
9. **Set appropriate dimensions** - Match embedding model
|
||||
10. **Regular backups** - Export important data
|
||||
|
||||
## Performance
|
||||
|
||||
| Operation | Latency | Notes |
|
||||
|-----------|---------|-------|
|
||||
| Upsert | ~50-100ms | Per batch |
|
||||
| Query (p50) | ~50ms | Depends on index size |
|
||||
| Query (p95) | ~100ms | SLA target |
|
||||
| Metadata filter | ~+10-20ms | Additional overhead |
|
||||
|
||||
## Pricing (as of 2025)
|
||||
|
||||
**Serverless**:
|
||||
- $0.096 per million read units
|
||||
- $0.06 per million write units
|
||||
- $0.06 per GB storage/month
|
||||
|
||||
**Free tier**:
|
||||
- 1 serverless index
|
||||
- 100K vectors (1536 dimensions)
|
||||
- Great for prototyping
|
||||
|
||||
## Resources
|
||||
|
||||
- **Website**: https://www.pinecone.io
|
||||
- **Docs**: https://docs.pinecone.io
|
||||
- **Console**: https://app.pinecone.io
|
||||
- **Pricing**: https://www.pinecone.io/pricing
|
||||
|
||||
|
||||
181
optional-skills/mlops/pinecone/references/deployment.md
Normal file
181
optional-skills/mlops/pinecone/references/deployment.md
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
# Pinecone Deployment Guide
|
||||
|
||||
Production deployment patterns for Pinecone.
|
||||
|
||||
## Serverless vs Pod-based
|
||||
|
||||
### Serverless (Recommended)
|
||||
|
||||
```python
|
||||
from pinecone import Pinecone, ServerlessSpec
|
||||
|
||||
pc = Pinecone(api_key="your-key")
|
||||
|
||||
# Create serverless index
|
||||
pc.create_index(
|
||||
name="my-index",
|
||||
dimension=1536,
|
||||
metric="cosine",
|
||||
spec=ServerlessSpec(
|
||||
cloud="aws", # or "gcp", "azure"
|
||||
region="us-east-1"
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Auto-scaling
|
||||
- Pay per usage
|
||||
- No infrastructure management
|
||||
- Cost-effective for variable load
|
||||
|
||||
**Use when:**
|
||||
- Variable traffic
|
||||
- Cost optimization important
|
||||
- Don't need consistent latency
|
||||
|
||||
### Pod-based
|
||||
|
||||
```python
|
||||
from pinecone import PodSpec
|
||||
|
||||
pc.create_index(
|
||||
name="my-index",
|
||||
dimension=1536,
|
||||
metric="cosine",
|
||||
spec=PodSpec(
|
||||
environment="us-east1-gcp",
|
||||
pod_type="p1.x1", # or p1.x2, p1.x4, p1.x8
|
||||
pods=2, # Number of pods
|
||||
replicas=2 # High availability
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Consistent performance
|
||||
- Predictable latency
|
||||
- Higher throughput
|
||||
- Dedicated resources
|
||||
|
||||
**Use when:**
|
||||
- Production workloads
|
||||
- Need consistent p95 latency
|
||||
- High throughput required
|
||||
|
||||
## Hybrid search
|
||||
|
||||
### Dense + Sparse vectors
|
||||
|
||||
```python
|
||||
# Upsert with both dense and sparse vectors
|
||||
index.upsert(vectors=[
|
||||
{
|
||||
"id": "doc1",
|
||||
"values": [0.1, 0.2, ...], # Dense (semantic)
|
||||
"sparse_values": {
|
||||
"indices": [10, 45, 123], # Token IDs
|
||||
"values": [0.5, 0.3, 0.8] # TF-IDF/BM25 scores
|
||||
},
|
||||
"metadata": {"text": "..."}
|
||||
}
|
||||
])
|
||||
|
||||
# Hybrid query
|
||||
results = index.query(
|
||||
vector=[0.1, 0.2, ...], # Dense query
|
||||
sparse_vector={
|
||||
"indices": [10, 45],
|
||||
"values": [0.5, 0.3]
|
||||
},
|
||||
top_k=10,
|
||||
alpha=0.5 # 0=sparse only, 1=dense only, 0.5=balanced
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Best of both worlds
|
||||
- Semantic + keyword matching
|
||||
- Better recall than either alone
|
||||
|
||||
## Namespaces for multi-tenancy
|
||||
|
||||
```python
|
||||
# Separate data by user/tenant
|
||||
index.upsert(
|
||||
vectors=[{"id": "doc1", "values": [...]}],
|
||||
namespace="user-123"
|
||||
)
|
||||
|
||||
# Query specific namespace
|
||||
results = index.query(
|
||||
vector=[...],
|
||||
namespace="user-123",
|
||||
top_k=5
|
||||
)
|
||||
|
||||
# List namespaces
|
||||
stats = index.describe_index_stats()
|
||||
print(stats['namespaces'])
|
||||
```
|
||||
|
||||
**Use cases:**
|
||||
- Multi-tenant SaaS
|
||||
- User-specific data isolation
|
||||
- A/B testing (prod/staging namespaces)
|
||||
|
||||
## Metadata filtering
|
||||
|
||||
### Exact match
|
||||
|
||||
```python
|
||||
results = index.query(
|
||||
vector=[...],
|
||||
filter={"category": "tutorial"},
|
||||
top_k=5
|
||||
)
|
||||
```
|
||||
|
||||
### Range queries
|
||||
|
||||
```python
|
||||
results = index.query(
|
||||
vector=[...],
|
||||
filter={"price": {"$gte": 100, "$lte": 500}},
|
||||
top_k=5
|
||||
)
|
||||
```
|
||||
|
||||
### Complex filters
|
||||
|
||||
```python
|
||||
results = index.query(
|
||||
vector=[...],
|
||||
filter={
|
||||
"$and": [
|
||||
{"category": {"$in": ["tutorial", "guide"]}},
|
||||
{"difficulty": {"$lte": 3}},
|
||||
{"published": {"$gte": "2024-01-01"}}
|
||||
]
|
||||
},
|
||||
top_k=5
|
||||
)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use serverless for development** - Cost-effective
|
||||
2. **Switch to pods for production** - Consistent performance
|
||||
3. **Implement namespaces** - Multi-tenancy
|
||||
4. **Add metadata strategically** - Enable filtering
|
||||
5. **Use hybrid search** - Better quality
|
||||
6. **Batch upserts** - 100-200 vectors per batch
|
||||
7. **Monitor usage** - Check Pinecone dashboard
|
||||
8. **Set up alerts** - Usage/cost thresholds
|
||||
9. **Regular backups** - Export important data
|
||||
10. **Test filters** - Verify performance
|
||||
|
||||
## Resources
|
||||
|
||||
- **Docs**: https://docs.pinecone.io
|
||||
- **Console**: https://app.pinecone.io
|
||||
349
optional-skills/mlops/pytorch-lightning/SKILL.md
Normal file
349
optional-skills/mlops/pytorch-lightning/SKILL.md
Normal file
|
|
@ -0,0 +1,349 @@
|
|||
---
|
||||
name: pytorch-lightning
|
||||
description: High-level PyTorch framework with Trainer class, automatic distributed training (DDP/FSDP/DeepSpeed), callbacks system, and minimal boilerplate. Scales from laptop to supercomputer with same code. Use when you want clean training loops with built-in best practices.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [lightning, torch, transformers]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [PyTorch Lightning, Training Framework, Distributed Training, DDP, FSDP, DeepSpeed, High-Level API, Callbacks, Best Practices, Scalable]
|
||||
|
||||
---
|
||||
|
||||
# PyTorch Lightning - High-Level Training Framework
|
||||
|
||||
## Quick start
|
||||
|
||||
PyTorch Lightning organizes PyTorch code to eliminate boilerplate while maintaining flexibility.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install lightning
|
||||
```
|
||||
|
||||
**Convert PyTorch to Lightning** (3 steps):
|
||||
|
||||
```python
|
||||
import lightning as L
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
# Step 1: Define LightningModule (organize your PyTorch code)
|
||||
class LitModel(L.LightningModule):
|
||||
def __init__(self, hidden_size=128):
|
||||
super().__init__()
|
||||
self.model = nn.Sequential(
|
||||
nn.Linear(28 * 28, hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size, 10)
|
||||
)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
loss = nn.functional.cross_entropy(y_hat, y)
|
||||
self.log('train_loss', loss) # Auto-logged to TensorBoard
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
|
||||
# Step 2: Create data
|
||||
train_loader = DataLoader(train_dataset, batch_size=32)
|
||||
|
||||
# Step 3: Train with Trainer (handles everything else!)
|
||||
trainer = L.Trainer(max_epochs=10, accelerator='gpu', devices=2)
|
||||
model = LitModel()
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
**That's it!** Trainer handles:
|
||||
- GPU/TPU/CPU switching
|
||||
- Distributed training (DDP, FSDP, DeepSpeed)
|
||||
- Mixed precision (FP16, BF16)
|
||||
- Gradient accumulation
|
||||
- Checkpointing
|
||||
- Logging
|
||||
- Progress bars
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: From PyTorch to Lightning
|
||||
|
||||
**Original PyTorch code**:
|
||||
```python
|
||||
model = MyModel()
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
model.to('cuda')
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
for batch in train_loader:
|
||||
batch = batch.to('cuda')
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Lightning version**:
|
||||
```python
|
||||
class LitModel(L.LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = MyModel()
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = self.model(batch) # No .to('cuda') needed!
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters())
|
||||
|
||||
# Train
|
||||
trainer = L.Trainer(max_epochs=10, accelerator='gpu')
|
||||
trainer.fit(LitModel(), train_loader)
|
||||
```
|
||||
|
||||
**Benefits**: 40+ lines → 15 lines, no device management, automatic distributed
|
||||
|
||||
### Workflow 2: Validation and testing
|
||||
|
||||
```python
|
||||
class LitModel(L.LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = MyModel()
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
loss = nn.functional.cross_entropy(y_hat, y)
|
||||
self.log('train_loss', loss)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
val_loss = nn.functional.cross_entropy(y_hat, y)
|
||||
acc = (y_hat.argmax(dim=1) == y).float().mean()
|
||||
self.log('val_loss', val_loss)
|
||||
self.log('val_acc', acc)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
test_loss = nn.functional.cross_entropy(y_hat, y)
|
||||
self.log('test_loss', test_loss)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
|
||||
# Train with validation
|
||||
trainer = L.Trainer(max_epochs=10)
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
|
||||
# Test
|
||||
trainer.test(model, test_loader)
|
||||
```
|
||||
|
||||
**Automatic features**:
|
||||
- Validation runs every epoch by default
|
||||
- Metrics logged to TensorBoard
|
||||
- Best model checkpointing based on val_loss
|
||||
|
||||
### Workflow 3: Distributed training (DDP)
|
||||
|
||||
```python
|
||||
# Same code as single GPU!
|
||||
model = LitModel()
|
||||
|
||||
# 8 GPUs with DDP (automatic!)
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=8,
|
||||
strategy='ddp' # Or 'fsdp', 'deepspeed'
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
**Launch**:
|
||||
```bash
|
||||
# Single command, Lightning handles the rest
|
||||
python train.py
|
||||
```
|
||||
|
||||
**No changes needed**:
|
||||
- Automatic data distribution
|
||||
- Gradient synchronization
|
||||
- Multi-node support (just set `num_nodes=2`)
|
||||
|
||||
### Workflow 4: Callbacks for monitoring
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
|
||||
|
||||
# Create callbacks
|
||||
checkpoint = ModelCheckpoint(
|
||||
monitor='val_loss',
|
||||
mode='min',
|
||||
save_top_k=3,
|
||||
filename='model-{epoch:02d}-{val_loss:.2f}'
|
||||
)
|
||||
|
||||
early_stop = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=5,
|
||||
mode='min'
|
||||
)
|
||||
|
||||
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||
|
||||
# Add to Trainer
|
||||
trainer = L.Trainer(
|
||||
max_epochs=100,
|
||||
callbacks=[checkpoint, early_stop, lr_monitor]
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
```
|
||||
|
||||
**Result**:
|
||||
- Auto-saves best 3 models
|
||||
- Stops early if no improvement for 5 epochs
|
||||
- Logs learning rate to TensorBoard
|
||||
|
||||
### Workflow 5: Learning rate scheduling
|
||||
|
||||
```python
|
||||
class LitModel(L.LightningModule):
|
||||
# ... (training_step, etc.)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
|
||||
# Cosine annealing
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer,
|
||||
T_max=100,
|
||||
eta_min=1e-5
|
||||
)
|
||||
|
||||
return {
|
||||
'optimizer': optimizer,
|
||||
'lr_scheduler': {
|
||||
'scheduler': scheduler,
|
||||
'interval': 'epoch', # Update per epoch
|
||||
'frequency': 1
|
||||
}
|
||||
}
|
||||
|
||||
# Learning rate auto-logged!
|
||||
trainer = L.Trainer(max_epochs=100)
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use PyTorch Lightning when**:
|
||||
- Want clean, organized code
|
||||
- Need production-ready training loops
|
||||
- Switching between single GPU, multi-GPU, TPU
|
||||
- Want built-in callbacks and logging
|
||||
- Team collaboration (standardized structure)
|
||||
|
||||
**Key advantages**:
|
||||
- **Organized**: Separates research code from engineering
|
||||
- **Automatic**: DDP, FSDP, DeepSpeed with 1 line
|
||||
- **Callbacks**: Modular training extensions
|
||||
- **Reproducible**: Less boilerplate = fewer bugs
|
||||
- **Tested**: 1M+ downloads/month, battle-tested
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **Accelerate**: Minimal changes to existing code, more flexibility
|
||||
- **Ray Train**: Multi-node orchestration, hyperparameter tuning
|
||||
- **Raw PyTorch**: Maximum control, learning purposes
|
||||
- **Keras**: TensorFlow ecosystem
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Loss not decreasing**
|
||||
|
||||
Check data and model setup:
|
||||
```python
|
||||
# Add to training_step
|
||||
def training_step(self, batch, batch_idx):
|
||||
if batch_idx == 0:
|
||||
print(f"Batch shape: {batch[0].shape}")
|
||||
print(f"Labels: {batch[1]}")
|
||||
loss = ...
|
||||
return loss
|
||||
```
|
||||
|
||||
**Issue: Out of memory**
|
||||
|
||||
Reduce batch size or use gradient accumulation:
|
||||
```python
|
||||
trainer = L.Trainer(
|
||||
accumulate_grad_batches=4, # Effective batch = batch_size × 4
|
||||
precision='bf16' # Or 'fp16', reduces memory 50%
|
||||
)
|
||||
```
|
||||
|
||||
**Issue: Validation not running**
|
||||
|
||||
Ensure you pass val_loader:
|
||||
```python
|
||||
# WRONG
|
||||
trainer.fit(model, train_loader)
|
||||
|
||||
# CORRECT
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
```
|
||||
|
||||
**Issue: DDP spawns multiple processes unexpectedly**
|
||||
|
||||
Lightning auto-detects GPUs. Explicitly set devices:
|
||||
```python
|
||||
# Test on CPU first
|
||||
trainer = L.Trainer(accelerator='cpu', devices=1)
|
||||
|
||||
# Then GPU
|
||||
trainer = L.Trainer(accelerator='gpu', devices=1)
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Callbacks**: See [references/callbacks.md](references/callbacks.md) for EarlyStopping, ModelCheckpoint, custom callbacks, and callback hooks.
|
||||
|
||||
**Distributed strategies**: See [references/distributed.md](references/distributed.md) for DDP, FSDP, DeepSpeed ZeRO integration, multi-node setup.
|
||||
|
||||
**Hyperparameter tuning**: See [references/hyperparameter-tuning.md](references/hyperparameter-tuning.md) for integration with Optuna, Ray Tune, and WandB sweeps.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **CPU**: Works (good for debugging)
|
||||
- **Single GPU**: Works
|
||||
- **Multi-GPU**: DDP (default), FSDP, or DeepSpeed
|
||||
- **Multi-node**: DDP, FSDP, DeepSpeed
|
||||
- **TPU**: Supported (8 cores)
|
||||
- **Apple MPS**: Supported
|
||||
|
||||
**Precision options**:
|
||||
- FP32 (default)
|
||||
- FP16 (V100, older GPUs)
|
||||
- BF16 (A100/H100, recommended)
|
||||
- FP8 (H100)
|
||||
|
||||
## Resources
|
||||
|
||||
- Docs: https://lightning.ai/docs/pytorch/stable/
|
||||
- GitHub: https://github.com/Lightning-AI/pytorch-lightning ⭐ 29,000+
|
||||
- Version: 2.5.5+
|
||||
- Examples: https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples
|
||||
- Discord: https://discord.gg/lightning-ai
|
||||
- Used by: Kaggle winners, research labs, production teams
|
||||
|
||||
|
||||
436
optional-skills/mlops/pytorch-lightning/references/callbacks.md
Normal file
436
optional-skills/mlops/pytorch-lightning/references/callbacks.md
Normal file
|
|
@ -0,0 +1,436 @@
|
|||
# PyTorch Lightning Callbacks
|
||||
|
||||
## Overview
|
||||
|
||||
Callbacks add functionality to training without modifying the LightningModule. They capture **non-essential logic** like checkpointing, early stopping, and logging.
|
||||
|
||||
## Built-In Callbacks
|
||||
|
||||
### 1. ModelCheckpoint
|
||||
|
||||
**Saves best models during training**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
|
||||
# Save top 3 models based on validation loss
|
||||
checkpoint = ModelCheckpoint(
|
||||
dirpath='checkpoints/',
|
||||
filename='model-{epoch:02d}-{val_loss:.2f}',
|
||||
monitor='val_loss',
|
||||
mode='min',
|
||||
save_top_k=3,
|
||||
save_last=True, # Also save last epoch
|
||||
verbose=True
|
||||
)
|
||||
|
||||
trainer = L.Trainer(callbacks=[checkpoint])
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
```
|
||||
|
||||
**Configuration options**:
|
||||
```python
|
||||
checkpoint = ModelCheckpoint(
|
||||
monitor='val_acc', # Metric to monitor
|
||||
mode='max', # 'max' for accuracy, 'min' for loss
|
||||
save_top_k=5, # Keep best 5 models
|
||||
save_last=True, # Save last epoch separately
|
||||
every_n_epochs=1, # Save every N epochs
|
||||
save_on_train_epoch_end=False, # Save on validation end instead
|
||||
filename='best-{epoch}-{val_acc:.3f}', # Naming pattern
|
||||
auto_insert_metric_name=False # Don't auto-add metric to filename
|
||||
)
|
||||
```
|
||||
|
||||
**Load checkpoint**:
|
||||
```python
|
||||
# Load best model
|
||||
best_model_path = checkpoint.best_model_path
|
||||
model = LitModel.load_from_checkpoint(best_model_path)
|
||||
|
||||
# Resume training
|
||||
trainer = L.Trainer(callbacks=[checkpoint])
|
||||
trainer.fit(model, train_loader, val_loader, ckpt_path='checkpoints/last.ckpt')
|
||||
```
|
||||
|
||||
### 2. EarlyStopping
|
||||
|
||||
**Stops training when metric stops improving**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import EarlyStopping
|
||||
|
||||
early_stop = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=5, # Wait 5 epochs
|
||||
mode='min',
|
||||
min_delta=0.001, # Minimum change to qualify as improvement
|
||||
verbose=True,
|
||||
strict=True, # Crash if monitored metric not found
|
||||
check_on_train_epoch_end=False # Check on validation end
|
||||
)
|
||||
|
||||
trainer = L.Trainer(callbacks=[early_stop])
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
# Stops automatically if no improvement for 5 epochs
|
||||
```
|
||||
|
||||
**Advanced usage**:
|
||||
```python
|
||||
early_stop = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=10,
|
||||
min_delta=0.0,
|
||||
verbose=True,
|
||||
mode='min',
|
||||
stopping_threshold=0.1, # Stop if val_loss < 0.1
|
||||
divergence_threshold=5.0, # Stop if val_loss > 5.0
|
||||
check_finite=True # Stop on NaN/Inf
|
||||
)
|
||||
```
|
||||
|
||||
### 3. LearningRateMonitor
|
||||
|
||||
**Logs learning rate**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import LearningRateMonitor
|
||||
|
||||
lr_monitor = LearningRateMonitor(
|
||||
logging_interval='epoch', # Or 'step'
|
||||
log_momentum=True # Also log momentum
|
||||
)
|
||||
|
||||
trainer = L.Trainer(callbacks=[lr_monitor])
|
||||
# Learning rate automatically logged to TensorBoard/WandB
|
||||
```
|
||||
|
||||
### 4. TQDMProgressBar
|
||||
|
||||
**Customizes progress bar**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import TQDMProgressBar
|
||||
|
||||
progress_bar = TQDMProgressBar(
|
||||
refresh_rate=10, # Update every 10 batches
|
||||
process_position=0
|
||||
)
|
||||
|
||||
trainer = L.Trainer(callbacks=[progress_bar])
|
||||
```
|
||||
|
||||
### 5. GradientAccumulationScheduler
|
||||
|
||||
**Dynamic gradient accumulation**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import GradientAccumulationScheduler
|
||||
|
||||
# Accumulate more gradients as training progresses
|
||||
accumulator = GradientAccumulationScheduler(
|
||||
scheduling={
|
||||
0: 8, # Epochs 0-4: accumulate 8 batches
|
||||
5: 4, # Epochs 5-9: accumulate 4 batches
|
||||
10: 2 # Epochs 10+: accumulate 2 batches
|
||||
}
|
||||
)
|
||||
|
||||
trainer = L.Trainer(callbacks=[accumulator])
|
||||
```
|
||||
|
||||
### 6. StochasticWeightAveraging (SWA)
|
||||
|
||||
**Averages weights for better generalization**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import StochasticWeightAveraging
|
||||
|
||||
swa = StochasticWeightAveraging(
|
||||
swa_lrs=1e-2, # SWA learning rate
|
||||
swa_epoch_start=0.8, # Start at 80% of training
|
||||
annealing_epochs=10, # Annealing period
|
||||
annealing_strategy='cos' # 'cos' or 'linear'
|
||||
)
|
||||
|
||||
trainer = L.Trainer(callbacks=[swa])
|
||||
```
|
||||
|
||||
## Custom Callbacks
|
||||
|
||||
### Basic Custom Callback
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
|
||||
class PrintingCallback(Callback):
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
print("Training is starting!")
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
print("Training is done!")
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
print(f"Epoch {trainer.current_epoch} ended")
|
||||
|
||||
# Use it
|
||||
trainer = L.Trainer(callbacks=[PrintingCallback()])
|
||||
```
|
||||
|
||||
### Advanced Custom Callback
|
||||
|
||||
```python
|
||||
class MetricsCallback(Callback):
|
||||
"""Logs custom metrics every N batches."""
|
||||
|
||||
def __init__(self, log_every_n_batches=100):
|
||||
self.log_every_n_batches = log_every_n_batches
|
||||
self.metrics = []
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
if batch_idx % self.log_every_n_batches == 0:
|
||||
# Compute custom metric
|
||||
metric = self.compute_metric(outputs)
|
||||
self.metrics.append(metric)
|
||||
|
||||
# Log to Lightning
|
||||
pl_module.log('custom_metric', metric)
|
||||
|
||||
def compute_metric(self, outputs):
|
||||
# Your custom logic
|
||||
return outputs['loss'].item()
|
||||
|
||||
def state_dict(self):
|
||||
"""Save callback state in checkpoint."""
|
||||
return {'metrics': self.metrics}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Restore callback state from checkpoint."""
|
||||
self.metrics = state_dict['metrics']
|
||||
```
|
||||
|
||||
### Gradient Monitoring Callback
|
||||
|
||||
```python
|
||||
class GradientMonitorCallback(Callback):
|
||||
"""Monitor gradient norms."""
|
||||
|
||||
def on_after_backward(self, trainer, pl_module):
|
||||
# Compute gradient norm
|
||||
total_norm = 0.0
|
||||
for p in pl_module.parameters():
|
||||
if p.grad is not None:
|
||||
param_norm = p.grad.data.norm(2)
|
||||
total_norm += param_norm.item() ** 2
|
||||
total_norm = total_norm ** 0.5
|
||||
|
||||
# Log
|
||||
pl_module.log('grad_norm', total_norm)
|
||||
|
||||
# Warn if exploding
|
||||
if total_norm > 100:
|
||||
print(f"Warning: Large gradient norm: {total_norm:.2f}")
|
||||
```
|
||||
|
||||
### Model Inspection Callback
|
||||
|
||||
```python
|
||||
class ModelInspectionCallback(Callback):
|
||||
"""Inspect model activations during training."""
|
||||
|
||||
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
||||
if batch_idx == 0: # First batch of epoch
|
||||
# Register hooks
|
||||
self.activations = {}
|
||||
|
||||
def get_activation(name):
|
||||
def hook(model, input, output):
|
||||
self.activations[name] = output.detach()
|
||||
return hook
|
||||
|
||||
# Attach to specific layers
|
||||
pl_module.model.layer1.register_forward_hook(get_activation('layer1'))
|
||||
pl_module.model.layer2.register_forward_hook(get_activation('layer2'))
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
if batch_idx == 0:
|
||||
# Log activation statistics
|
||||
for name, activation in self.activations.items():
|
||||
mean = activation.mean().item()
|
||||
std = activation.std().item()
|
||||
pl_module.log(f'{name}_mean', mean)
|
||||
pl_module.log(f'{name}_std', std)
|
||||
```
|
||||
|
||||
## Callback Hooks
|
||||
|
||||
**All available hooks**:
|
||||
|
||||
```python
|
||||
class MyCallback(Callback):
|
||||
# Setup/Teardown
|
||||
def setup(self, trainer, pl_module, stage):
|
||||
"""Called at beginning of fit/test/predict."""
|
||||
pass
|
||||
|
||||
def teardown(self, trainer, pl_module, stage):
|
||||
"""Called at end of fit/test/predict."""
|
||||
pass
|
||||
|
||||
# Training
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
||||
pass
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
pass
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
# Validation
|
||||
def on_validation_start(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
def on_validation_epoch_start(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
|
||||
pass
|
||||
|
||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
pass
|
||||
|
||||
def on_validation_epoch_end(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
# Test (same structure as validation)
|
||||
def on_test_start(self, trainer, pl_module):
|
||||
pass
|
||||
# ... (test_epoch_start, test_batch_start, etc.)
|
||||
|
||||
# Predict
|
||||
def on_predict_start(self, trainer, pl_module):
|
||||
pass
|
||||
# ... (predict_epoch_start, predict_batch_start, etc.)
|
||||
|
||||
# Backward
|
||||
def on_before_backward(self, trainer, pl_module, loss):
|
||||
pass
|
||||
|
||||
def on_after_backward(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
# Optimizer
|
||||
def on_before_optimizer_step(self, trainer, pl_module, optimizer):
|
||||
pass
|
||||
|
||||
# Checkpointing
|
||||
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
|
||||
"""Add data to checkpoint."""
|
||||
pass
|
||||
|
||||
def on_load_checkpoint(self, trainer, pl_module, checkpoint):
|
||||
"""Restore data from checkpoint."""
|
||||
pass
|
||||
```
|
||||
|
||||
## Combining Multiple Callbacks
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
|
||||
|
||||
# Create all callbacks
|
||||
checkpoint = ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=3)
|
||||
early_stop = EarlyStopping(monitor='val_loss', patience=5)
|
||||
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||
custom_callback = MyCustomCallback()
|
||||
|
||||
# Add all to Trainer
|
||||
trainer = L.Trainer(
|
||||
callbacks=[checkpoint, early_stop, lr_monitor, custom_callback]
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
```
|
||||
|
||||
**Execution order**: Callbacks execute in the order they're added
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Keep Callbacks Independent
|
||||
|
||||
**Bad** (dependent on other callback):
|
||||
```python
|
||||
class BadCallback(Callback):
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
# Assumes ModelCheckpoint is present
|
||||
best_path = trainer.checkpoint_callback.best_model_path # Fragile!
|
||||
```
|
||||
|
||||
**Good** (self-contained):
|
||||
```python
|
||||
class GoodCallback(Callback):
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
# Find checkpoint callback if present
|
||||
for callback in trainer.callbacks:
|
||||
if isinstance(callback, ModelCheckpoint):
|
||||
best_path = callback.best_model_path
|
||||
break
|
||||
```
|
||||
|
||||
### 2. Use State Dict for Persistence
|
||||
|
||||
```python
|
||||
class StatefulCallback(Callback):
|
||||
def __init__(self):
|
||||
self.counter = 0
|
||||
self.history = []
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
self.counter += 1
|
||||
self.history.append(outputs['loss'].item())
|
||||
|
||||
def state_dict(self):
|
||||
"""Save state."""
|
||||
return {
|
||||
'counter': self.counter,
|
||||
'history': self.history
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Restore state."""
|
||||
self.counter = state_dict['counter']
|
||||
self.history = state_dict['history']
|
||||
```
|
||||
|
||||
### 3. Handle Distributed Training
|
||||
|
||||
```python
|
||||
class DistributedCallback(Callback):
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
# Only run on main process
|
||||
if trainer.is_global_zero:
|
||||
print("This only prints once in distributed training")
|
||||
|
||||
# Run on all processes
|
||||
loss = outputs['loss']
|
||||
# ... do something with loss on each GPU
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Callback API: https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html
|
||||
- Built-in callbacks: https://lightning.ai/docs/pytorch/stable/api_references.html#callbacks
|
||||
- Examples: https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/callbacks
|
||||
|
|
@ -0,0 +1,490 @@
|
|||
# PyTorch Lightning Distributed Training
|
||||
|
||||
## Distributed Strategies
|
||||
|
||||
Lightning supports multiple distributed strategies with a single parameter change.
|
||||
|
||||
### 1. DDP (DistributedDataParallel)
|
||||
|
||||
**Default strategy for multi-GPU**:
|
||||
|
||||
```python
|
||||
# Automatic DDP on all available GPUs
|
||||
trainer = L.Trainer(accelerator='gpu', devices=4, strategy='ddp')
|
||||
|
||||
# Or auto-detect
|
||||
trainer = L.Trainer(accelerator='gpu', devices='auto')
|
||||
```
|
||||
|
||||
**How DDP works**:
|
||||
- Replicates model on each GPU
|
||||
- Each GPU processes different batch
|
||||
- Gradients all-reduced across GPUs
|
||||
- Model weights synchronized
|
||||
|
||||
**Launch**:
|
||||
```bash
|
||||
# Lightning handles spawning processes automatically
|
||||
python train.py
|
||||
```
|
||||
|
||||
**DDP Configuration**:
|
||||
```python
|
||||
from lightning.pytorch.strategies import DDPStrategy
|
||||
|
||||
strategy = DDPStrategy(
|
||||
find_unused_parameters=False, # Set True if model has unused params
|
||||
gradient_as_bucket_view=True, # Memory optimization
|
||||
static_graph=False, # Set True if graph doesn't change
|
||||
)
|
||||
|
||||
trainer = L.Trainer(strategy=strategy)
|
||||
```
|
||||
|
||||
### 2. FSDP (Fully Sharded Data Parallel)
|
||||
|
||||
**For large models (7B+ parameters)**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.strategies import FSDPStrategy
|
||||
|
||||
strategy = FSDPStrategy(
|
||||
sharding_strategy="FULL_SHARD", # ZeRO-3 equivalent
|
||||
activation_checkpointing=None, # Or specify layer types
|
||||
cpu_offload=False, # CPU offload for memory
|
||||
)
|
||||
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=8,
|
||||
strategy=strategy,
|
||||
precision='bf16' # Recommended with FSDP
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
**FSDP Sharding Strategies**:
|
||||
```python
|
||||
# FULL_SHARD (most memory efficient, equivalent to ZeRO-3)
|
||||
strategy = FSDPStrategy(sharding_strategy="FULL_SHARD")
|
||||
|
||||
# SHARD_GRAD_OP (less memory efficient, equivalent to ZeRO-2)
|
||||
strategy = FSDPStrategy(sharding_strategy="SHARD_GRAD_OP")
|
||||
|
||||
# NO_SHARD (no sharding, like DDP)
|
||||
strategy = FSDPStrategy(sharding_strategy="NO_SHARD")
|
||||
```
|
||||
|
||||
**Auto-wrap policy** (wrap transformer blocks):
|
||||
```python
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
|
||||
import functools
|
||||
|
||||
auto_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={GPT2Block}
|
||||
)
|
||||
|
||||
strategy = FSDPStrategy(
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
activation_checkpointing_policy={GPT2Block} # Checkpoint these blocks
|
||||
)
|
||||
```
|
||||
|
||||
### 3. DeepSpeed
|
||||
|
||||
**For massive models (70B+ parameters)**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.strategies import DeepSpeedStrategy
|
||||
|
||||
# DeepSpeed ZeRO-3 with CPU offload
|
||||
strategy = DeepSpeedStrategy(
|
||||
stage=3, # ZeRO-3
|
||||
offload_optimizer=True, # CPU offload optimizer
|
||||
offload_parameters=True, # CPU offload parameters
|
||||
cpu_checkpointing=True, # Checkpoint to CPU
|
||||
)
|
||||
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=8,
|
||||
strategy=strategy,
|
||||
precision='bf16'
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
**DeepSpeed configuration file**:
|
||||
```json
|
||||
{
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"stage3_prefetch_bucket_size": 5e8,
|
||||
"stage3_param_persistence_threshold": 1e6
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Use config file**:
|
||||
```python
|
||||
strategy = DeepSpeedStrategy(config='deepspeed_config.json')
|
||||
trainer = L.Trainer(strategy=strategy)
|
||||
```
|
||||
|
||||
### 4. DDP Spawn
|
||||
|
||||
**Windows-compatible DDP**:
|
||||
|
||||
```python
|
||||
# Use when DDP doesn't work (e.g., Windows, Jupyter)
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=2,
|
||||
strategy='ddp_spawn' # Spawns new processes
|
||||
)
|
||||
```
|
||||
|
||||
**Note**: Slower than DDP due to process spawning overhead
|
||||
|
||||
## Multi-Node Training
|
||||
|
||||
### Setup Multi-Node Cluster
|
||||
|
||||
**Node 0 (master)**:
|
||||
```bash
|
||||
export MASTER_ADDR=192.168.1.100
|
||||
export MASTER_PORT=12355
|
||||
export WORLD_SIZE=16 # 2 nodes × 8 GPUs
|
||||
export NODE_RANK=0
|
||||
|
||||
python train.py
|
||||
```
|
||||
|
||||
**Node 1 (worker)**:
|
||||
```bash
|
||||
export MASTER_ADDR=192.168.1.100
|
||||
export MASTER_PORT=12355
|
||||
export WORLD_SIZE=16
|
||||
export NODE_RANK=1
|
||||
|
||||
python train.py
|
||||
```
|
||||
|
||||
**Training script**:
|
||||
```python
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=8, # GPUs per node
|
||||
num_nodes=2, # Total nodes
|
||||
strategy='ddp'
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
### SLURM Integration
|
||||
|
||||
**SLURM job script**:
|
||||
```bash
|
||||
#!/bin/bash
|
||||
#SBATCH --nodes=4
|
||||
#SBATCH --ntasks-per-node=8
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH --time=24:00:00
|
||||
|
||||
# Lightning auto-detects SLURM environment
|
||||
srun python train.py
|
||||
```
|
||||
|
||||
**Training script** (no changes needed):
|
||||
```python
|
||||
# Lightning automatically reads SLURM environment variables
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=8,
|
||||
num_nodes=4, # From SBATCH --nodes
|
||||
strategy='ddp'
|
||||
)
|
||||
```
|
||||
|
||||
### Kubernetes (KubeFlow)
|
||||
|
||||
**Training script**:
|
||||
```python
|
||||
import os
|
||||
|
||||
# Lightning auto-detects Kubernetes
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=int(os.getenv('WORLD_SIZE', 1)),
|
||||
strategy='ddp'
|
||||
)
|
||||
```
|
||||
|
||||
## Mixed Precision Training
|
||||
|
||||
### BF16 (A100/H100)
|
||||
|
||||
```python
|
||||
trainer = L.Trainer(
|
||||
precision='bf16', # Or 'bf16-mixed'
|
||||
accelerator='gpu'
|
||||
)
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- No gradient scaler needed
|
||||
- Same dynamic range as FP32
|
||||
- 2× speedup, 50% memory reduction
|
||||
|
||||
### FP16 (V100, older GPUs)
|
||||
|
||||
```python
|
||||
trainer = L.Trainer(
|
||||
precision='16-mixed', # Or just '16'
|
||||
accelerator='gpu'
|
||||
)
|
||||
```
|
||||
|
||||
**Automatic gradient scaling** handled by Lightning
|
||||
|
||||
### FP8 (H100)
|
||||
|
||||
```python
|
||||
# Requires transformer_engine
|
||||
# pip install transformer-engine[pytorch]
|
||||
|
||||
trainer = L.Trainer(
|
||||
precision='transformer-engine',
|
||||
accelerator='gpu'
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits**: 2× faster than BF16 on H100
|
||||
|
||||
## Gradient Accumulation
|
||||
|
||||
**Simulate larger batch size**:
|
||||
|
||||
```python
|
||||
trainer = L.Trainer(
|
||||
accumulate_grad_batches=4, # Accumulate 4 batches
|
||||
precision='bf16'
|
||||
)
|
||||
|
||||
# Effective batch = batch_size × accumulate_grad_batches × num_gpus
|
||||
# Example: 32 × 4 × 8 = 1024
|
||||
```
|
||||
|
||||
**Dynamic accumulation**:
|
||||
```python
|
||||
# Accumulate more early in training
|
||||
trainer = L.Trainer(
|
||||
accumulate_grad_batches={
|
||||
0: 8, # Epochs 0-4: accumulate 8
|
||||
5: 4, # Epochs 5-9: accumulate 4
|
||||
10: 2 # Epochs 10+: accumulate 2
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Checkpointing in Distributed
|
||||
|
||||
### Save Checkpoint
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
|
||||
# Only rank 0 saves by default
|
||||
checkpoint = ModelCheckpoint(
|
||||
dirpath='checkpoints/',
|
||||
filename='model-{epoch:02d}',
|
||||
save_top_k=3
|
||||
)
|
||||
|
||||
trainer = L.Trainer(callbacks=[checkpoint], strategy='ddp')
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
**Manual save**:
|
||||
```python
|
||||
class MyModel(L.LightningModule):
|
||||
def training_step(self, batch, batch_idx):
|
||||
# Training...
|
||||
loss = ...
|
||||
|
||||
# Save every 1000 steps (only rank 0)
|
||||
if batch_idx % 1000 == 0 and self.trainer.is_global_zero:
|
||||
self.trainer.save_checkpoint(f'checkpoint_step_{batch_idx}.ckpt')
|
||||
|
||||
return loss
|
||||
```
|
||||
|
||||
### Load Checkpoint
|
||||
|
||||
```python
|
||||
# Resume training
|
||||
trainer = L.Trainer(strategy='ddp')
|
||||
trainer.fit(model, train_loader, ckpt_path='checkpoints/last.ckpt')
|
||||
|
||||
# Load for inference
|
||||
model = MyModel.load_from_checkpoint('checkpoints/best.ckpt')
|
||||
model.eval()
|
||||
```
|
||||
|
||||
## Strategy Comparison
|
||||
|
||||
| Strategy | Memory Efficiency | Speed | Use Case |
|
||||
|----------|------------------|-------|----------|
|
||||
| DDP | Low | Fast | Small models (<7B), single node |
|
||||
| FSDP | High | Medium | Large models (7-70B) |
|
||||
| DeepSpeed ZeRO-2 | Medium | Fast | Medium models (1-13B) |
|
||||
| DeepSpeed ZeRO-3 | Very High | Slower | Massive models (70B+) |
|
||||
| DDP Spawn | Low | Slow | Windows, debugging |
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Choose Right Strategy
|
||||
|
||||
```python
|
||||
# Model size guide
|
||||
if model_params < 1e9: # <1B
|
||||
strategy = 'ddp'
|
||||
elif model_params < 7e9: # 1-7B
|
||||
strategy = 'ddp' or DeepSpeedStrategy(stage=2)
|
||||
elif model_params < 70e9: # 7-70B
|
||||
strategy = FSDPStrategy(sharding_strategy="FULL_SHARD")
|
||||
else: # 70B+
|
||||
strategy = DeepSpeedStrategy(stage=3, offload_optimizer=True)
|
||||
|
||||
trainer = L.Trainer(strategy=strategy)
|
||||
```
|
||||
|
||||
### 2. Avoid Sync Issues
|
||||
|
||||
```python
|
||||
class MyModel(L.LightningModule):
|
||||
def training_step(self, batch, batch_idx):
|
||||
# WRONG: This runs on all GPUs independently
|
||||
if batch_idx % 100 == 0:
|
||||
self.log_something() # Logged 8 times on 8 GPUs!
|
||||
|
||||
# CORRECT: Use is_global_zero
|
||||
if batch_idx % 100 == 0 and self.trainer.is_global_zero:
|
||||
self.log_something() # Logged once
|
||||
|
||||
loss = ...
|
||||
return loss
|
||||
```
|
||||
|
||||
### 3. Efficient Data Loading
|
||||
|
||||
```python
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
|
||||
# Lightning handles DistributedSampler automatically
|
||||
train_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=32,
|
||||
num_workers=4, # 4 workers per GPU
|
||||
pin_memory=True,
|
||||
persistent_workers=True
|
||||
)
|
||||
|
||||
# Lightning automatically wraps with DistributedSampler in DDP
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
### 4. Reduce Communication Overhead
|
||||
|
||||
```python
|
||||
from lightning.pytorch.strategies import DDPStrategy
|
||||
|
||||
strategy = DDPStrategy(
|
||||
gradient_as_bucket_view=True, # Reduce memory copies
|
||||
static_graph=True, # If model graph doesn't change (faster)
|
||||
)
|
||||
|
||||
trainer = L.Trainer(strategy=strategy)
|
||||
```
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Issue: NCCL Timeout
|
||||
|
||||
**Symptom**: Training hangs with `NCCL timeout` error
|
||||
|
||||
**Solution 1**: Increase timeout
|
||||
```bash
|
||||
export NCCL_TIMEOUT=3600 # 1 hour
|
||||
python train.py
|
||||
```
|
||||
|
||||
**Solution 2**: Check network
|
||||
```bash
|
||||
# Test inter-node communication
|
||||
nvidia-smi nvlink -s
|
||||
|
||||
# Verify all nodes can ping each other
|
||||
ping <node-2-ip>
|
||||
```
|
||||
|
||||
### Issue: OOM with FSDP
|
||||
|
||||
**Solution**: Enable CPU offload
|
||||
```python
|
||||
strategy = FSDPStrategy(
|
||||
sharding_strategy="FULL_SHARD",
|
||||
cpu_offload=True # Offload to CPU
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Different Results with DDP
|
||||
|
||||
**Cause**: Different random seeds per GPU
|
||||
|
||||
**Solution**: Set seed in LightningModule
|
||||
```python
|
||||
class MyModel(L.LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
L.seed_everything(42, workers=True) # Same seed everywhere
|
||||
```
|
||||
|
||||
### Issue: DeepSpeed Config Errors
|
||||
|
||||
**Solution**: Use Lightning's auto config
|
||||
```python
|
||||
strategy = DeepSpeedStrategy(
|
||||
stage=3,
|
||||
# Don't specify config file, Lightning generates automatically
|
||||
)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Distributed strategies: https://lightning.ai/docs/pytorch/stable/accelerators/gpu_intermediate.html
|
||||
- FSDP guide: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html
|
||||
- DeepSpeed: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/deepspeed.html
|
||||
- Multi-node: https://lightning.ai/docs/pytorch/stable/clouds/cluster.html
|
||||
|
|
@ -0,0 +1,556 @@
|
|||
# Hyperparameter Tuning with PyTorch Lightning
|
||||
|
||||
## Integration with Tuning Frameworks
|
||||
|
||||
Lightning integrates seamlessly with popular hyperparameter tuning libraries.
|
||||
|
||||
### 1. Ray Tune Integration
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install ray[tune]
|
||||
pip install lightning
|
||||
```
|
||||
|
||||
**Basic Ray Tune example**:
|
||||
|
||||
```python
|
||||
import lightning as L
|
||||
from ray import tune
|
||||
from ray.tune.integration.pytorch_lightning import TuneReportCallback
|
||||
|
||||
class LitModel(L.LightningModule):
|
||||
def __init__(self, lr, batch_size):
|
||||
super().__init__()
|
||||
self.lr = lr
|
||||
self.batch_size = batch_size
|
||||
self.model = nn.Sequential(nn.Linear(10, 128), nn.ReLU(), nn.Linear(128, 1))
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = self.model(batch).mean()
|
||||
self.log('train_loss', loss)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
val_loss = self.model(batch).mean()
|
||||
self.log('val_loss', val_loss)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||
|
||||
def train_fn(config):
|
||||
"""Training function for Ray Tune."""
|
||||
model = LitModel(lr=config["lr"], batch_size=config["batch_size"])
|
||||
|
||||
# Add callback to report metrics to Tune
|
||||
trainer = L.Trainer(
|
||||
max_epochs=10,
|
||||
callbacks=[TuneReportCallback({"loss": "val_loss"}, on="validation_end")]
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
|
||||
# Define search space
|
||||
config = {
|
||||
"lr": tune.loguniform(1e-5, 1e-1),
|
||||
"batch_size": tune.choice([16, 32, 64, 128])
|
||||
}
|
||||
|
||||
# Run hyperparameter search
|
||||
analysis = tune.run(
|
||||
train_fn,
|
||||
config=config,
|
||||
num_samples=20, # 20 trials
|
||||
resources_per_trial={"gpu": 1}
|
||||
)
|
||||
|
||||
# Best hyperparameters
|
||||
best_config = analysis.get_best_config(metric="loss", mode="min")
|
||||
print(f"Best config: {best_config}")
|
||||
```
|
||||
|
||||
**Advanced: Population-Based Training (PBT)**:
|
||||
|
||||
```python
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
|
||||
# PBT scheduler
|
||||
scheduler = PopulationBasedTraining(
|
||||
time_attr='training_iteration',
|
||||
metric='val_loss',
|
||||
mode='min',
|
||||
perturbation_interval=5, # Perturb every 5 epochs
|
||||
hyperparam_mutations={
|
||||
"lr": tune.loguniform(1e-5, 1e-1),
|
||||
"batch_size": [16, 32, 64, 128]
|
||||
}
|
||||
)
|
||||
|
||||
analysis = tune.run(
|
||||
train_fn,
|
||||
config=config,
|
||||
num_samples=8, # Population size
|
||||
scheduler=scheduler,
|
||||
resources_per_trial={"gpu": 1}
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Optuna Integration
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install optuna
|
||||
pip install optuna-integration
|
||||
```
|
||||
|
||||
**Optuna example**:
|
||||
|
||||
```python
|
||||
import optuna
|
||||
from optuna.integration import PyTorchLightningPruningCallback
|
||||
|
||||
def objective(trial):
|
||||
# Suggest hyperparameters
|
||||
lr = trial.suggest_loguniform('lr', 1e-5, 1e-1)
|
||||
batch_size = trial.suggest_categorical('batch_size', [16, 32, 64, 128])
|
||||
n_layers = trial.suggest_int('n_layers', 1, 3)
|
||||
hidden_size = trial.suggest_int('hidden_size', 64, 512, step=64)
|
||||
|
||||
# Create model
|
||||
model = LitModel(lr=lr, n_layers=n_layers, hidden_size=hidden_size)
|
||||
|
||||
# Pruning callback (early stopping for bad trials)
|
||||
pruning_callback = PyTorchLightningPruningCallback(trial, monitor="val_loss")
|
||||
|
||||
trainer = L.Trainer(
|
||||
max_epochs=20,
|
||||
callbacks=[pruning_callback],
|
||||
enable_progress_bar=False,
|
||||
logger=False
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
|
||||
return trainer.callback_metrics["val_loss"].item()
|
||||
|
||||
# Create study
|
||||
study = optuna.create_study(
|
||||
direction='minimize',
|
||||
pruner=optuna.pruners.MedianPruner() # Prune bad trials early
|
||||
)
|
||||
|
||||
# Optimize
|
||||
study.optimize(objective, n_trials=50, timeout=3600)
|
||||
|
||||
# Best params
|
||||
print(f"Best trial: {study.best_trial.params}")
|
||||
print(f"Best value: {study.best_value}")
|
||||
|
||||
# Visualization
|
||||
optuna.visualization.plot_optimization_history(study).show()
|
||||
optuna.visualization.plot_param_importances(study).show()
|
||||
```
|
||||
|
||||
**Optuna with distributed training**:
|
||||
|
||||
```python
|
||||
import optuna
|
||||
|
||||
# Shared database for distributed optimization
|
||||
storage = optuna.storages.RDBStorage(
|
||||
url='postgresql://user:pass@localhost/optuna'
|
||||
)
|
||||
|
||||
study = optuna.create_study(
|
||||
study_name='distributed_study',
|
||||
storage=storage,
|
||||
load_if_exists=True,
|
||||
direction='minimize'
|
||||
)
|
||||
|
||||
# Run on multiple machines
|
||||
study.optimize(objective, n_trials=50)
|
||||
```
|
||||
|
||||
### 3. Weights & Biases (WandB) Sweeps
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install wandb
|
||||
```
|
||||
|
||||
**WandB sweep config** (`sweep.yaml`):
|
||||
```yaml
|
||||
program: train.py
|
||||
method: bayes
|
||||
metric:
|
||||
name: val_loss
|
||||
goal: minimize
|
||||
parameters:
|
||||
lr:
|
||||
distribution: log_uniform_values
|
||||
min: 0.00001
|
||||
max: 0.1
|
||||
batch_size:
|
||||
values: [16, 32, 64, 128]
|
||||
optimizer:
|
||||
values: ['adam', 'sgd', 'adamw']
|
||||
dropout:
|
||||
distribution: uniform
|
||||
min: 0.0
|
||||
max: 0.5
|
||||
```
|
||||
|
||||
**Training script** (`train.py`):
|
||||
```python
|
||||
import wandb
|
||||
import lightning as L
|
||||
from lightning.pytorch.loggers import WandbLogger
|
||||
|
||||
def train():
|
||||
# Initialize wandb
|
||||
wandb.init()
|
||||
config = wandb.config
|
||||
|
||||
# Create model with sweep params
|
||||
model = LitModel(
|
||||
lr=config.lr,
|
||||
batch_size=config.batch_size,
|
||||
optimizer=config.optimizer,
|
||||
dropout=config.dropout
|
||||
)
|
||||
|
||||
# WandB logger
|
||||
wandb_logger = WandbLogger(project='hyperparameter-sweep')
|
||||
|
||||
trainer = L.Trainer(
|
||||
max_epochs=20,
|
||||
logger=wandb_logger
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
||||
```
|
||||
|
||||
**Launch sweep**:
|
||||
```bash
|
||||
# Initialize sweep
|
||||
wandb sweep sweep.yaml
|
||||
# Output: wandb: Created sweep with ID: abc123
|
||||
|
||||
# Run agent (can run on multiple machines)
|
||||
wandb agent your-entity/your-project/abc123
|
||||
```
|
||||
|
||||
### 4. Hyperopt Integration
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install hyperopt
|
||||
```
|
||||
|
||||
**Hyperopt example**:
|
||||
|
||||
```python
|
||||
from hyperopt import hp, fmin, tpe, Trials
|
||||
|
||||
def objective(params):
|
||||
model = LitModel(
|
||||
lr=params['lr'],
|
||||
batch_size=int(params['batch_size']),
|
||||
hidden_size=int(params['hidden_size'])
|
||||
)
|
||||
|
||||
trainer = L.Trainer(
|
||||
max_epochs=10,
|
||||
enable_progress_bar=False,
|
||||
logger=False
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
|
||||
# Return loss (minimize)
|
||||
return trainer.callback_metrics["val_loss"].item()
|
||||
|
||||
# Define search space
|
||||
space = {
|
||||
'lr': hp.loguniform('lr', np.log(1e-5), np.log(1e-1)),
|
||||
'batch_size': hp.quniform('batch_size', 16, 128, 16),
|
||||
'hidden_size': hp.quniform('hidden_size', 64, 512, 64)
|
||||
}
|
||||
|
||||
# Optimize
|
||||
trials = Trials()
|
||||
best = fmin(
|
||||
fn=objective,
|
||||
space=space,
|
||||
algo=tpe.suggest, # Tree-structured Parzen Estimator
|
||||
max_evals=50,
|
||||
trials=trials
|
||||
)
|
||||
|
||||
print(f"Best hyperparameters: {best}")
|
||||
```
|
||||
|
||||
## Built-In Lightning Tuning
|
||||
|
||||
### Auto Learning Rate Finder
|
||||
|
||||
```python
|
||||
class LitModel(L.LightningModule):
|
||||
def __init__(self, lr=1e-3):
|
||||
super().__init__()
|
||||
self.lr = lr
|
||||
self.model = nn.Linear(10, 1)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = self.model(batch).mean()
|
||||
return loss
|
||||
|
||||
# Find optimal learning rate
|
||||
model = LitModel()
|
||||
trainer = L.Trainer(auto_lr_find=True)
|
||||
|
||||
# This runs LR finder before training
|
||||
trainer.tune(model, train_loader)
|
||||
|
||||
# Or manually
|
||||
from lightning.pytorch.tuner import Tuner
|
||||
tuner = Tuner(trainer)
|
||||
lr_finder = tuner.lr_find(model, train_loader)
|
||||
|
||||
# Plot results
|
||||
fig = lr_finder.plot(suggest=True)
|
||||
fig.show()
|
||||
|
||||
# Get suggested LR
|
||||
suggested_lr = lr_finder.suggestion()
|
||||
print(f"Suggested LR: {suggested_lr}")
|
||||
|
||||
# Update model
|
||||
model.lr = suggested_lr
|
||||
|
||||
# Train with optimal LR
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
### Auto Batch Size Finder
|
||||
|
||||
```python
|
||||
class LitModel(L.LightningModule):
|
||||
def __init__(self, batch_size=32):
|
||||
super().__init__()
|
||||
self.batch_size = batch_size
|
||||
self.model = nn.Linear(10, 1)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(dataset, batch_size=self.batch_size)
|
||||
|
||||
model = LitModel()
|
||||
trainer = L.Trainer(auto_scale_batch_size='binsearch')
|
||||
|
||||
# Find optimal batch size
|
||||
trainer.tune(model)
|
||||
|
||||
print(f"Optimal batch size: {model.batch_size}")
|
||||
|
||||
# Train with optimal batch size
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
## Advanced Tuning Strategies
|
||||
|
||||
### 1. Multi-Fidelity Optimization (Successive Halving)
|
||||
|
||||
```python
|
||||
from ray.tune.schedulers import ASHAScheduler
|
||||
|
||||
# ASHA: Asynchronous Successive Halving Algorithm
|
||||
scheduler = ASHAScheduler(
|
||||
max_t=100, # Max epochs
|
||||
grace_period=10, # Min epochs before stopping
|
||||
reduction_factor=2 # Halve resources each round
|
||||
)
|
||||
|
||||
analysis = tune.run(
|
||||
train_fn,
|
||||
config=config,
|
||||
num_samples=64,
|
||||
scheduler=scheduler,
|
||||
resources_per_trial={"gpu": 1}
|
||||
)
|
||||
```
|
||||
|
||||
**How it works**:
|
||||
- Start 64 trials
|
||||
- After 10 epochs, stop bottom 50% (32 trials remain)
|
||||
- After 20 epochs, stop bottom 50% (16 trials remain)
|
||||
- After 40 epochs, stop bottom 50% (8 trials remain)
|
||||
- After 80 epochs, stop bottom 50% (4 trials remain)
|
||||
- Run remaining 4 trials to completion (100 epochs)
|
||||
|
||||
### 2. Bayesian Optimization
|
||||
|
||||
```python
|
||||
from ray.tune.search.bayesopt import BayesOptSearch
|
||||
|
||||
search = BayesOptSearch(
|
||||
metric="val_loss",
|
||||
mode="min"
|
||||
)
|
||||
|
||||
analysis = tune.run(
|
||||
train_fn,
|
||||
config=config,
|
||||
num_samples=50,
|
||||
search_alg=search,
|
||||
resources_per_trial={"gpu": 1}
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Grid Search
|
||||
|
||||
```python
|
||||
from ray import tune
|
||||
|
||||
# Exhaustive grid search
|
||||
config = {
|
||||
"lr": tune.grid_search([1e-5, 1e-4, 1e-3, 1e-2]),
|
||||
"batch_size": tune.grid_search([16, 32, 64, 128]),
|
||||
"optimizer": tune.grid_search(['adam', 'sgd', 'adamw'])
|
||||
}
|
||||
|
||||
# Total trials: 4 × 4 × 3 = 48
|
||||
analysis = tune.run(train_fn, config=config)
|
||||
```
|
||||
|
||||
### 4. Random Search
|
||||
|
||||
```python
|
||||
config = {
|
||||
"lr": tune.loguniform(1e-5, 1e-1),
|
||||
"batch_size": tune.choice([16, 32, 64, 128]),
|
||||
"dropout": tune.uniform(0.0, 0.5),
|
||||
"hidden_size": tune.randint(64, 512)
|
||||
}
|
||||
|
||||
# Random sampling
|
||||
analysis = tune.run(
|
||||
train_fn,
|
||||
config=config,
|
||||
num_samples=100 # 100 random samples
|
||||
)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Start Simple
|
||||
|
||||
```python
|
||||
# Phase 1: Coarse search (fast)
|
||||
coarse_config = {
|
||||
"lr": tune.loguniform(1e-5, 1e-1),
|
||||
"batch_size": tune.choice([32, 64])
|
||||
}
|
||||
coarse_analysis = tune.run(train_fn, config=coarse_config, num_samples=10, max_epochs=5)
|
||||
|
||||
# Phase 2: Fine-tune around best (slow)
|
||||
best_lr = coarse_analysis.best_config["lr"]
|
||||
fine_config = {
|
||||
"lr": tune.uniform(best_lr * 0.5, best_lr * 2),
|
||||
"batch_size": tune.choice([16, 32, 64, 128])
|
||||
}
|
||||
fine_analysis = tune.run(train_fn, config=fine_config, num_samples=20, max_epochs=20)
|
||||
```
|
||||
|
||||
### 2. Use Checkpointing
|
||||
|
||||
```python
|
||||
def train_fn(config, checkpoint_dir=None):
|
||||
model = LitModel(lr=config["lr"])
|
||||
|
||||
trainer = L.Trainer(
|
||||
max_epochs=100,
|
||||
callbacks=[
|
||||
TuneReportCheckpointCallback(
|
||||
metrics={"loss": "val_loss"},
|
||||
filename="checkpoint",
|
||||
on="validation_end"
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Resume from checkpoint if exists
|
||||
ckpt_path = None
|
||||
if checkpoint_dir:
|
||||
ckpt_path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
|
||||
trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path)
|
||||
```
|
||||
|
||||
### 3. Monitor Resource Usage
|
||||
|
||||
```python
|
||||
import GPUtil
|
||||
|
||||
def train_fn(config):
|
||||
# Before training
|
||||
GPUs = GPUtil.getGPUs()
|
||||
print(f"GPU memory before: {GPUs[0].memoryUsed} MB")
|
||||
|
||||
# Train
|
||||
model = LitModel(lr=config["lr"], batch_size=config["batch_size"])
|
||||
trainer.fit(model, train_loader)
|
||||
|
||||
# After training
|
||||
GPUs = GPUtil.getGPUs()
|
||||
print(f"GPU memory after: {GPUs[0].memoryUsed} MB")
|
||||
```
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Issue: Trials Running Out of Memory
|
||||
|
||||
**Solution**: Reduce concurrent trials or batch size
|
||||
```python
|
||||
analysis = tune.run(
|
||||
train_fn,
|
||||
config=config,
|
||||
resources_per_trial={"gpu": 0.5}, # 2 trials per GPU
|
||||
max_concurrent_trials=2 # Limit concurrent trials
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Slow Hyperparameter Search
|
||||
|
||||
**Solution**: Use early stopping scheduler
|
||||
```python
|
||||
from ray.tune.schedulers import ASHAScheduler
|
||||
|
||||
scheduler = ASHAScheduler(
|
||||
max_t=100,
|
||||
grace_period=5, # Stop bad trials after 5 epochs
|
||||
reduction_factor=3
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Can't Reproduce Best Trial
|
||||
|
||||
**Solution**: Set seeds in training function
|
||||
```python
|
||||
def train_fn(config):
|
||||
L.seed_everything(42, workers=True)
|
||||
# Rest of training...
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Ray Tune + Lightning: https://docs.ray.io/en/latest/tune/examples/tune-pytorch-lightning.html
|
||||
- Optuna: https://optuna.readthedocs.io/
|
||||
- WandB Sweeps: https://docs.wandb.ai/guides/sweeps
|
||||
- Lightning Tuner: https://lightning.ai/docs/pytorch/stable/tuning.html
|
||||
496
optional-skills/mlops/qdrant/SKILL.md
Normal file
496
optional-skills/mlops/qdrant/SKILL.md
Normal file
|
|
@ -0,0 +1,496 @@
|
|||
---
|
||||
name: qdrant-vector-search
|
||||
description: High-performance vector similarity search engine for RAG and semantic search. Use when building production RAG systems requiring fast nearest neighbor search, hybrid search with filtering, or scalable vector storage with Rust-powered performance.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [qdrant-client>=1.12.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [RAG, Vector Search, Qdrant, Semantic Search, Embeddings, Similarity Search, HNSW, Production, Distributed]
|
||||
|
||||
---
|
||||
|
||||
# Qdrant - Vector Similarity Search Engine
|
||||
|
||||
High-performance vector database written in Rust for production RAG and semantic search.
|
||||
|
||||
## When to use Qdrant
|
||||
|
||||
**Use Qdrant when:**
|
||||
- Building production RAG systems requiring low latency
|
||||
- Need hybrid search (vectors + metadata filtering)
|
||||
- Require horizontal scaling with sharding/replication
|
||||
- Want on-premise deployment with full data control
|
||||
- Need multi-vector storage per record (dense + sparse)
|
||||
- Building real-time recommendation systems
|
||||
|
||||
**Key features:**
|
||||
- **Rust-powered**: Memory-safe, high performance
|
||||
- **Rich filtering**: Filter by any payload field during search
|
||||
- **Multiple vectors**: Dense, sparse, multi-dense per point
|
||||
- **Quantization**: Scalar, product, binary for memory efficiency
|
||||
- **Distributed**: Raft consensus, sharding, replication
|
||||
- **REST + gRPC**: Both APIs with full feature parity
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Chroma**: Simpler setup, embedded use cases
|
||||
- **FAISS**: Maximum raw speed, research/batch processing
|
||||
- **Pinecone**: Fully managed, zero ops preferred
|
||||
- **Weaviate**: GraphQL preference, built-in vectorizers
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Python client
|
||||
pip install qdrant-client
|
||||
|
||||
# Docker (recommended for development)
|
||||
docker run -p 6333:6333 -p 6334:6334 qdrant/qdrant
|
||||
|
||||
# Docker with persistent storage
|
||||
docker run -p 6333:6333 -p 6334:6334 \
|
||||
-v $(pwd)/qdrant_storage:/qdrant/storage \
|
||||
qdrant/qdrant
|
||||
```
|
||||
|
||||
### Basic usage
|
||||
|
||||
```python
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import Distance, VectorParams, PointStruct
|
||||
|
||||
# Connect to Qdrant
|
||||
client = QdrantClient(host="localhost", port=6333)
|
||||
|
||||
# Create collection
|
||||
client.create_collection(
|
||||
collection_name="documents",
|
||||
vectors_config=VectorParams(size=384, distance=Distance.COSINE)
|
||||
)
|
||||
|
||||
# Insert vectors with payload
|
||||
client.upsert(
|
||||
collection_name="documents",
|
||||
points=[
|
||||
PointStruct(
|
||||
id=1,
|
||||
vector=[0.1, 0.2, ...], # 384-dim vector
|
||||
payload={"title": "Doc 1", "category": "tech"}
|
||||
),
|
||||
PointStruct(
|
||||
id=2,
|
||||
vector=[0.3, 0.4, ...],
|
||||
payload={"title": "Doc 2", "category": "science"}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Search with filtering
|
||||
results = client.search(
|
||||
collection_name="documents",
|
||||
query_vector=[0.15, 0.25, ...],
|
||||
query_filter={
|
||||
"must": [{"key": "category", "match": {"value": "tech"}}]
|
||||
},
|
||||
limit=10
|
||||
)
|
||||
|
||||
for point in results:
|
||||
print(f"ID: {point.id}, Score: {point.score}, Payload: {point.payload}")
|
||||
```
|
||||
|
||||
## Core concepts
|
||||
|
||||
### Points - Basic data unit
|
||||
|
||||
```python
|
||||
from qdrant_client.models import PointStruct
|
||||
|
||||
# Point = ID + Vector(s) + Payload
|
||||
point = PointStruct(
|
||||
id=123, # Integer or UUID string
|
||||
vector=[0.1, 0.2, 0.3, ...], # Dense vector
|
||||
payload={ # Arbitrary JSON metadata
|
||||
"title": "Document title",
|
||||
"category": "tech",
|
||||
"timestamp": 1699900000,
|
||||
"tags": ["python", "ml"]
|
||||
}
|
||||
)
|
||||
|
||||
# Batch upsert (recommended)
|
||||
client.upsert(
|
||||
collection_name="documents",
|
||||
points=[point1, point2, point3],
|
||||
wait=True # Wait for indexing
|
||||
)
|
||||
```
|
||||
|
||||
### Collections - Vector containers
|
||||
|
||||
```python
|
||||
from qdrant_client.models import VectorParams, Distance, HnswConfigDiff
|
||||
|
||||
# Create with HNSW configuration
|
||||
client.create_collection(
|
||||
collection_name="documents",
|
||||
vectors_config=VectorParams(
|
||||
size=384, # Vector dimensions
|
||||
distance=Distance.COSINE # COSINE, EUCLID, DOT, MANHATTAN
|
||||
),
|
||||
hnsw_config=HnswConfigDiff(
|
||||
m=16, # Connections per node (default 16)
|
||||
ef_construct=100, # Build-time accuracy (default 100)
|
||||
full_scan_threshold=10000 # Switch to brute force below this
|
||||
),
|
||||
on_disk_payload=True # Store payload on disk
|
||||
)
|
||||
|
||||
# Collection info
|
||||
info = client.get_collection("documents")
|
||||
print(f"Points: {info.points_count}, Vectors: {info.vectors_count}")
|
||||
```
|
||||
|
||||
### Distance metrics
|
||||
|
||||
| Metric | Use Case | Range |
|
||||
|--------|----------|-------|
|
||||
| `COSINE` | Text embeddings, normalized vectors | 0 to 2 |
|
||||
| `EUCLID` | Spatial data, image features | 0 to ∞ |
|
||||
| `DOT` | Recommendations, unnormalized | -∞ to ∞ |
|
||||
| `MANHATTAN` | Sparse features, discrete data | 0 to ∞ |
|
||||
|
||||
## Search operations
|
||||
|
||||
### Basic search
|
||||
|
||||
```python
|
||||
# Simple nearest neighbor search
|
||||
results = client.search(
|
||||
collection_name="documents",
|
||||
query_vector=[0.1, 0.2, ...],
|
||||
limit=10,
|
||||
with_payload=True,
|
||||
with_vectors=False # Don't return vectors (faster)
|
||||
)
|
||||
```
|
||||
|
||||
### Filtered search
|
||||
|
||||
```python
|
||||
from qdrant_client.models import Filter, FieldCondition, MatchValue, Range
|
||||
|
||||
# Complex filtering
|
||||
results = client.search(
|
||||
collection_name="documents",
|
||||
query_vector=query_embedding,
|
||||
query_filter=Filter(
|
||||
must=[
|
||||
FieldCondition(key="category", match=MatchValue(value="tech")),
|
||||
FieldCondition(key="timestamp", range=Range(gte=1699000000))
|
||||
],
|
||||
must_not=[
|
||||
FieldCondition(key="status", match=MatchValue(value="archived"))
|
||||
]
|
||||
),
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Shorthand filter syntax
|
||||
results = client.search(
|
||||
collection_name="documents",
|
||||
query_vector=query_embedding,
|
||||
query_filter={
|
||||
"must": [
|
||||
{"key": "category", "match": {"value": "tech"}},
|
||||
{"key": "price", "range": {"gte": 10, "lte": 100}}
|
||||
]
|
||||
},
|
||||
limit=10
|
||||
)
|
||||
```
|
||||
|
||||
### Batch search
|
||||
|
||||
```python
|
||||
from qdrant_client.models import SearchRequest
|
||||
|
||||
# Multiple queries in one request
|
||||
results = client.search_batch(
|
||||
collection_name="documents",
|
||||
requests=[
|
||||
SearchRequest(vector=[0.1, ...], limit=5),
|
||||
SearchRequest(vector=[0.2, ...], limit=5, filter={"must": [...]}),
|
||||
SearchRequest(vector=[0.3, ...], limit=10)
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
## RAG integration
|
||||
|
||||
### With sentence-transformers
|
||||
|
||||
```python
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import VectorParams, Distance, PointStruct
|
||||
|
||||
# Initialize
|
||||
encoder = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
client = QdrantClient(host="localhost", port=6333)
|
||||
|
||||
# Create collection
|
||||
client.create_collection(
|
||||
collection_name="knowledge_base",
|
||||
vectors_config=VectorParams(size=384, distance=Distance.COSINE)
|
||||
)
|
||||
|
||||
# Index documents
|
||||
documents = [
|
||||
{"id": 1, "text": "Python is a programming language", "source": "wiki"},
|
||||
{"id": 2, "text": "Machine learning uses algorithms", "source": "textbook"},
|
||||
]
|
||||
|
||||
points = [
|
||||
PointStruct(
|
||||
id=doc["id"],
|
||||
vector=encoder.encode(doc["text"]).tolist(),
|
||||
payload={"text": doc["text"], "source": doc["source"]}
|
||||
)
|
||||
for doc in documents
|
||||
]
|
||||
client.upsert(collection_name="knowledge_base", points=points)
|
||||
|
||||
# RAG retrieval
|
||||
def retrieve(query: str, top_k: int = 5) -> list[dict]:
|
||||
query_vector = encoder.encode(query).tolist()
|
||||
results = client.search(
|
||||
collection_name="knowledge_base",
|
||||
query_vector=query_vector,
|
||||
limit=top_k
|
||||
)
|
||||
return [{"text": r.payload["text"], "score": r.score} for r in results]
|
||||
|
||||
# Use in RAG pipeline
|
||||
context = retrieve("What is Python?")
|
||||
prompt = f"Context: {context}\n\nQuestion: What is Python?"
|
||||
```
|
||||
|
||||
### With LangChain
|
||||
|
||||
```python
|
||||
from langchain_community.vectorstores import Qdrant
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
||||
vectorstore = Qdrant.from_documents(documents, embeddings, url="http://localhost:6333", collection_name="docs")
|
||||
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
|
||||
```
|
||||
|
||||
### With LlamaIndex
|
||||
|
||||
```python
|
||||
from llama_index.vector_stores.qdrant import QdrantVectorStore
|
||||
from llama_index.core import VectorStoreIndex, StorageContext
|
||||
|
||||
vector_store = QdrantVectorStore(client=client, collection_name="llama_docs")
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
|
||||
query_engine = index.as_query_engine()
|
||||
```
|
||||
|
||||
## Multi-vector support
|
||||
|
||||
### Named vectors (different embedding models)
|
||||
|
||||
```python
|
||||
from qdrant_client.models import VectorParams, Distance
|
||||
|
||||
# Collection with multiple vector types
|
||||
client.create_collection(
|
||||
collection_name="hybrid_search",
|
||||
vectors_config={
|
||||
"dense": VectorParams(size=384, distance=Distance.COSINE),
|
||||
"sparse": VectorParams(size=30000, distance=Distance.DOT)
|
||||
}
|
||||
)
|
||||
|
||||
# Insert with named vectors
|
||||
client.upsert(
|
||||
collection_name="hybrid_search",
|
||||
points=[
|
||||
PointStruct(
|
||||
id=1,
|
||||
vector={
|
||||
"dense": dense_embedding,
|
||||
"sparse": sparse_embedding
|
||||
},
|
||||
payload={"text": "document text"}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Search specific vector
|
||||
results = client.search(
|
||||
collection_name="hybrid_search",
|
||||
query_vector=("dense", query_dense), # Specify which vector
|
||||
limit=10
|
||||
)
|
||||
```
|
||||
|
||||
### Sparse vectors (BM25, SPLADE)
|
||||
|
||||
```python
|
||||
from qdrant_client.models import SparseVectorParams, SparseIndexParams, SparseVector
|
||||
|
||||
# Collection with sparse vectors
|
||||
client.create_collection(
|
||||
collection_name="sparse_search",
|
||||
vectors_config={},
|
||||
sparse_vectors_config={"text": SparseVectorParams(index=SparseIndexParams(on_disk=False))}
|
||||
)
|
||||
|
||||
# Insert sparse vector
|
||||
client.upsert(
|
||||
collection_name="sparse_search",
|
||||
points=[PointStruct(id=1, vector={"text": SparseVector(indices=[1, 5, 100], values=[0.5, 0.8, 0.2])}, payload={"text": "document"})]
|
||||
)
|
||||
```
|
||||
|
||||
## Quantization (memory optimization)
|
||||
|
||||
```python
|
||||
from qdrant_client.models import ScalarQuantization, ScalarQuantizationConfig, ScalarType
|
||||
|
||||
# Scalar quantization (4x memory reduction)
|
||||
client.create_collection(
|
||||
collection_name="quantized",
|
||||
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
|
||||
quantization_config=ScalarQuantization(
|
||||
scalar=ScalarQuantizationConfig(
|
||||
type=ScalarType.INT8,
|
||||
quantile=0.99, # Clip outliers
|
||||
always_ram=True # Keep quantized in RAM
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Search with rescoring
|
||||
results = client.search(
|
||||
collection_name="quantized",
|
||||
query_vector=query,
|
||||
search_params={"quantization": {"rescore": True}}, # Rescore top results
|
||||
limit=10
|
||||
)
|
||||
```
|
||||
|
||||
## Payload indexing
|
||||
|
||||
```python
|
||||
from qdrant_client.models import PayloadSchemaType
|
||||
|
||||
# Create payload index for faster filtering
|
||||
client.create_payload_index(
|
||||
collection_name="documents",
|
||||
field_name="category",
|
||||
field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
|
||||
client.create_payload_index(
|
||||
collection_name="documents",
|
||||
field_name="timestamp",
|
||||
field_schema=PayloadSchemaType.INTEGER
|
||||
)
|
||||
|
||||
# Index types: KEYWORD, INTEGER, FLOAT, GEO, TEXT (full-text), BOOL
|
||||
```
|
||||
|
||||
## Production deployment
|
||||
|
||||
### Qdrant Cloud
|
||||
|
||||
```python
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
# Connect to Qdrant Cloud
|
||||
client = QdrantClient(
|
||||
url="https://your-cluster.cloud.qdrant.io",
|
||||
api_key="your-api-key"
|
||||
)
|
||||
```
|
||||
|
||||
### Performance tuning
|
||||
|
||||
```python
|
||||
# Optimize for search speed (higher recall)
|
||||
client.update_collection(
|
||||
collection_name="documents",
|
||||
hnsw_config=HnswConfigDiff(ef_construct=200, m=32)
|
||||
)
|
||||
|
||||
# Optimize for indexing speed (bulk loads)
|
||||
client.update_collection(
|
||||
collection_name="documents",
|
||||
optimizer_config={"indexing_threshold": 20000}
|
||||
)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Batch operations** - Use batch upsert/search for efficiency
|
||||
2. **Payload indexing** - Index fields used in filters
|
||||
3. **Quantization** - Enable for large collections (>1M vectors)
|
||||
4. **Sharding** - Use for collections >10M vectors
|
||||
5. **On-disk storage** - Enable `on_disk_payload` for large payloads
|
||||
6. **Connection pooling** - Reuse client instances
|
||||
|
||||
## Common issues
|
||||
|
||||
**Slow search with filters:**
|
||||
```python
|
||||
# Create payload index for filtered fields
|
||||
client.create_payload_index(
|
||||
collection_name="docs",
|
||||
field_name="category",
|
||||
field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
```
|
||||
|
||||
**Out of memory:**
|
||||
```python
|
||||
# Enable quantization and on-disk storage
|
||||
client.create_collection(
|
||||
collection_name="large_collection",
|
||||
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
|
||||
quantization_config=ScalarQuantization(...),
|
||||
on_disk_payload=True
|
||||
)
|
||||
```
|
||||
|
||||
**Connection issues:**
|
||||
```python
|
||||
# Use timeout and retry
|
||||
client = QdrantClient(
|
||||
host="localhost",
|
||||
port=6333,
|
||||
timeout=30,
|
||||
prefer_grpc=True # gRPC for better performance
|
||||
)
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Distributed mode, hybrid search, recommendations
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues, debugging, performance tuning
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/qdrant/qdrant (22k+ stars)
|
||||
- **Docs**: https://qdrant.tech/documentation/
|
||||
- **Python Client**: https://github.com/qdrant/qdrant-client
|
||||
- **Cloud**: https://cloud.qdrant.io
|
||||
- **Version**: 1.12.0+
|
||||
- **License**: Apache 2.0
|
||||
648
optional-skills/mlops/qdrant/references/advanced-usage.md
Normal file
648
optional-skills/mlops/qdrant/references/advanced-usage.md
Normal file
|
|
@ -0,0 +1,648 @@
|
|||
# Qdrant Advanced Usage Guide
|
||||
|
||||
## Distributed Deployment
|
||||
|
||||
### Cluster Setup
|
||||
|
||||
Qdrant uses Raft consensus for distributed coordination.
|
||||
|
||||
```yaml
|
||||
# docker-compose.yml for 3-node cluster
|
||||
version: '3.8'
|
||||
services:
|
||||
qdrant-node-1:
|
||||
image: qdrant/qdrant:latest
|
||||
ports:
|
||||
- "6333:6333"
|
||||
- "6334:6334"
|
||||
- "6335:6335"
|
||||
volumes:
|
||||
- ./node1_storage:/qdrant/storage
|
||||
environment:
|
||||
- QDRANT__CLUSTER__ENABLED=true
|
||||
- QDRANT__CLUSTER__P2P__PORT=6335
|
||||
- QDRANT__SERVICE__HTTP_PORT=6333
|
||||
- QDRANT__SERVICE__GRPC_PORT=6334
|
||||
|
||||
qdrant-node-2:
|
||||
image: qdrant/qdrant:latest
|
||||
ports:
|
||||
- "6343:6333"
|
||||
- "6344:6334"
|
||||
- "6345:6335"
|
||||
volumes:
|
||||
- ./node2_storage:/qdrant/storage
|
||||
environment:
|
||||
- QDRANT__CLUSTER__ENABLED=true
|
||||
- QDRANT__CLUSTER__P2P__PORT=6335
|
||||
- QDRANT__CLUSTER__BOOTSTRAP=http://qdrant-node-1:6335
|
||||
depends_on:
|
||||
- qdrant-node-1
|
||||
|
||||
qdrant-node-3:
|
||||
image: qdrant/qdrant:latest
|
||||
ports:
|
||||
- "6353:6333"
|
||||
- "6354:6334"
|
||||
- "6355:6335"
|
||||
volumes:
|
||||
- ./node3_storage:/qdrant/storage
|
||||
environment:
|
||||
- QDRANT__CLUSTER__ENABLED=true
|
||||
- QDRANT__CLUSTER__P2P__PORT=6335
|
||||
- QDRANT__CLUSTER__BOOTSTRAP=http://qdrant-node-1:6335
|
||||
depends_on:
|
||||
- qdrant-node-1
|
||||
```
|
||||
|
||||
### Sharding Configuration
|
||||
|
||||
```python
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import VectorParams, Distance, ShardingMethod
|
||||
|
||||
client = QdrantClient(host="localhost", port=6333)
|
||||
|
||||
# Create sharded collection
|
||||
client.create_collection(
|
||||
collection_name="large_collection",
|
||||
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
|
||||
shard_number=6, # Number of shards
|
||||
replication_factor=2, # Replicas per shard
|
||||
write_consistency_factor=1 # Required acks for write
|
||||
)
|
||||
|
||||
# Check cluster status
|
||||
cluster_info = client.get_cluster_info()
|
||||
print(f"Peers: {cluster_info.peers}")
|
||||
print(f"Raft state: {cluster_info.raft_info}")
|
||||
```
|
||||
|
||||
### Replication and Consistency
|
||||
|
||||
```python
|
||||
from qdrant_client.models import WriteOrdering
|
||||
|
||||
# Strong consistency write
|
||||
client.upsert(
|
||||
collection_name="critical_data",
|
||||
points=points,
|
||||
ordering=WriteOrdering.STRONG # Wait for all replicas
|
||||
)
|
||||
|
||||
# Eventual consistency (faster)
|
||||
client.upsert(
|
||||
collection_name="logs",
|
||||
points=points,
|
||||
ordering=WriteOrdering.WEAK # Return after primary ack
|
||||
)
|
||||
|
||||
# Read from specific shard
|
||||
results = client.search(
|
||||
collection_name="documents",
|
||||
query_vector=query,
|
||||
consistency="majority" # Read from majority of replicas
|
||||
)
|
||||
```
|
||||
|
||||
## Hybrid Search
|
||||
|
||||
### Dense + Sparse Vectors
|
||||
|
||||
Combine semantic (dense) and keyword (sparse) search:
|
||||
|
||||
```python
|
||||
from qdrant_client.models import (
|
||||
VectorParams, SparseVectorParams, SparseIndexParams,
|
||||
Distance, PointStruct, SparseVector, Prefetch, Query
|
||||
)
|
||||
|
||||
# Create hybrid collection
|
||||
client.create_collection(
|
||||
collection_name="hybrid",
|
||||
vectors_config={
|
||||
"dense": VectorParams(size=384, distance=Distance.COSINE)
|
||||
},
|
||||
sparse_vectors_config={
|
||||
"sparse": SparseVectorParams(
|
||||
index=SparseIndexParams(on_disk=False)
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
# Insert with both vector types
|
||||
def encode_sparse(text: str) -> SparseVector:
|
||||
"""Simple BM25-like sparse encoding"""
|
||||
from collections import Counter
|
||||
tokens = text.lower().split()
|
||||
counts = Counter(tokens)
|
||||
# Map tokens to indices (use vocabulary in production)
|
||||
indices = [hash(t) % 30000 for t in counts.keys()]
|
||||
values = list(counts.values())
|
||||
return SparseVector(indices=indices, values=values)
|
||||
|
||||
client.upsert(
|
||||
collection_name="hybrid",
|
||||
points=[
|
||||
PointStruct(
|
||||
id=1,
|
||||
vector={
|
||||
"dense": dense_encoder.encode("Python programming").tolist(),
|
||||
"sparse": encode_sparse("Python programming language code")
|
||||
},
|
||||
payload={"text": "Python programming language code"}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Hybrid search with Reciprocal Rank Fusion (RRF)
|
||||
from qdrant_client.models import FusionQuery
|
||||
|
||||
results = client.query_points(
|
||||
collection_name="hybrid",
|
||||
prefetch=[
|
||||
Prefetch(query=dense_query, using="dense", limit=20),
|
||||
Prefetch(query=sparse_query, using="sparse", limit=20)
|
||||
],
|
||||
query=FusionQuery(fusion="rrf"), # Combine results
|
||||
limit=10
|
||||
)
|
||||
```
|
||||
|
||||
### Multi-Stage Search
|
||||
|
||||
```python
|
||||
from qdrant_client.models import Prefetch, Query
|
||||
|
||||
# Two-stage retrieval: coarse then fine
|
||||
results = client.query_points(
|
||||
collection_name="documents",
|
||||
prefetch=[
|
||||
Prefetch(
|
||||
query=query_vector,
|
||||
limit=100, # Broad first stage
|
||||
params={"quantization": {"rescore": False}} # Fast, approximate
|
||||
)
|
||||
],
|
||||
query=Query(nearest=query_vector),
|
||||
limit=10,
|
||||
params={"quantization": {"rescore": True}} # Accurate reranking
|
||||
)
|
||||
```
|
||||
|
||||
## Recommendations
|
||||
|
||||
### Item-to-Item Recommendations
|
||||
|
||||
```python
|
||||
# Find similar items
|
||||
recommendations = client.recommend(
|
||||
collection_name="products",
|
||||
positive=[1, 2, 3], # IDs user liked
|
||||
negative=[4], # IDs user disliked
|
||||
limit=10
|
||||
)
|
||||
|
||||
# With filtering
|
||||
recommendations = client.recommend(
|
||||
collection_name="products",
|
||||
positive=[1, 2],
|
||||
query_filter={
|
||||
"must": [
|
||||
{"key": "category", "match": {"value": "electronics"}},
|
||||
{"key": "in_stock", "match": {"value": True}}
|
||||
]
|
||||
},
|
||||
limit=10
|
||||
)
|
||||
```
|
||||
|
||||
### Lookup from Another Collection
|
||||
|
||||
```python
|
||||
from qdrant_client.models import RecommendStrategy, LookupLocation
|
||||
|
||||
# Recommend using vectors from another collection
|
||||
results = client.recommend(
|
||||
collection_name="products",
|
||||
positive=[
|
||||
LookupLocation(
|
||||
collection_name="user_history",
|
||||
id="user_123"
|
||||
)
|
||||
],
|
||||
strategy=RecommendStrategy.AVERAGE_VECTOR,
|
||||
limit=10
|
||||
)
|
||||
```
|
||||
|
||||
## Advanced Filtering
|
||||
|
||||
### Nested Payload Filtering
|
||||
|
||||
```python
|
||||
from qdrant_client.models import Filter, FieldCondition, MatchValue, NestedCondition
|
||||
|
||||
# Filter on nested objects
|
||||
results = client.search(
|
||||
collection_name="documents",
|
||||
query_vector=query,
|
||||
query_filter=Filter(
|
||||
must=[
|
||||
NestedCondition(
|
||||
key="metadata",
|
||||
filter=Filter(
|
||||
must=[
|
||||
FieldCondition(
|
||||
key="author.name",
|
||||
match=MatchValue(value="John")
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
]
|
||||
),
|
||||
limit=10
|
||||
)
|
||||
```
|
||||
|
||||
### Geo Filtering
|
||||
|
||||
```python
|
||||
from qdrant_client.models import FieldCondition, GeoRadius, GeoPoint
|
||||
|
||||
# Find within radius
|
||||
results = client.search(
|
||||
collection_name="locations",
|
||||
query_vector=query,
|
||||
query_filter=Filter(
|
||||
must=[
|
||||
FieldCondition(
|
||||
key="location",
|
||||
geo_radius=GeoRadius(
|
||||
center=GeoPoint(lat=40.7128, lon=-74.0060),
|
||||
radius=5000 # meters
|
||||
)
|
||||
)
|
||||
]
|
||||
),
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Geo bounding box
|
||||
from qdrant_client.models import GeoBoundingBox
|
||||
|
||||
results = client.search(
|
||||
collection_name="locations",
|
||||
query_vector=query,
|
||||
query_filter=Filter(
|
||||
must=[
|
||||
FieldCondition(
|
||||
key="location",
|
||||
geo_bounding_box=GeoBoundingBox(
|
||||
top_left=GeoPoint(lat=40.8, lon=-74.1),
|
||||
bottom_right=GeoPoint(lat=40.6, lon=-73.9)
|
||||
)
|
||||
)
|
||||
]
|
||||
),
|
||||
limit=10
|
||||
)
|
||||
```
|
||||
|
||||
### Full-Text Search
|
||||
|
||||
```python
|
||||
from qdrant_client.models import TextIndexParams, TokenizerType
|
||||
|
||||
# Create text index
|
||||
client.create_payload_index(
|
||||
collection_name="documents",
|
||||
field_name="content",
|
||||
field_schema=TextIndexParams(
|
||||
type="text",
|
||||
tokenizer=TokenizerType.WORD,
|
||||
min_token_len=2,
|
||||
max_token_len=15,
|
||||
lowercase=True
|
||||
)
|
||||
)
|
||||
|
||||
# Full-text filter
|
||||
from qdrant_client.models import MatchText
|
||||
|
||||
results = client.search(
|
||||
collection_name="documents",
|
||||
query_vector=query,
|
||||
query_filter=Filter(
|
||||
must=[
|
||||
FieldCondition(
|
||||
key="content",
|
||||
match=MatchText(text="machine learning")
|
||||
)
|
||||
]
|
||||
),
|
||||
limit=10
|
||||
)
|
||||
```
|
||||
|
||||
## Quantization Strategies
|
||||
|
||||
### Scalar Quantization (INT8)
|
||||
|
||||
```python
|
||||
from qdrant_client.models import ScalarQuantization, ScalarQuantizationConfig, ScalarType
|
||||
|
||||
# ~4x memory reduction, minimal accuracy loss
|
||||
client.create_collection(
|
||||
collection_name="scalar_quantized",
|
||||
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
|
||||
quantization_config=ScalarQuantization(
|
||||
scalar=ScalarQuantizationConfig(
|
||||
type=ScalarType.INT8,
|
||||
quantile=0.99, # Clip extreme values
|
||||
always_ram=True # Keep quantized vectors in RAM
|
||||
)
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
### Product Quantization
|
||||
|
||||
```python
|
||||
from qdrant_client.models import ProductQuantization, ProductQuantizationConfig, CompressionRatio
|
||||
|
||||
# ~16x memory reduction, some accuracy loss
|
||||
client.create_collection(
|
||||
collection_name="product_quantized",
|
||||
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
|
||||
quantization_config=ProductQuantization(
|
||||
product=ProductQuantizationConfig(
|
||||
compression=CompressionRatio.X16,
|
||||
always_ram=True
|
||||
)
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
### Binary Quantization
|
||||
|
||||
```python
|
||||
from qdrant_client.models import BinaryQuantization, BinaryQuantizationConfig
|
||||
|
||||
# ~32x memory reduction, requires oversampling
|
||||
client.create_collection(
|
||||
collection_name="binary_quantized",
|
||||
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
|
||||
quantization_config=BinaryQuantization(
|
||||
binary=BinaryQuantizationConfig(always_ram=True)
|
||||
)
|
||||
)
|
||||
|
||||
# Search with oversampling
|
||||
results = client.search(
|
||||
collection_name="binary_quantized",
|
||||
query_vector=query,
|
||||
search_params={
|
||||
"quantization": {
|
||||
"rescore": True,
|
||||
"oversampling": 2.0 # Retrieve 2x candidates, rescore
|
||||
}
|
||||
},
|
||||
limit=10
|
||||
)
|
||||
```
|
||||
|
||||
## Snapshots and Backups
|
||||
|
||||
### Create Snapshot
|
||||
|
||||
```python
|
||||
# Create collection snapshot
|
||||
snapshot_info = client.create_snapshot(collection_name="documents")
|
||||
print(f"Snapshot: {snapshot_info.name}")
|
||||
|
||||
# List snapshots
|
||||
snapshots = client.list_snapshots(collection_name="documents")
|
||||
for s in snapshots:
|
||||
print(f"{s.name}: {s.size} bytes")
|
||||
|
||||
# Full storage snapshot
|
||||
full_snapshot = client.create_full_snapshot()
|
||||
```
|
||||
|
||||
### Restore from Snapshot
|
||||
|
||||
```python
|
||||
# Download snapshot
|
||||
client.download_snapshot(
|
||||
collection_name="documents",
|
||||
snapshot_name="documents-2024-01-01.snapshot",
|
||||
target_path="./backup/"
|
||||
)
|
||||
|
||||
# Restore (via REST API)
|
||||
import requests
|
||||
|
||||
response = requests.put(
|
||||
"http://localhost:6333/collections/documents/snapshots/recover",
|
||||
json={"location": "file:///backup/documents-2024-01-01.snapshot"}
|
||||
)
|
||||
```
|
||||
|
||||
## Collection Aliases
|
||||
|
||||
```python
|
||||
# Create alias
|
||||
client.update_collection_aliases(
|
||||
change_aliases_operations=[
|
||||
{"create_alias": {"alias_name": "production", "collection_name": "documents_v2"}}
|
||||
]
|
||||
)
|
||||
|
||||
# Blue-green deployment
|
||||
# 1. Create new collection with updates
|
||||
client.create_collection(collection_name="documents_v3", ...)
|
||||
|
||||
# 2. Populate new collection
|
||||
client.upsert(collection_name="documents_v3", points=new_points)
|
||||
|
||||
# 3. Atomic switch
|
||||
client.update_collection_aliases(
|
||||
change_aliases_operations=[
|
||||
{"delete_alias": {"alias_name": "production"}},
|
||||
{"create_alias": {"alias_name": "production", "collection_name": "documents_v3"}}
|
||||
]
|
||||
)
|
||||
|
||||
# Search via alias
|
||||
results = client.search(collection_name="production", query_vector=query, limit=10)
|
||||
```
|
||||
|
||||
## Scroll and Iteration
|
||||
|
||||
### Scroll Through All Points
|
||||
|
||||
```python
|
||||
# Paginated iteration
|
||||
offset = None
|
||||
all_points = []
|
||||
|
||||
while True:
|
||||
results, offset = client.scroll(
|
||||
collection_name="documents",
|
||||
limit=100,
|
||||
offset=offset,
|
||||
with_payload=True,
|
||||
with_vectors=False
|
||||
)
|
||||
all_points.extend(results)
|
||||
|
||||
if offset is None:
|
||||
break
|
||||
|
||||
print(f"Total points: {len(all_points)}")
|
||||
```
|
||||
|
||||
### Filtered Scroll
|
||||
|
||||
```python
|
||||
# Scroll with filter
|
||||
results, _ = client.scroll(
|
||||
collection_name="documents",
|
||||
scroll_filter=Filter(
|
||||
must=[
|
||||
FieldCondition(key="status", match=MatchValue(value="active"))
|
||||
]
|
||||
),
|
||||
limit=1000
|
||||
)
|
||||
```
|
||||
|
||||
## Async Client
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
|
||||
async def main():
|
||||
client = AsyncQdrantClient(host="localhost", port=6333)
|
||||
|
||||
# Async operations
|
||||
await client.create_collection(
|
||||
collection_name="async_docs",
|
||||
vectors_config=VectorParams(size=384, distance=Distance.COSINE)
|
||||
)
|
||||
|
||||
await client.upsert(
|
||||
collection_name="async_docs",
|
||||
points=points
|
||||
)
|
||||
|
||||
results = await client.search(
|
||||
collection_name="async_docs",
|
||||
query_vector=query,
|
||||
limit=10
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
results = asyncio.run(main())
|
||||
```
|
||||
|
||||
## gRPC Client
|
||||
|
||||
```python
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
# Prefer gRPC for better performance
|
||||
client = QdrantClient(
|
||||
host="localhost",
|
||||
port=6333,
|
||||
grpc_port=6334,
|
||||
prefer_grpc=True # Use gRPC when available
|
||||
)
|
||||
|
||||
# gRPC-only client
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
client = QdrantClient(
|
||||
host="localhost",
|
||||
grpc_port=6334,
|
||||
prefer_grpc=True,
|
||||
https=False
|
||||
)
|
||||
```
|
||||
|
||||
## Multitenancy
|
||||
|
||||
### Payload-Based Isolation
|
||||
|
||||
```python
|
||||
# Single collection, filter by tenant
|
||||
client.upsert(
|
||||
collection_name="multi_tenant",
|
||||
points=[
|
||||
PointStruct(
|
||||
id=1,
|
||||
vector=embedding,
|
||||
payload={"tenant_id": "tenant_a", "text": "..."}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Search within tenant
|
||||
results = client.search(
|
||||
collection_name="multi_tenant",
|
||||
query_vector=query,
|
||||
query_filter=Filter(
|
||||
must=[FieldCondition(key="tenant_id", match=MatchValue(value="tenant_a"))]
|
||||
),
|
||||
limit=10
|
||||
)
|
||||
```
|
||||
|
||||
### Collection-Per-Tenant
|
||||
|
||||
```python
|
||||
# Create tenant collection
|
||||
def create_tenant_collection(tenant_id: str):
|
||||
client.create_collection(
|
||||
collection_name=f"tenant_{tenant_id}",
|
||||
vectors_config=VectorParams(size=384, distance=Distance.COSINE)
|
||||
)
|
||||
|
||||
# Search tenant collection
|
||||
def search_tenant(tenant_id: str, query_vector: list, limit: int = 10):
|
||||
return client.search(
|
||||
collection_name=f"tenant_{tenant_id}",
|
||||
query_vector=query_vector,
|
||||
limit=limit
|
||||
)
|
||||
```
|
||||
|
||||
## Performance Monitoring
|
||||
|
||||
### Collection Statistics
|
||||
|
||||
```python
|
||||
# Collection info
|
||||
info = client.get_collection("documents")
|
||||
print(f"Points: {info.points_count}")
|
||||
print(f"Indexed vectors: {info.indexed_vectors_count}")
|
||||
print(f"Segments: {len(info.segments)}")
|
||||
print(f"Status: {info.status}")
|
||||
|
||||
# Detailed segment info
|
||||
for i, segment in enumerate(info.segments):
|
||||
print(f"Segment {i}: {segment}")
|
||||
```
|
||||
|
||||
### Telemetry
|
||||
|
||||
```python
|
||||
# Get telemetry data
|
||||
telemetry = client.get_telemetry()
|
||||
print(f"Collections: {telemetry.collections}")
|
||||
print(f"Operations: {telemetry.operations}")
|
||||
```
|
||||
631
optional-skills/mlops/qdrant/references/troubleshooting.md
Normal file
631
optional-skills/mlops/qdrant/references/troubleshooting.md
Normal file
|
|
@ -0,0 +1,631 @@
|
|||
# Qdrant Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### Docker Issues
|
||||
|
||||
**Error**: `Cannot connect to Docker daemon`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Start Docker daemon
|
||||
sudo systemctl start docker
|
||||
|
||||
# Or use Docker Desktop on Mac/Windows
|
||||
open -a Docker
|
||||
```
|
||||
|
||||
**Error**: `Port 6333 already in use`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Find process using port
|
||||
lsof -i :6333
|
||||
|
||||
# Kill process or use different port
|
||||
docker run -p 6334:6333 qdrant/qdrant
|
||||
```
|
||||
|
||||
### Python Client Issues
|
||||
|
||||
**Error**: `ModuleNotFoundError: No module named 'qdrant_client'`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
pip install qdrant-client
|
||||
|
||||
# With specific version
|
||||
pip install qdrant-client>=1.12.0
|
||||
```
|
||||
|
||||
**Error**: `grpc._channel._InactiveRpcError`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Install with gRPC support
|
||||
pip install 'qdrant-client[grpc]'
|
||||
|
||||
# Or disable gRPC
|
||||
client = QdrantClient(host="localhost", port=6333, prefer_grpc=False)
|
||||
```
|
||||
|
||||
## Connection Issues
|
||||
|
||||
### Cannot Connect to Server
|
||||
|
||||
**Error**: `ConnectionRefusedError: [Errno 111] Connection refused`
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Check server is running**:
|
||||
```bash
|
||||
docker ps | grep qdrant
|
||||
curl http://localhost:6333/healthz
|
||||
```
|
||||
|
||||
2. **Verify port binding**:
|
||||
```bash
|
||||
# Check listening ports
|
||||
netstat -tlnp | grep 6333
|
||||
|
||||
# Docker port mapping
|
||||
docker port <container_id>
|
||||
```
|
||||
|
||||
3. **Use correct host**:
|
||||
```python
|
||||
# Docker on Linux
|
||||
client = QdrantClient(host="localhost", port=6333)
|
||||
|
||||
# Docker on Mac/Windows with networking issues
|
||||
client = QdrantClient(host="127.0.0.1", port=6333)
|
||||
|
||||
# Inside Docker network
|
||||
client = QdrantClient(host="qdrant", port=6333)
|
||||
```
|
||||
|
||||
### Timeout Errors
|
||||
|
||||
**Error**: `TimeoutError: Connection timed out`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Increase timeout
|
||||
client = QdrantClient(
|
||||
host="localhost",
|
||||
port=6333,
|
||||
timeout=60 # seconds
|
||||
)
|
||||
|
||||
# For large operations
|
||||
client.upsert(
|
||||
collection_name="documents",
|
||||
points=large_batch,
|
||||
wait=False # Don't wait for indexing
|
||||
)
|
||||
```
|
||||
|
||||
### SSL/TLS Errors
|
||||
|
||||
**Error**: `ssl.SSLCertVerificationError`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Qdrant Cloud
|
||||
client = QdrantClient(
|
||||
url="https://cluster.cloud.qdrant.io",
|
||||
api_key="your-api-key"
|
||||
)
|
||||
|
||||
# Self-signed certificate
|
||||
client = QdrantClient(
|
||||
host="localhost",
|
||||
port=6333,
|
||||
https=True,
|
||||
verify=False # Disable verification (not recommended for production)
|
||||
)
|
||||
```
|
||||
|
||||
## Collection Issues
|
||||
|
||||
### Collection Already Exists
|
||||
|
||||
**Error**: `ValueError: Collection 'documents' already exists`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Check before creating
|
||||
collections = client.get_collections().collections
|
||||
names = [c.name for c in collections]
|
||||
|
||||
if "documents" not in names:
|
||||
client.create_collection(...)
|
||||
|
||||
# Or recreate
|
||||
client.recreate_collection(
|
||||
collection_name="documents",
|
||||
vectors_config=VectorParams(size=384, distance=Distance.COSINE)
|
||||
)
|
||||
```
|
||||
|
||||
### Collection Not Found
|
||||
|
||||
**Error**: `NotFoundException: Collection 'docs' not found`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# List available collections
|
||||
collections = client.get_collections()
|
||||
print([c.name for c in collections.collections])
|
||||
|
||||
# Check exact name (case-sensitive)
|
||||
try:
|
||||
info = client.get_collection("documents")
|
||||
except Exception as e:
|
||||
print(f"Collection not found: {e}")
|
||||
```
|
||||
|
||||
### Vector Dimension Mismatch
|
||||
|
||||
**Error**: `ValueError: Vector dimension mismatch. Expected 384, got 768`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Check collection config
|
||||
info = client.get_collection("documents")
|
||||
print(f"Expected dimension: {info.config.params.vectors.size}")
|
||||
|
||||
# Recreate with correct dimension
|
||||
client.recreate_collection(
|
||||
collection_name="documents",
|
||||
vectors_config=VectorParams(size=768, distance=Distance.COSINE) # Match your embeddings
|
||||
)
|
||||
```
|
||||
|
||||
## Search Issues
|
||||
|
||||
### Empty Search Results
|
||||
|
||||
**Problem**: Search returns empty results.
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Verify data exists**:
|
||||
```python
|
||||
info = client.get_collection("documents")
|
||||
print(f"Points: {info.points_count}")
|
||||
|
||||
# Scroll to check data
|
||||
points, _ = client.scroll(
|
||||
collection_name="documents",
|
||||
limit=10,
|
||||
with_payload=True
|
||||
)
|
||||
print(points)
|
||||
```
|
||||
|
||||
2. **Check vector format**:
|
||||
```python
|
||||
# Must be list of floats
|
||||
query_vector = embedding.tolist() # Convert numpy to list
|
||||
|
||||
# Check dimensions
|
||||
print(f"Query dimension: {len(query_vector)}")
|
||||
```
|
||||
|
||||
3. **Verify filter conditions**:
|
||||
```python
|
||||
# Test without filter first
|
||||
results = client.search(
|
||||
collection_name="documents",
|
||||
query_vector=query,
|
||||
limit=10
|
||||
# No filter
|
||||
)
|
||||
|
||||
# Then add filter incrementally
|
||||
```
|
||||
|
||||
### Slow Search Performance
|
||||
|
||||
**Problem**: Search takes too long.
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Create payload indexes**:
|
||||
```python
|
||||
# Index fields used in filters
|
||||
client.create_payload_index(
|
||||
collection_name="documents",
|
||||
field_name="category",
|
||||
field_schema="keyword"
|
||||
)
|
||||
```
|
||||
|
||||
2. **Enable quantization**:
|
||||
```python
|
||||
client.update_collection(
|
||||
collection_name="documents",
|
||||
quantization_config=ScalarQuantization(
|
||||
scalar=ScalarQuantizationConfig(type=ScalarType.INT8)
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
3. **Tune HNSW parameters**:
|
||||
```python
|
||||
# Faster search (less accurate)
|
||||
client.update_collection(
|
||||
collection_name="documents",
|
||||
hnsw_config=HnswConfigDiff(ef_construct=64, m=8)
|
||||
)
|
||||
|
||||
# Use ef search parameter
|
||||
results = client.search(
|
||||
collection_name="documents",
|
||||
query_vector=query,
|
||||
search_params={"hnsw_ef": 64}, # Lower = faster
|
||||
limit=10
|
||||
)
|
||||
```
|
||||
|
||||
4. **Use gRPC**:
|
||||
```python
|
||||
client = QdrantClient(
|
||||
host="localhost",
|
||||
port=6333,
|
||||
grpc_port=6334,
|
||||
prefer_grpc=True
|
||||
)
|
||||
```
|
||||
|
||||
### Inconsistent Results
|
||||
|
||||
**Problem**: Same query returns different results.
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Wait for indexing**:
|
||||
```python
|
||||
client.upsert(
|
||||
collection_name="documents",
|
||||
points=points,
|
||||
wait=True # Wait for index update
|
||||
)
|
||||
```
|
||||
|
||||
2. **Check replication consistency**:
|
||||
```python
|
||||
# Strong consistency read
|
||||
results = client.search(
|
||||
collection_name="documents",
|
||||
query_vector=query,
|
||||
consistency="all" # Read from all replicas
|
||||
)
|
||||
```
|
||||
|
||||
## Upsert Issues
|
||||
|
||||
### Batch Upsert Fails
|
||||
|
||||
**Error**: `PayloadError: Payload too large`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Split into smaller batches
|
||||
def batch_upsert(client, collection, points, batch_size=100):
|
||||
for i in range(0, len(points), batch_size):
|
||||
batch = points[i:i + batch_size]
|
||||
client.upsert(
|
||||
collection_name=collection,
|
||||
points=batch,
|
||||
wait=True
|
||||
)
|
||||
|
||||
batch_upsert(client, "documents", large_points_list)
|
||||
```
|
||||
|
||||
### Invalid Point ID
|
||||
|
||||
**Error**: `ValueError: Invalid point ID`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Valid ID types: int or UUID string
|
||||
from uuid import uuid4
|
||||
|
||||
# Integer ID
|
||||
PointStruct(id=123, vector=vec, payload={})
|
||||
|
||||
# UUID string
|
||||
PointStruct(id=str(uuid4()), vector=vec, payload={})
|
||||
|
||||
# NOT valid
|
||||
PointStruct(id="custom-string-123", ...) # Use UUID format
|
||||
```
|
||||
|
||||
### Payload Validation Errors
|
||||
|
||||
**Error**: `ValidationError: Invalid payload`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Ensure JSON-serializable payload
|
||||
import json
|
||||
|
||||
payload = {
|
||||
"title": "Document",
|
||||
"count": 42,
|
||||
"tags": ["a", "b"],
|
||||
"nested": {"key": "value"}
|
||||
}
|
||||
|
||||
# Validate before upsert
|
||||
json.dumps(payload) # Should not raise
|
||||
|
||||
# Avoid non-serializable types
|
||||
# NOT valid: datetime, numpy arrays, custom objects
|
||||
payload = {
|
||||
"timestamp": datetime.now().isoformat(), # Convert to string
|
||||
"vector": embedding.tolist() # Convert numpy to list
|
||||
}
|
||||
```
|
||||
|
||||
## Memory Issues
|
||||
|
||||
### Out of Memory
|
||||
|
||||
**Error**: `MemoryError` or container killed
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Enable on-disk storage**:
|
||||
```python
|
||||
client.create_collection(
|
||||
collection_name="large_collection",
|
||||
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
|
||||
on_disk_payload=True, # Store payloads on disk
|
||||
hnsw_config=HnswConfigDiff(on_disk=True) # Store HNSW on disk
|
||||
)
|
||||
```
|
||||
|
||||
2. **Use quantization**:
|
||||
```python
|
||||
# 4x memory reduction
|
||||
client.update_collection(
|
||||
collection_name="large_collection",
|
||||
quantization_config=ScalarQuantization(
|
||||
scalar=ScalarQuantizationConfig(
|
||||
type=ScalarType.INT8,
|
||||
always_ram=False # Keep on disk
|
||||
)
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
3. **Increase Docker memory**:
|
||||
```bash
|
||||
docker run -m 8g -p 6333:6333 qdrant/qdrant
|
||||
```
|
||||
|
||||
4. **Configure Qdrant storage**:
|
||||
```yaml
|
||||
# config.yaml
|
||||
storage:
|
||||
performance:
|
||||
max_search_threads: 2
|
||||
optimizers:
|
||||
memmap_threshold_kb: 20000
|
||||
```
|
||||
|
||||
### High Memory Usage During Indexing
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Increase indexing threshold for bulk loads
|
||||
client.update_collection(
|
||||
collection_name="documents",
|
||||
optimizer_config={
|
||||
"indexing_threshold": 50000 # Delay indexing
|
||||
}
|
||||
)
|
||||
|
||||
# Bulk insert
|
||||
client.upsert(collection_name="documents", points=all_points, wait=False)
|
||||
|
||||
# Then optimize
|
||||
client.update_collection(
|
||||
collection_name="documents",
|
||||
optimizer_config={
|
||||
"indexing_threshold": 10000 # Resume normal indexing
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Cluster Issues
|
||||
|
||||
### Node Not Joining Cluster
|
||||
|
||||
**Problem**: New node fails to join cluster.
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Check network connectivity
|
||||
docker exec qdrant-node-2 ping qdrant-node-1
|
||||
|
||||
# Verify bootstrap URL
|
||||
docker logs qdrant-node-2 | grep bootstrap
|
||||
|
||||
# Check Raft state
|
||||
curl http://localhost:6333/cluster
|
||||
```
|
||||
|
||||
### Split Brain
|
||||
|
||||
**Problem**: Cluster has inconsistent state.
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Force leader election
|
||||
curl -X POST http://localhost:6333/cluster/recover
|
||||
|
||||
# Or restart minority nodes
|
||||
docker restart qdrant-node-2 qdrant-node-3
|
||||
```
|
||||
|
||||
### Replication Lag
|
||||
|
||||
**Problem**: Replicas fall behind.
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Check collection status
|
||||
info = client.get_collection("documents")
|
||||
print(f"Status: {info.status}")
|
||||
|
||||
# Use strong consistency for critical writes
|
||||
client.upsert(
|
||||
collection_name="documents",
|
||||
points=points,
|
||||
ordering=WriteOrdering.STRONG
|
||||
)
|
||||
```
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### Benchmark Configuration
|
||||
|
||||
```python
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
def benchmark_search(client, collection, n_queries=100, dimension=384):
|
||||
# Generate random queries
|
||||
queries = [np.random.rand(dimension).tolist() for _ in range(n_queries)]
|
||||
|
||||
# Warmup
|
||||
for q in queries[:10]:
|
||||
client.search(collection_name=collection, query_vector=q, limit=10)
|
||||
|
||||
# Benchmark
|
||||
start = time.perf_counter()
|
||||
for q in queries:
|
||||
client.search(collection_name=collection, query_vector=q, limit=10)
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
print(f"QPS: {n_queries / elapsed:.2f}")
|
||||
print(f"Latency: {elapsed / n_queries * 1000:.2f}ms")
|
||||
|
||||
benchmark_search(client, "documents")
|
||||
```
|
||||
|
||||
### Optimal HNSW Parameters
|
||||
|
||||
```python
|
||||
# High recall (slower)
|
||||
client.create_collection(
|
||||
collection_name="high_recall",
|
||||
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
|
||||
hnsw_config=HnswConfigDiff(
|
||||
m=32, # More connections
|
||||
ef_construct=200 # Higher build quality
|
||||
)
|
||||
)
|
||||
|
||||
# High speed (lower recall)
|
||||
client.create_collection(
|
||||
collection_name="high_speed",
|
||||
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
|
||||
hnsw_config=HnswConfigDiff(
|
||||
m=8, # Fewer connections
|
||||
ef_construct=64 # Lower build quality
|
||||
)
|
||||
)
|
||||
|
||||
# Balanced
|
||||
client.create_collection(
|
||||
collection_name="balanced",
|
||||
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
|
||||
hnsw_config=HnswConfigDiff(
|
||||
m=16, # Default
|
||||
ef_construct=100 # Default
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
## Debugging Tips
|
||||
|
||||
### Enable Verbose Logging
|
||||
|
||||
```python
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logging.getLogger("qdrant_client").setLevel(logging.DEBUG)
|
||||
```
|
||||
|
||||
### Check Server Logs
|
||||
|
||||
```bash
|
||||
# Docker logs
|
||||
docker logs -f qdrant
|
||||
|
||||
# With timestamps
|
||||
docker logs --timestamps qdrant
|
||||
|
||||
# Last 100 lines
|
||||
docker logs --tail 100 qdrant
|
||||
```
|
||||
|
||||
### Inspect Collection State
|
||||
|
||||
```python
|
||||
# Collection info
|
||||
info = client.get_collection("documents")
|
||||
print(f"Status: {info.status}")
|
||||
print(f"Points: {info.points_count}")
|
||||
print(f"Segments: {len(info.segments)}")
|
||||
print(f"Config: {info.config}")
|
||||
|
||||
# Sample points
|
||||
points, _ = client.scroll(
|
||||
collection_name="documents",
|
||||
limit=5,
|
||||
with_payload=True,
|
||||
with_vectors=True
|
||||
)
|
||||
for p in points:
|
||||
print(f"ID: {p.id}, Payload: {p.payload}")
|
||||
```
|
||||
|
||||
### Test Connection
|
||||
|
||||
```python
|
||||
def test_connection(host="localhost", port=6333):
|
||||
try:
|
||||
client = QdrantClient(host=host, port=port, timeout=5)
|
||||
collections = client.get_collections()
|
||||
print(f"Connected! Collections: {len(collections.collections)}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Connection failed: {e}")
|
||||
return False
|
||||
|
||||
test_connection()
|
||||
```
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **Documentation**: https://qdrant.tech/documentation/
|
||||
2. **GitHub Issues**: https://github.com/qdrant/qdrant/issues
|
||||
3. **Discord**: https://discord.gg/qdrant
|
||||
4. **Stack Overflow**: Tag `qdrant`
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Include:
|
||||
- Qdrant version: `curl http://localhost:6333/`
|
||||
- Python client version: `pip show qdrant-client`
|
||||
- Full error traceback
|
||||
- Minimal reproducible code
|
||||
- Collection configuration
|
||||
389
optional-skills/mlops/saelens/SKILL.md
Normal file
389
optional-skills/mlops/saelens/SKILL.md
Normal file
|
|
@ -0,0 +1,389 @@
|
|||
---
|
||||
name: sparse-autoencoder-training
|
||||
description: Provides guidance for training and analyzing Sparse Autoencoders (SAEs) using SAELens to decompose neural network activations into interpretable features. Use when discovering interpretable features, analyzing superposition, or studying monosemantic representations in language models.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [sae-lens>=6.0.0, transformer-lens>=2.0.0, torch>=2.0.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Sparse Autoencoders, SAE, Mechanistic Interpretability, Feature Discovery, Superposition]
|
||||
|
||||
---
|
||||
|
||||
# SAELens: Sparse Autoencoders for Mechanistic Interpretability
|
||||
|
||||
SAELens is the primary library for training and analyzing Sparse Autoencoders (SAEs) - a technique for decomposing polysemantic neural network activations into sparse, interpretable features. Based on Anthropic's groundbreaking research on monosemanticity.
|
||||
|
||||
**GitHub**: [jbloomAus/SAELens](https://github.com/jbloomAus/SAELens) (1,100+ stars)
|
||||
|
||||
## The Problem: Polysemanticity & Superposition
|
||||
|
||||
Individual neurons in neural networks are **polysemantic** - they activate in multiple, semantically distinct contexts. This happens because models use **superposition** to represent more features than they have neurons, making interpretability difficult.
|
||||
|
||||
**SAEs solve this** by decomposing dense activations into sparse, monosemantic features - typically only a small number of features activate for any given input, and each feature corresponds to an interpretable concept.
|
||||
|
||||
## When to Use SAELens
|
||||
|
||||
**Use SAELens when you need to:**
|
||||
- Discover interpretable features in model activations
|
||||
- Understand what concepts a model has learned
|
||||
- Study superposition and feature geometry
|
||||
- Perform feature-based steering or ablation
|
||||
- Analyze safety-relevant features (deception, bias, harmful content)
|
||||
|
||||
**Consider alternatives when:**
|
||||
- You need basic activation analysis → Use **TransformerLens** directly
|
||||
- You want causal intervention experiments → Use **pyvene** or **TransformerLens**
|
||||
- You need production steering → Consider direct activation engineering
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install sae-lens
|
||||
```
|
||||
|
||||
Requirements: Python 3.10+, transformer-lens>=2.0.0
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### What SAEs Learn
|
||||
|
||||
SAEs are trained to reconstruct model activations through a sparse bottleneck:
|
||||
|
||||
```
|
||||
Input Activation → Encoder → Sparse Features → Decoder → Reconstructed Activation
|
||||
(d_model) ↓ (d_sae >> d_model) ↓ (d_model)
|
||||
sparsity reconstruction
|
||||
penalty loss
|
||||
```
|
||||
|
||||
**Loss Function**: `MSE(original, reconstructed) + L1_coefficient × L1(features)`
|
||||
|
||||
### Key Validation (Anthropic Research)
|
||||
|
||||
In "Towards Monosemanticity", human evaluators found **70% of SAE features genuinely interpretable**. Features discovered include:
|
||||
- DNA sequences, legal language, HTTP requests
|
||||
- Hebrew text, nutrition statements, code syntax
|
||||
- Sentiment, named entities, grammatical structures
|
||||
|
||||
## Workflow 1: Loading and Analyzing Pre-trained SAEs
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
from sae_lens import SAE
|
||||
|
||||
# 1. Load model and pre-trained SAE
|
||||
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
|
||||
sae, cfg_dict, sparsity = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# 2. Get model activations
|
||||
tokens = model.to_tokens("The capital of France is Paris")
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
activations = cache["resid_pre", 8] # [batch, pos, d_model]
|
||||
|
||||
# 3. Encode to SAE features
|
||||
sae_features = sae.encode(activations) # [batch, pos, d_sae]
|
||||
print(f"Active features: {(sae_features > 0).sum()}")
|
||||
|
||||
# 4. Find top features for each position
|
||||
for pos in range(tokens.shape[1]):
|
||||
top_features = sae_features[0, pos].topk(5)
|
||||
token = model.to_str_tokens(tokens[0, pos:pos+1])[0]
|
||||
print(f"Token '{token}': features {top_features.indices.tolist()}")
|
||||
|
||||
# 5. Reconstruct activations
|
||||
reconstructed = sae.decode(sae_features)
|
||||
reconstruction_error = (activations - reconstructed).norm()
|
||||
```
|
||||
|
||||
### Available Pre-trained SAEs
|
||||
|
||||
| Release | Model | Layers |
|
||||
|---------|-------|--------|
|
||||
| `gpt2-small-res-jb` | GPT-2 Small | Multiple residual streams |
|
||||
| `gemma-2b-res` | Gemma 2B | Residual streams |
|
||||
| Various on HuggingFace | Search tag `saelens` | Various |
|
||||
|
||||
### Checklist
|
||||
- [ ] Load model with TransformerLens
|
||||
- [ ] Load matching SAE for target layer
|
||||
- [ ] Encode activations to sparse features
|
||||
- [ ] Identify top-activating features per token
|
||||
- [ ] Validate reconstruction quality
|
||||
|
||||
## Workflow 2: Training a Custom SAE
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from sae_lens import SAE, LanguageModelSAERunnerConfig, SAETrainingRunner
|
||||
|
||||
# 1. Configure training
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
# Model
|
||||
model_name="gpt2-small",
|
||||
hook_name="blocks.8.hook_resid_pre",
|
||||
hook_layer=8,
|
||||
d_in=768, # Model dimension
|
||||
|
||||
# SAE architecture
|
||||
architecture="standard", # or "gated", "topk"
|
||||
d_sae=768 * 8, # Expansion factor of 8
|
||||
activation_fn="relu",
|
||||
|
||||
# Training
|
||||
lr=4e-4,
|
||||
l1_coefficient=8e-5, # Sparsity penalty
|
||||
l1_warm_up_steps=1000,
|
||||
train_batch_size_tokens=4096,
|
||||
training_tokens=100_000_000,
|
||||
|
||||
# Data
|
||||
dataset_path="monology/pile-uncopyrighted",
|
||||
context_size=128,
|
||||
|
||||
# Logging
|
||||
log_to_wandb=True,
|
||||
wandb_project="sae-training",
|
||||
|
||||
# Checkpointing
|
||||
checkpoint_path="checkpoints",
|
||||
n_checkpoints=5,
|
||||
)
|
||||
|
||||
# 2. Train
|
||||
trainer = SAETrainingRunner(cfg)
|
||||
sae = trainer.run()
|
||||
|
||||
# 3. Evaluate
|
||||
print(f"L0 (avg active features): {trainer.metrics['l0']}")
|
||||
print(f"CE Loss Recovered: {trainer.metrics['ce_loss_score']}")
|
||||
```
|
||||
|
||||
### Key Hyperparameters
|
||||
|
||||
| Parameter | Typical Value | Effect |
|
||||
|-----------|---------------|--------|
|
||||
| `d_sae` | 4-16× d_model | More features, higher capacity |
|
||||
| `l1_coefficient` | 5e-5 to 1e-4 | Higher = sparser, less accurate |
|
||||
| `lr` | 1e-4 to 1e-3 | Standard optimizer LR |
|
||||
| `l1_warm_up_steps` | 500-2000 | Prevents early feature death |
|
||||
|
||||
### Evaluation Metrics
|
||||
|
||||
| Metric | Target | Meaning |
|
||||
|--------|--------|---------|
|
||||
| **L0** | 50-200 | Average active features per token |
|
||||
| **CE Loss Score** | 80-95% | Cross-entropy recovered vs original |
|
||||
| **Dead Features** | <5% | Features that never activate |
|
||||
| **Explained Variance** | >90% | Reconstruction quality |
|
||||
|
||||
### Checklist
|
||||
- [ ] Choose target layer and hook point
|
||||
- [ ] Set expansion factor (d_sae = 4-16× d_model)
|
||||
- [ ] Tune L1 coefficient for desired sparsity
|
||||
- [ ] Enable L1 warm-up to prevent dead features
|
||||
- [ ] Monitor metrics during training (W&B)
|
||||
- [ ] Validate L0 and CE loss recovery
|
||||
- [ ] Check dead feature ratio
|
||||
|
||||
## Workflow 3: Feature Analysis and Steering
|
||||
|
||||
### Analyzing Individual Features
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
from sae_lens import SAE
|
||||
import torch
|
||||
|
||||
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
|
||||
sae, _, _ = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Find what activates a specific feature
|
||||
feature_idx = 1234
|
||||
test_texts = [
|
||||
"The scientist conducted an experiment",
|
||||
"I love chocolate cake",
|
||||
"The code compiles successfully",
|
||||
"Paris is beautiful in spring",
|
||||
]
|
||||
|
||||
for text in test_texts:
|
||||
tokens = model.to_tokens(text)
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
features = sae.encode(cache["resid_pre", 8])
|
||||
activation = features[0, :, feature_idx].max().item()
|
||||
print(f"{activation:.3f}: {text}")
|
||||
```
|
||||
|
||||
### Feature Steering
|
||||
|
||||
```python
|
||||
def steer_with_feature(model, sae, prompt, feature_idx, strength=5.0):
|
||||
"""Add SAE feature direction to residual stream."""
|
||||
tokens = model.to_tokens(prompt)
|
||||
|
||||
# Get feature direction from decoder
|
||||
feature_direction = sae.W_dec[feature_idx] # [d_model]
|
||||
|
||||
def steering_hook(activation, hook):
|
||||
# Add scaled feature direction at all positions
|
||||
activation += strength * feature_direction
|
||||
return activation
|
||||
|
||||
# Generate with steering
|
||||
output = model.generate(
|
||||
tokens,
|
||||
max_new_tokens=50,
|
||||
fwd_hooks=[("blocks.8.hook_resid_pre", steering_hook)]
|
||||
)
|
||||
return model.to_string(output[0])
|
||||
```
|
||||
|
||||
### Feature Attribution
|
||||
|
||||
```python
|
||||
# Which features most affect a specific output?
|
||||
tokens = model.to_tokens("The capital of France is")
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
|
||||
# Get features at final position
|
||||
features = sae.encode(cache["resid_pre", 8])[0, -1] # [d_sae]
|
||||
|
||||
# Get logit attribution per feature
|
||||
# Feature contribution = feature_activation × decoder_weight × unembedding
|
||||
W_dec = sae.W_dec # [d_sae, d_model]
|
||||
W_U = model.W_U # [d_model, vocab]
|
||||
|
||||
# Contribution to "Paris" logit
|
||||
paris_token = model.to_single_token(" Paris")
|
||||
feature_contributions = features * (W_dec @ W_U[:, paris_token])
|
||||
|
||||
top_features = feature_contributions.topk(10)
|
||||
print("Top features for 'Paris' prediction:")
|
||||
for idx, val in zip(top_features.indices, top_features.values):
|
||||
print(f" Feature {idx.item()}: {val.item():.3f}")
|
||||
```
|
||||
|
||||
## Common Issues & Solutions
|
||||
|
||||
### Issue: High dead feature ratio
|
||||
```python
|
||||
# WRONG: No warm-up, features die early
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
l1_coefficient=1e-4,
|
||||
l1_warm_up_steps=0, # Bad!
|
||||
)
|
||||
|
||||
# RIGHT: Warm-up L1 penalty
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
l1_coefficient=8e-5,
|
||||
l1_warm_up_steps=1000, # Gradually increase
|
||||
use_ghost_grads=True, # Revive dead features
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Poor reconstruction (low CE recovery)
|
||||
```python
|
||||
# Reduce sparsity penalty
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
l1_coefficient=5e-5, # Lower = better reconstruction
|
||||
d_sae=768 * 16, # More capacity
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Features not interpretable
|
||||
```python
|
||||
# Increase sparsity (higher L1)
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
l1_coefficient=1e-4, # Higher = sparser, more interpretable
|
||||
)
|
||||
# Or use TopK architecture
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
architecture="topk",
|
||||
activation_fn_kwargs={"k": 50}, # Exactly 50 active features
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Memory errors during training
|
||||
```python
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
train_batch_size_tokens=2048, # Reduce batch size
|
||||
store_batch_size_prompts=4, # Fewer prompts in buffer
|
||||
n_batches_in_buffer=8, # Smaller activation buffer
|
||||
)
|
||||
```
|
||||
|
||||
## Integration with Neuronpedia
|
||||
|
||||
Browse pre-trained SAE features at [neuronpedia.org](https://neuronpedia.org):
|
||||
|
||||
```python
|
||||
# Features are indexed by SAE ID
|
||||
# Example: gpt2-small layer 8 feature 1234
|
||||
# → neuronpedia.org/gpt2-small/8-res-jb/1234
|
||||
```
|
||||
|
||||
## Key Classes Reference
|
||||
|
||||
| Class | Purpose |
|
||||
|-------|---------|
|
||||
| `SAE` | Sparse Autoencoder model |
|
||||
| `LanguageModelSAERunnerConfig` | Training configuration |
|
||||
| `SAETrainingRunner` | Training loop manager |
|
||||
| `ActivationsStore` | Activation collection and batching |
|
||||
| `HookedSAETransformer` | TransformerLens + SAE integration |
|
||||
|
||||
## Reference Documentation
|
||||
|
||||
For detailed API documentation, tutorials, and advanced usage, see the `references/` folder:
|
||||
|
||||
| File | Contents |
|
||||
|------|----------|
|
||||
| [references/README.md](references/README.md) | Overview and quick start guide |
|
||||
| [references/api.md](references/api.md) | Complete API reference for SAE, TrainingSAE, configurations |
|
||||
| [references/tutorials.md](references/tutorials.md) | Step-by-step tutorials for training, analysis, steering |
|
||||
|
||||
## External Resources
|
||||
|
||||
### Tutorials
|
||||
- [Basic Loading & Analysis](https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
|
||||
- [Training a Sparse Autoencoder](https://github.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
|
||||
- [ARENA SAE Curriculum](https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab)
|
||||
|
||||
### Papers
|
||||
- [Towards Monosemanticity](https://transformer-circuits.pub/2023/monosemantic-features) - Anthropic (2023)
|
||||
- [Scaling Monosemanticity](https://transformer-circuits.pub/2024/scaling-monosemanticity/) - Anthropic (2024)
|
||||
- [Sparse Autoencoders Find Highly Interpretable Features](https://arxiv.org/abs/2309.08600) - Cunningham et al. (ICLR 2024)
|
||||
|
||||
### Official Documentation
|
||||
- [SAELens Docs](https://jbloomaus.github.io/SAELens/)
|
||||
- [Neuronpedia](https://neuronpedia.org) - Feature browser
|
||||
|
||||
## SAE Architectures
|
||||
|
||||
| Architecture | Description | Use Case |
|
||||
|--------------|-------------|----------|
|
||||
| **Standard** | ReLU + L1 penalty | General purpose |
|
||||
| **Gated** | Learned gating mechanism | Better sparsity control |
|
||||
| **TopK** | Exactly K active features | Consistent sparsity |
|
||||
|
||||
```python
|
||||
# TopK SAE (exactly 50 features active)
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
architecture="topk",
|
||||
activation_fn="topk",
|
||||
activation_fn_kwargs={"k": 50},
|
||||
)
|
||||
```
|
||||
70
optional-skills/mlops/saelens/references/README.md
Normal file
70
optional-skills/mlops/saelens/references/README.md
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
# SAELens Reference Documentation
|
||||
|
||||
This directory contains comprehensive reference materials for SAELens.
|
||||
|
||||
## Contents
|
||||
|
||||
- [api.md](api.md) - Complete API reference for SAE, TrainingSAE, and configuration classes
|
||||
- [tutorials.md](tutorials.md) - Step-by-step tutorials for training and analyzing SAEs
|
||||
- [papers.md](papers.md) - Key research papers on sparse autoencoders
|
||||
|
||||
## Quick Links
|
||||
|
||||
- **GitHub Repository**: https://github.com/jbloomAus/SAELens
|
||||
- **Neuronpedia**: https://neuronpedia.org (browse pre-trained SAE features)
|
||||
- **HuggingFace SAEs**: Search for tag `saelens`
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install sae-lens
|
||||
```
|
||||
|
||||
Requirements: Python 3.10+, transformer-lens>=2.0.0
|
||||
|
||||
## Basic Usage
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
from sae_lens import SAE
|
||||
|
||||
# Load model and SAE
|
||||
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
|
||||
sae, cfg_dict, sparsity = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Encode activations to sparse features
|
||||
tokens = model.to_tokens("Hello world")
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
activations = cache["resid_pre", 8]
|
||||
|
||||
features = sae.encode(activations) # Sparse feature activations
|
||||
reconstructed = sae.decode(features) # Reconstructed activations
|
||||
```
|
||||
|
||||
## Key Concepts
|
||||
|
||||
### Sparse Autoencoders
|
||||
SAEs decompose dense neural activations into sparse, interpretable features:
|
||||
- **Encoder**: Maps d_model → d_sae (typically 4-16x expansion)
|
||||
- **ReLU/TopK**: Enforces sparsity
|
||||
- **Decoder**: Reconstructs original activations
|
||||
|
||||
### Training Loss
|
||||
`Loss = MSE(original, reconstructed) + L1_coefficient × L1(features)`
|
||||
|
||||
### Key Metrics
|
||||
- **L0**: Average number of active features (target: 50-200)
|
||||
- **CE Loss Score**: Cross-entropy recovered vs original model (target: 80-95%)
|
||||
- **Dead Features**: Features that never activate (target: <5%)
|
||||
|
||||
## Available Pre-trained SAEs
|
||||
|
||||
| Release | Model | Description |
|
||||
|---------|-------|-------------|
|
||||
| `gpt2-small-res-jb` | GPT-2 Small | Residual stream SAEs |
|
||||
| `gemma-2b-res` | Gemma 2B | Residual stream SAEs |
|
||||
| Various | Search HuggingFace | Community-trained SAEs |
|
||||
333
optional-skills/mlops/saelens/references/api.md
Normal file
333
optional-skills/mlops/saelens/references/api.md
Normal file
|
|
@ -0,0 +1,333 @@
|
|||
# SAELens API Reference
|
||||
|
||||
## SAE Class
|
||||
|
||||
The core class representing a Sparse Autoencoder.
|
||||
|
||||
### Loading Pre-trained SAEs
|
||||
|
||||
```python
|
||||
from sae_lens import SAE
|
||||
|
||||
# From official releases
|
||||
sae, cfg_dict, sparsity = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# From HuggingFace
|
||||
sae, cfg_dict, sparsity = SAE.from_pretrained(
|
||||
release="username/repo-name",
|
||||
sae_id="path/to/sae",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# From local disk
|
||||
sae = SAE.load_from_disk("/path/to/sae", device="cuda")
|
||||
```
|
||||
|
||||
### SAE Attributes
|
||||
|
||||
| Attribute | Shape | Description |
|
||||
|-----------|-------|-------------|
|
||||
| `W_enc` | [d_in, d_sae] | Encoder weights |
|
||||
| `W_dec` | [d_sae, d_in] | Decoder weights |
|
||||
| `b_enc` | [d_sae] | Encoder bias |
|
||||
| `b_dec` | [d_in] | Decoder bias |
|
||||
| `cfg` | SAEConfig | Configuration object |
|
||||
|
||||
### Core Methods
|
||||
|
||||
#### encode()
|
||||
|
||||
```python
|
||||
# Encode activations to sparse features
|
||||
features = sae.encode(activations)
|
||||
# Input: [batch, pos, d_in]
|
||||
# Output: [batch, pos, d_sae]
|
||||
```
|
||||
|
||||
#### decode()
|
||||
|
||||
```python
|
||||
# Reconstruct activations from features
|
||||
reconstructed = sae.decode(features)
|
||||
# Input: [batch, pos, d_sae]
|
||||
# Output: [batch, pos, d_in]
|
||||
```
|
||||
|
||||
#### forward()
|
||||
|
||||
```python
|
||||
# Full forward pass (encode + decode)
|
||||
reconstructed = sae(activations)
|
||||
# Returns reconstructed activations
|
||||
```
|
||||
|
||||
#### save_model()
|
||||
|
||||
```python
|
||||
sae.save_model("/path/to/save")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## SAEConfig
|
||||
|
||||
Configuration class for SAE architecture and training context.
|
||||
|
||||
### Key Parameters
|
||||
|
||||
| Parameter | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `d_in` | int | Input dimension (model's d_model) |
|
||||
| `d_sae` | int | SAE hidden dimension |
|
||||
| `architecture` | str | "standard", "gated", "jumprelu", "topk" |
|
||||
| `activation_fn_str` | str | Activation function name |
|
||||
| `model_name` | str | Source model name |
|
||||
| `hook_name` | str | Hook point in model |
|
||||
| `normalize_activations` | str | Normalization method |
|
||||
| `dtype` | str | Data type |
|
||||
| `device` | str | Device |
|
||||
|
||||
### Accessing Config
|
||||
|
||||
```python
|
||||
print(sae.cfg.d_in) # 768 for GPT-2 small
|
||||
print(sae.cfg.d_sae) # e.g., 24576 (32x expansion)
|
||||
print(sae.cfg.hook_name) # e.g., "blocks.8.hook_resid_pre"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## LanguageModelSAERunnerConfig
|
||||
|
||||
Comprehensive configuration for training SAEs.
|
||||
|
||||
### Example Configuration
|
||||
|
||||
```python
|
||||
from sae_lens import LanguageModelSAERunnerConfig
|
||||
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
# Model and hook
|
||||
model_name="gpt2-small",
|
||||
hook_name="blocks.8.hook_resid_pre",
|
||||
hook_layer=8,
|
||||
d_in=768,
|
||||
|
||||
# SAE architecture
|
||||
architecture="standard", # "standard", "gated", "jumprelu", "topk"
|
||||
d_sae=768 * 8, # Expansion factor
|
||||
activation_fn="relu",
|
||||
|
||||
# Training hyperparameters
|
||||
lr=4e-4,
|
||||
l1_coefficient=8e-5,
|
||||
lp_norm=1.0,
|
||||
lr_scheduler_name="constant",
|
||||
lr_warm_up_steps=500,
|
||||
|
||||
# Sparsity control
|
||||
l1_warm_up_steps=1000,
|
||||
use_ghost_grads=True,
|
||||
feature_sampling_window=1000,
|
||||
dead_feature_window=5000,
|
||||
dead_feature_threshold=1e-8,
|
||||
|
||||
# Data
|
||||
dataset_path="monology/pile-uncopyrighted",
|
||||
streaming=True,
|
||||
context_size=128,
|
||||
|
||||
# Batch sizes
|
||||
train_batch_size_tokens=4096,
|
||||
store_batch_size_prompts=16,
|
||||
n_batches_in_buffer=64,
|
||||
|
||||
# Training duration
|
||||
training_tokens=100_000_000,
|
||||
|
||||
# Logging
|
||||
log_to_wandb=True,
|
||||
wandb_project="sae-training",
|
||||
wandb_log_frequency=100,
|
||||
|
||||
# Checkpointing
|
||||
checkpoint_path="checkpoints",
|
||||
n_checkpoints=5,
|
||||
|
||||
# Hardware
|
||||
device="cuda",
|
||||
dtype="float32",
|
||||
)
|
||||
```
|
||||
|
||||
### Key Parameters Explained
|
||||
|
||||
#### Architecture Parameters
|
||||
|
||||
| Parameter | Description |
|
||||
|-----------|-------------|
|
||||
| `architecture` | SAE type: "standard", "gated", "jumprelu", "topk" |
|
||||
| `d_sae` | Hidden dimension (or use `expansion_factor`) |
|
||||
| `expansion_factor` | Alternative to d_sae: d_sae = d_in × expansion_factor |
|
||||
| `activation_fn` | "relu", "topk", etc. |
|
||||
| `activation_fn_kwargs` | Dict for activation params (e.g., {"k": 50} for topk) |
|
||||
|
||||
#### Sparsity Parameters
|
||||
|
||||
| Parameter | Description |
|
||||
|-----------|-------------|
|
||||
| `l1_coefficient` | L1 penalty weight (higher = sparser) |
|
||||
| `l1_warm_up_steps` | Steps to ramp up L1 penalty |
|
||||
| `use_ghost_grads` | Apply gradients to dead features |
|
||||
| `dead_feature_threshold` | Activation threshold for "dead" |
|
||||
| `dead_feature_window` | Steps to check for dead features |
|
||||
|
||||
#### Learning Rate Parameters
|
||||
|
||||
| Parameter | Description |
|
||||
|-----------|-------------|
|
||||
| `lr` | Base learning rate |
|
||||
| `lr_scheduler_name` | "constant", "cosineannealing", etc. |
|
||||
| `lr_warm_up_steps` | LR warmup steps |
|
||||
| `lr_decay_steps` | Steps for LR decay |
|
||||
|
||||
---
|
||||
|
||||
## SAETrainingRunner
|
||||
|
||||
Main class for executing training.
|
||||
|
||||
### Basic Training
|
||||
|
||||
```python
|
||||
from sae_lens import SAETrainingRunner, LanguageModelSAERunnerConfig
|
||||
|
||||
cfg = LanguageModelSAERunnerConfig(...)
|
||||
runner = SAETrainingRunner(cfg)
|
||||
sae = runner.run()
|
||||
```
|
||||
|
||||
### Accessing Training Metrics
|
||||
|
||||
```python
|
||||
# During training, metrics logged to W&B include:
|
||||
# - l0: Average active features
|
||||
# - ce_loss_score: Cross-entropy recovery
|
||||
# - mse_loss: Reconstruction loss
|
||||
# - l1_loss: Sparsity loss
|
||||
# - dead_features: Count of dead features
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ActivationsStore
|
||||
|
||||
Manages activation collection and batching.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from sae_lens import ActivationsStore
|
||||
|
||||
store = ActivationsStore.from_sae(
|
||||
model=model,
|
||||
sae=sae,
|
||||
store_batch_size_prompts=8,
|
||||
train_batch_size_tokens=4096,
|
||||
n_batches_in_buffer=32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# Get batch of activations
|
||||
activations = store.get_batch_tokens()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## HookedSAETransformer
|
||||
|
||||
Integration of SAEs with TransformerLens models.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from sae_lens import HookedSAETransformer
|
||||
|
||||
# Load model with SAE
|
||||
model = HookedSAETransformer.from_pretrained("gpt2-small")
|
||||
model.add_sae(sae)
|
||||
|
||||
# Run with SAE in the loop
|
||||
output = model.run_with_saes(tokens, saes=[sae])
|
||||
|
||||
# Cache with SAE activations
|
||||
output, cache = model.run_with_cache_with_saes(tokens, saes=[sae])
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## SAE Architectures
|
||||
|
||||
### Standard (ReLU + L1)
|
||||
|
||||
```python
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
architecture="standard",
|
||||
activation_fn="relu",
|
||||
l1_coefficient=8e-5,
|
||||
)
|
||||
```
|
||||
|
||||
### Gated
|
||||
|
||||
```python
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
architecture="gated",
|
||||
)
|
||||
```
|
||||
|
||||
### TopK
|
||||
|
||||
```python
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
architecture="topk",
|
||||
activation_fn="topk",
|
||||
activation_fn_kwargs={"k": 50}, # Exactly 50 active features
|
||||
)
|
||||
```
|
||||
|
||||
### JumpReLU (State-of-the-art)
|
||||
|
||||
```python
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
architecture="jumprelu",
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Utility Functions
|
||||
|
||||
### Upload to HuggingFace
|
||||
|
||||
```python
|
||||
from sae_lens import upload_saes_to_huggingface
|
||||
|
||||
upload_saes_to_huggingface(
|
||||
saes=[sae],
|
||||
repo_id="username/my-saes",
|
||||
token="hf_token",
|
||||
)
|
||||
```
|
||||
|
||||
### Neuronpedia Integration
|
||||
|
||||
```python
|
||||
# Features can be viewed on Neuronpedia
|
||||
# URL format: neuronpedia.org/{model}/{layer}-{sae_type}/{feature_id}
|
||||
# Example: neuronpedia.org/gpt2-small/8-res-jb/1234
|
||||
```
|
||||
318
optional-skills/mlops/saelens/references/tutorials.md
Normal file
318
optional-skills/mlops/saelens/references/tutorials.md
Normal file
|
|
@ -0,0 +1,318 @@
|
|||
# SAELens Tutorials
|
||||
|
||||
## Tutorial 1: Loading and Analyzing Pre-trained SAEs
|
||||
|
||||
### Goal
|
||||
Load a pre-trained SAE and analyze which features activate on specific inputs.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
from sae_lens import SAE
|
||||
import torch
|
||||
|
||||
# 1. Load model and SAE
|
||||
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
|
||||
sae, cfg_dict, sparsity = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
print(f"SAE input dim: {sae.cfg.d_in}")
|
||||
print(f"SAE hidden dim: {sae.cfg.d_sae}")
|
||||
print(f"Expansion factor: {sae.cfg.d_sae / sae.cfg.d_in:.1f}x")
|
||||
|
||||
# 2. Get model activations
|
||||
prompt = "The capital of France is Paris"
|
||||
tokens = model.to_tokens(prompt)
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
activations = cache["resid_pre", 8] # [1, seq_len, 768]
|
||||
|
||||
# 3. Encode to SAE features
|
||||
features = sae.encode(activations) # [1, seq_len, d_sae]
|
||||
|
||||
# 4. Analyze sparsity
|
||||
active_per_token = (features > 0).sum(dim=-1)
|
||||
print(f"Average active features per token: {active_per_token.float().mean():.1f}")
|
||||
|
||||
# 5. Find top features for each token
|
||||
str_tokens = model.to_str_tokens(prompt)
|
||||
for pos in range(len(str_tokens)):
|
||||
top_features = features[0, pos].topk(5)
|
||||
print(f"\nToken '{str_tokens[pos]}':")
|
||||
for feat_idx, feat_val in zip(top_features.indices, top_features.values):
|
||||
print(f" Feature {feat_idx.item()}: {feat_val.item():.3f}")
|
||||
|
||||
# 6. Check reconstruction quality
|
||||
reconstructed = sae.decode(features)
|
||||
mse = ((activations - reconstructed) ** 2).mean()
|
||||
print(f"\nReconstruction MSE: {mse.item():.6f}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 2: Training a Custom SAE
|
||||
|
||||
### Goal
|
||||
Train a Sparse Autoencoder on GPT-2 activations.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner
|
||||
|
||||
# 1. Configure training
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
# Model
|
||||
model_name="gpt2-small",
|
||||
hook_name="blocks.6.hook_resid_pre",
|
||||
hook_layer=6,
|
||||
d_in=768,
|
||||
|
||||
# SAE architecture
|
||||
architecture="standard",
|
||||
d_sae=768 * 8, # 8x expansion
|
||||
activation_fn="relu",
|
||||
|
||||
# Training
|
||||
lr=4e-4,
|
||||
l1_coefficient=8e-5,
|
||||
l1_warm_up_steps=1000,
|
||||
train_batch_size_tokens=4096,
|
||||
training_tokens=10_000_000, # Small run for demo
|
||||
|
||||
# Data
|
||||
dataset_path="monology/pile-uncopyrighted",
|
||||
streaming=True,
|
||||
context_size=128,
|
||||
|
||||
# Dead feature prevention
|
||||
use_ghost_grads=True,
|
||||
dead_feature_window=5000,
|
||||
|
||||
# Logging
|
||||
log_to_wandb=True,
|
||||
wandb_project="sae-training-demo",
|
||||
|
||||
# Hardware
|
||||
device="cuda",
|
||||
dtype="float32",
|
||||
)
|
||||
|
||||
# 2. Train
|
||||
runner = SAETrainingRunner(cfg)
|
||||
sae = runner.run()
|
||||
|
||||
# 3. Save
|
||||
sae.save_model("./my_trained_sae")
|
||||
```
|
||||
|
||||
### Hyperparameter Tuning Guide
|
||||
|
||||
| If you see... | Try... |
|
||||
|---------------|--------|
|
||||
| High L0 (>200) | Increase `l1_coefficient` |
|
||||
| Low CE recovery (<80%) | Decrease `l1_coefficient`, increase `d_sae` |
|
||||
| Many dead features (>5%) | Enable `use_ghost_grads`, increase `l1_warm_up_steps` |
|
||||
| Training instability | Lower `lr`, increase `lr_warm_up_steps` |
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 3: Feature Attribution and Steering
|
||||
|
||||
### Goal
|
||||
Identify which SAE features contribute to specific predictions and use them for steering.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
from sae_lens import SAE
|
||||
import torch
|
||||
|
||||
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
|
||||
sae, _, _ = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# 1. Feature attribution for a specific prediction
|
||||
prompt = "The capital of France is"
|
||||
tokens = model.to_tokens(prompt)
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
activations = cache["resid_pre", 8]
|
||||
features = sae.encode(activations)
|
||||
|
||||
# Target token
|
||||
target_token = model.to_single_token(" Paris")
|
||||
|
||||
# Compute feature contributions to target logit
|
||||
# contribution = feature_activation * decoder_weight * unembedding
|
||||
W_dec = sae.W_dec # [d_sae, d_model]
|
||||
W_U = model.W_U # [d_model, d_vocab]
|
||||
|
||||
# Feature direction projected to vocabulary
|
||||
feature_to_logit = W_dec @ W_U # [d_sae, d_vocab]
|
||||
|
||||
# Contribution of each feature to "Paris" at final position
|
||||
feature_acts = features[0, -1] # [d_sae]
|
||||
contributions = feature_acts * feature_to_logit[:, target_token]
|
||||
|
||||
# Top contributing features
|
||||
top_features = contributions.topk(10)
|
||||
print("Top features contributing to 'Paris':")
|
||||
for idx, val in zip(top_features.indices, top_features.values):
|
||||
print(f" Feature {idx.item()}: {val.item():.3f}")
|
||||
|
||||
# 2. Feature steering
|
||||
def steer_with_feature(feature_idx, strength=5.0):
|
||||
"""Add a feature direction to the residual stream."""
|
||||
feature_direction = sae.W_dec[feature_idx] # [d_model]
|
||||
|
||||
def hook(activation, hook_obj):
|
||||
activation[:, -1, :] += strength * feature_direction
|
||||
return activation
|
||||
|
||||
output = model.generate(
|
||||
tokens,
|
||||
max_new_tokens=10,
|
||||
fwd_hooks=[("blocks.8.hook_resid_pre", hook)]
|
||||
)
|
||||
return model.to_string(output[0])
|
||||
|
||||
# Try steering with top feature
|
||||
top_feature_idx = top_features.indices[0].item()
|
||||
print(f"\nSteering with feature {top_feature_idx}:")
|
||||
print(steer_with_feature(top_feature_idx, strength=10.0))
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 4: Feature Ablation
|
||||
|
||||
### Goal
|
||||
Test the causal importance of features by ablating them.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
from sae_lens import SAE
|
||||
import torch
|
||||
|
||||
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
|
||||
sae, _, _ = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
prompt = "The capital of France is"
|
||||
tokens = model.to_tokens(prompt)
|
||||
|
||||
# Baseline prediction
|
||||
baseline_logits = model(tokens)
|
||||
target_token = model.to_single_token(" Paris")
|
||||
baseline_prob = torch.softmax(baseline_logits[0, -1], dim=-1)[target_token].item()
|
||||
print(f"Baseline P(Paris): {baseline_prob:.4f}")
|
||||
|
||||
# Get features to ablate
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
activations = cache["resid_pre", 8]
|
||||
features = sae.encode(activations)
|
||||
top_features = features[0, -1].topk(10).indices
|
||||
|
||||
# Ablate top features one by one
|
||||
for feat_idx in top_features:
|
||||
def ablation_hook(activation, hook, feat_idx=feat_idx):
|
||||
# Encode → zero feature → decode
|
||||
feats = sae.encode(activation)
|
||||
feats[:, :, feat_idx] = 0
|
||||
return sae.decode(feats)
|
||||
|
||||
ablated_logits = model.run_with_hooks(
|
||||
tokens,
|
||||
fwd_hooks=[("blocks.8.hook_resid_pre", ablation_hook)]
|
||||
)
|
||||
ablated_prob = torch.softmax(ablated_logits[0, -1], dim=-1)[target_token].item()
|
||||
change = (ablated_prob - baseline_prob) / baseline_prob * 100
|
||||
print(f"Ablate feature {feat_idx.item()}: P(Paris)={ablated_prob:.4f} ({change:+.1f}%)")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 5: Comparing Features Across Prompts
|
||||
|
||||
### Goal
|
||||
Find which features activate consistently for a concept.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
from sae_lens import SAE
|
||||
import torch
|
||||
|
||||
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
|
||||
sae, _, _ = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Test prompts about the same concept
|
||||
prompts = [
|
||||
"The Eiffel Tower is located in",
|
||||
"Paris is the capital of",
|
||||
"France's largest city is",
|
||||
"The Louvre museum is in",
|
||||
]
|
||||
|
||||
# Collect feature activations
|
||||
all_features = []
|
||||
for prompt in prompts:
|
||||
tokens = model.to_tokens(prompt)
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
activations = cache["resid_pre", 8]
|
||||
features = sae.encode(activations)
|
||||
# Take max activation across positions
|
||||
max_features = features[0].max(dim=0).values
|
||||
all_features.append(max_features)
|
||||
|
||||
all_features = torch.stack(all_features) # [n_prompts, d_sae]
|
||||
|
||||
# Find features that activate consistently
|
||||
mean_activation = all_features.mean(dim=0)
|
||||
min_activation = all_features.min(dim=0).values
|
||||
|
||||
# Features active in ALL prompts
|
||||
consistent_features = (min_activation > 0.5).nonzero().squeeze(-1)
|
||||
print(f"Features active in all prompts: {len(consistent_features)}")
|
||||
|
||||
# Top consistent features
|
||||
top_consistent = mean_activation[consistent_features].topk(min(10, len(consistent_features)))
|
||||
print("\nTop consistent features (possibly 'France/Paris' related):")
|
||||
for idx, val in zip(top_consistent.indices, top_consistent.values):
|
||||
feat_idx = consistent_features[idx].item()
|
||||
print(f" Feature {feat_idx}: mean activation {val.item():.3f}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## External Resources
|
||||
|
||||
### Official Tutorials
|
||||
- [Basic Loading & Analysis](https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
|
||||
- [Training SAEs](https://github.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
|
||||
- [Logits Lens with Features](https://github.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
|
||||
|
||||
### ARENA Curriculum
|
||||
Comprehensive SAE course: https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab
|
||||
|
||||
### Key Papers
|
||||
- [Towards Monosemanticity](https://transformer-circuits.pub/2023/monosemantic-features) - Anthropic (2023)
|
||||
- [Scaling Monosemanticity](https://transformer-circuits.pub/2024/scaling-monosemanticity/) - Anthropic (2024)
|
||||
- [Sparse Autoencoders Find Interpretable Features](https://arxiv.org/abs/2309.08600) - ICLR 2024
|
||||
222
optional-skills/mlops/simpo/SKILL.md
Normal file
222
optional-skills/mlops/simpo/SKILL.md
Normal file
|
|
@ -0,0 +1,222 @@
|
|||
---
|
||||
name: simpo-training
|
||||
description: Simple Preference Optimization for LLM alignment. Reference-free alternative to DPO with better performance (+6.4 points on AlpacaEval 2.0). No reference model needed, more efficient than DPO. Use for preference alignment when want simpler, faster training than DPO/PPO.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [torch, transformers, datasets, trl, accelerate]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Post-Training, SimPO, Preference Optimization, Alignment, DPO Alternative, Reference-Free, LLM Alignment, Efficient Training]
|
||||
|
||||
---
|
||||
|
||||
# SimPO - Simple Preference Optimization
|
||||
|
||||
## Quick start
|
||||
|
||||
SimPO is a reference-free preference optimization method that outperforms DPO without needing a reference model.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
# Create environment
|
||||
conda create -n simpo python=3.10 && conda activate simpo
|
||||
|
||||
# Install PyTorch 2.2.2
|
||||
# Visit: https://pytorch.org/get-started/locally/
|
||||
|
||||
# Install alignment-handbook
|
||||
git clone https://github.com/huggingface/alignment-handbook.git
|
||||
cd alignment-handbook
|
||||
python -m pip install .
|
||||
|
||||
# Install Flash Attention 2
|
||||
python -m pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
**Training** (Mistral 7B):
|
||||
```bash
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch \
|
||||
--config_file accelerate_configs/deepspeed_zero3.yaml \
|
||||
scripts/run_simpo.py \
|
||||
training_configs/mistral-7b-base-simpo.yaml
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Train from base model (Mistral 7B)
|
||||
|
||||
**Config** (`mistral-7b-base-simpo.yaml`):
|
||||
```yaml
|
||||
# Model
|
||||
model_name_or_path: mistralai/Mistral-7B-v0.1
|
||||
torch_dtype: bfloat16
|
||||
|
||||
# Dataset
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 1.0
|
||||
dataset_splits:
|
||||
- train_prefs
|
||||
- test_prefs
|
||||
|
||||
# SimPO hyperparameters
|
||||
beta: 2.0 # Reward scaling (2.0-10.0)
|
||||
gamma_beta_ratio: 0.5 # Target margin (0-1)
|
||||
loss_type: sigmoid # sigmoid or hinge
|
||||
sft_weight: 0.0 # Optional SFT regularization
|
||||
|
||||
# Training
|
||||
learning_rate: 5e-7 # Critical: 3e-7 to 1e-6
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
|
||||
# Output
|
||||
output_dir: ./outputs/mistral-7b-simpo
|
||||
```
|
||||
|
||||
**Launch training**:
|
||||
```bash
|
||||
accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml \
|
||||
scripts/run_simpo.py training_configs/mistral-7b-base-simpo.yaml
|
||||
```
|
||||
|
||||
### Workflow 2: Fine-tune instruct model (Llama 3 8B)
|
||||
|
||||
**Config** (`llama3-8b-instruct-simpo.yaml`):
|
||||
```yaml
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
|
||||
dataset_mixer:
|
||||
argilla/ultrafeedback-binarized-preferences-cleaned: 1.0
|
||||
|
||||
beta: 2.5
|
||||
gamma_beta_ratio: 0.5
|
||||
learning_rate: 5e-7
|
||||
sft_weight: 0.1 # Add SFT loss to preserve capabilities
|
||||
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
output_dir: ./outputs/llama3-8b-simpo
|
||||
```
|
||||
|
||||
**Launch**:
|
||||
```bash
|
||||
accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml \
|
||||
scripts/run_simpo.py training_configs/llama3-8b-instruct-simpo.yaml
|
||||
```
|
||||
|
||||
### Workflow 3: Reasoning-intensive tasks (lower LR)
|
||||
|
||||
**For math/code tasks**:
|
||||
```yaml
|
||||
model_name_or_path: deepseek-ai/deepseek-math-7b-base
|
||||
|
||||
dataset_mixer:
|
||||
argilla/distilabel-math-preference-dpo: 1.0
|
||||
|
||||
beta: 5.0 # Higher for stronger signal
|
||||
gamma_beta_ratio: 0.7 # Larger margin
|
||||
learning_rate: 3e-7 # Lower LR for reasoning
|
||||
sft_weight: 0.0
|
||||
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 16
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use SimPO when**:
|
||||
- Want simpler training than DPO (no reference model)
|
||||
- Have preference data (chosen/rejected pairs)
|
||||
- Need better performance than DPO
|
||||
- Limited compute resources
|
||||
- Single-node training sufficient
|
||||
|
||||
**Algorithm selection**:
|
||||
- **SimPO**: Simplest, best performance, no reference model
|
||||
- **DPO**: Need reference model baseline, more conservative
|
||||
- **PPO**: Maximum control, need reward model, complex setup
|
||||
- **GRPO**: Memory-efficient RL, no critic
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **OpenRLHF**: Multi-node distributed training, PPO/GRPO
|
||||
- **TRL**: Need multiple methods in one framework
|
||||
- **DPO**: Established baseline comparison
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Loss divergence**
|
||||
|
||||
Reduce learning rate:
|
||||
```yaml
|
||||
learning_rate: 3e-7 # Reduce from 5e-7
|
||||
```
|
||||
|
||||
Reduce beta:
|
||||
```yaml
|
||||
beta: 1.0 # Reduce from 2.0
|
||||
```
|
||||
|
||||
**Issue: Model forgets capabilities**
|
||||
|
||||
Add SFT regularization:
|
||||
```yaml
|
||||
sft_weight: 0.1 # Add SFT loss component
|
||||
```
|
||||
|
||||
**Issue: Poor preference separation**
|
||||
|
||||
Increase beta and margin:
|
||||
```yaml
|
||||
beta: 5.0 # Increase from 2.0
|
||||
gamma_beta_ratio: 0.8 # Increase from 0.5
|
||||
```
|
||||
|
||||
**Issue: OOM during training**
|
||||
|
||||
Reduce batch size:
|
||||
```yaml
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 16 # Maintain effective batch
|
||||
```
|
||||
|
||||
Enable gradient checkpointing:
|
||||
```yaml
|
||||
gradient_checkpointing: true
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Loss functions**: See [references/loss-functions.md](references/loss-functions.md) for sigmoid vs hinge loss, mathematical formulations, and when to use each.
|
||||
|
||||
**Hyperparameter tuning**: See [references/hyperparameters.md](references/hyperparameters.md) for beta, gamma, learning rate selection guide, and model-size-specific recommendations.
|
||||
|
||||
**Dataset preparation**: See [references/datasets.md](references/datasets.md) for preference data formats, quality filtering, and custom dataset creation.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA A100/H100 recommended
|
||||
- **VRAM**:
|
||||
- 7B model: 1× A100 40GB (DeepSpeed ZeRO-3)
|
||||
- 8B model: 2× A100 40GB
|
||||
- 70B model: 8× A100 80GB
|
||||
- **Single-node**: DeepSpeed ZeRO-3 sufficient
|
||||
- **Mixed precision**: BF16 recommended
|
||||
|
||||
**Memory optimization**:
|
||||
- DeepSpeed ZeRO-3 (default config)
|
||||
- Gradient checkpointing
|
||||
- Flash Attention 2
|
||||
|
||||
## Resources
|
||||
|
||||
- Paper: https://arxiv.org/abs/2405.14734 (NeurIPS 2024)
|
||||
- GitHub: https://github.com/princeton-nlp/SimPO
|
||||
- Models: https://huggingface.co/princeton-nlp
|
||||
- Alignment Handbook: https://github.com/huggingface/alignment-handbook
|
||||
|
||||
|
||||
|
||||
478
optional-skills/mlops/simpo/references/datasets.md
Normal file
478
optional-skills/mlops/simpo/references/datasets.md
Normal file
|
|
@ -0,0 +1,478 @@
|
|||
# Datasets
|
||||
|
||||
Complete guide to preference datasets for SimPO training.
|
||||
|
||||
## Dataset Format
|
||||
|
||||
### Required Fields
|
||||
|
||||
Preference datasets must contain:
|
||||
```json
|
||||
{
|
||||
"prompt": "User question or instruction",
|
||||
"chosen": "Better/preferred response",
|
||||
"rejected": "Worse/rejected response"
|
||||
}
|
||||
```
|
||||
|
||||
**Alternative field names** (auto-detected):
|
||||
- `prompt` → `question`, `instruction`, `input`
|
||||
- `chosen` → `response_chosen`, `winner`, `preferred`
|
||||
- `rejected` → `response_rejected`, `loser`
|
||||
|
||||
### Example Entry
|
||||
|
||||
```json
|
||||
{
|
||||
"prompt": "Explain quantum computing in simple terms.",
|
||||
"chosen": "Quantum computing uses quantum bits (qubits) that can exist in multiple states simultaneously through superposition. This allows quantum computers to process many possibilities at once, making them potentially much faster than classical computers for specific tasks like cryptography and optimization.",
|
||||
"rejected": "It's like regular computing but quantum."
|
||||
}
|
||||
```
|
||||
|
||||
## Popular Datasets
|
||||
|
||||
### 1. UltraFeedback (Recommended)
|
||||
|
||||
**HuggingFaceH4/ultrafeedback_binarized**:
|
||||
- **Size**: 60K preference pairs
|
||||
- **Quality**: High (GPT-4 annotations)
|
||||
- **Domain**: General instruction following
|
||||
- **Format**: Clean, ready-to-use
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 1.0
|
||||
dataset_splits:
|
||||
- train_prefs
|
||||
- test_prefs
|
||||
```
|
||||
|
||||
### 2. Argilla UltraFeedback (Cleaned)
|
||||
|
||||
**argilla/ultrafeedback-binarized-preferences-cleaned**:
|
||||
- **Size**: 50K pairs (filtered)
|
||||
- **Quality**: Very high (deduped, cleaned)
|
||||
- **Domain**: General
|
||||
- **Format**: Clean
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
argilla/ultrafeedback-binarized-preferences-cleaned: 1.0
|
||||
```
|
||||
|
||||
### 3. Distilabel Math
|
||||
|
||||
**argilla/distilabel-math-preference-dpo**:
|
||||
- **Size**: 30K pairs
|
||||
- **Quality**: High (GSM8K, MATH)
|
||||
- **Domain**: Math reasoning
|
||||
- **Format**: Math-specific
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
argilla/distilabel-math-preference-dpo: 1.0
|
||||
```
|
||||
|
||||
### 4. HelpSteer
|
||||
|
||||
**nvidia/HelpSteer**:
|
||||
- **Size**: 38K samples
|
||||
- **Quality**: High (human ratings)
|
||||
- **Domain**: Helpfulness alignment
|
||||
- **Format**: Multi-attribute ratings
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
nvidia/HelpSteer: 1.0
|
||||
```
|
||||
|
||||
### 5. Anthropic HH-RLHF
|
||||
|
||||
**Anthropic/hh-rlhf**:
|
||||
- **Size**: 161K samples
|
||||
- **Quality**: High (human preferences)
|
||||
- **Domain**: Harmless + helpful
|
||||
- **Format**: Conversational
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
Anthropic/hh-rlhf: 1.0
|
||||
```
|
||||
|
||||
## Dataset Mixing
|
||||
|
||||
### Multiple Datasets
|
||||
|
||||
**Equal mix**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 0.5
|
||||
Anthropic/hh-rlhf: 0.5
|
||||
```
|
||||
|
||||
**Weighted mix**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 0.7
|
||||
argilla/distilabel-math-preference-dpo: 0.2
|
||||
nvidia/HelpSteer: 0.1
|
||||
```
|
||||
|
||||
**Domain-specific emphasis**:
|
||||
```yaml
|
||||
# 80% general + 20% math
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 0.8
|
||||
argilla/distilabel-math-preference-dpo: 0.2
|
||||
```
|
||||
|
||||
## Data Quality
|
||||
|
||||
### Quality Indicators
|
||||
|
||||
**Good preference data**:
|
||||
- ✅ Clear quality difference between chosen/rejected
|
||||
- ✅ Diverse prompts
|
||||
- ✅ Minimal noise/annotation errors
|
||||
- ✅ Appropriate difficulty level
|
||||
|
||||
**Poor preference data**:
|
||||
- ❌ Ambiguous preferences
|
||||
- ❌ Repetitive prompts
|
||||
- ❌ Annotation noise
|
||||
- ❌ Too easy/hard prompts
|
||||
|
||||
### Quality Filtering
|
||||
|
||||
**Filter by length difference**:
|
||||
```python
|
||||
def filter_by_length(example):
|
||||
chosen_len = len(example['chosen'].split())
|
||||
rejected_len = len(example['rejected'].split())
|
||||
# Reject if chosen is much shorter (potential low-effort)
|
||||
return chosen_len >= rejected_len * 0.5
|
||||
|
||||
dataset = dataset.filter(filter_by_length)
|
||||
```
|
||||
|
||||
**Filter by diversity**:
|
||||
```python
|
||||
seen_prompts = set()
|
||||
|
||||
def filter_duplicates(example):
|
||||
prompt = example['prompt']
|
||||
if prompt in seen_prompts:
|
||||
return False
|
||||
seen_prompts.add(prompt)
|
||||
return True
|
||||
|
||||
dataset = dataset.filter(filter_duplicates)
|
||||
```
|
||||
|
||||
## Custom Dataset Creation
|
||||
|
||||
### Format 1: JSON Lines
|
||||
|
||||
**File** (`preferences.jsonl`):
|
||||
```jsonl
|
||||
{"prompt": "What is Python?", "chosen": "Python is a high-level programming language...", "rejected": "It's a snake."}
|
||||
{"prompt": "Explain AI.", "chosen": "AI refers to systems that can...", "rejected": "It's computers that think."}
|
||||
```
|
||||
|
||||
**Load**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
json:
|
||||
data_files: preferences.jsonl
|
||||
```
|
||||
|
||||
### Format 2: HuggingFace Dataset
|
||||
|
||||
**Create from dict**:
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
data = {
|
||||
"prompt": ["What is Python?", "Explain AI."],
|
||||
"chosen": ["Python is...", "AI refers to..."],
|
||||
"rejected": ["It's a snake.", "It's computers..."]
|
||||
}
|
||||
|
||||
dataset = Dataset.from_dict(data)
|
||||
dataset.push_to_hub("username/my-preferences")
|
||||
```
|
||||
|
||||
**Use in config**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
username/my-preferences: 1.0
|
||||
```
|
||||
|
||||
### Format 3: ChatML
|
||||
|
||||
**For conversational data**:
|
||||
```json
|
||||
{
|
||||
"prompt": [
|
||||
{"role": "user", "content": "What is quantum computing?"}
|
||||
],
|
||||
"chosen": [
|
||||
{"role": "assistant", "content": "Quantum computing uses qubits..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "assistant", "content": "It's like regular computing but quantum."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Apply chat template**:
|
||||
```yaml
|
||||
dataset_text_field: null # Will apply chat template
|
||||
```
|
||||
|
||||
## Synthetic Data Generation
|
||||
|
||||
### Using GPT-4
|
||||
|
||||
**Prompt template**:
|
||||
```
|
||||
Given the following question:
|
||||
{prompt}
|
||||
|
||||
Generate two responses:
|
||||
1. A high-quality, detailed response (chosen)
|
||||
2. A low-quality, brief response (rejected)
|
||||
|
||||
Format as JSON with "chosen" and "rejected" fields.
|
||||
```
|
||||
|
||||
**Example code**:
|
||||
```python
|
||||
import openai
|
||||
|
||||
def generate_pair(prompt):
|
||||
response = openai.ChatCompletion.create(
|
||||
model="gpt-4",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Given: {prompt}\n\nGenerate chosen/rejected pair in JSON."
|
||||
}]
|
||||
)
|
||||
return json.loads(response.choices[0].message.content)
|
||||
|
||||
# Generate dataset
|
||||
prompts = load_prompts()
|
||||
dataset = [generate_pair(p) for p in prompts]
|
||||
```
|
||||
|
||||
### Using Local Model
|
||||
|
||||
**With vLLM**:
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
llm = LLM(model="meta-llama/Meta-Llama-3-70B-Instruct")
|
||||
|
||||
def generate_variations(prompt):
|
||||
# Generate multiple completions
|
||||
outputs = llm.generate(
|
||||
[prompt] * 4,
|
||||
sampling_params={
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"max_tokens": 512
|
||||
}
|
||||
)
|
||||
|
||||
# Select best/worst
|
||||
chosen = max(outputs, key=lambda x: len(x.outputs[0].text))
|
||||
rejected = min(outputs, key=lambda x: len(x.outputs[0].text))
|
||||
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"chosen": chosen.outputs[0].text,
|
||||
"rejected": rejected.outputs[0].text
|
||||
}
|
||||
```
|
||||
|
||||
## Data Preprocessing
|
||||
|
||||
### Truncation
|
||||
|
||||
**Limit sequence length**:
|
||||
```yaml
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 512
|
||||
max_length: 1024 # Total
|
||||
```
|
||||
|
||||
**Implementation**:
|
||||
```python
|
||||
def truncate_example(example):
|
||||
tokenizer.truncation_side = "left" # For prompts
|
||||
prompt_tokens = tokenizer(
|
||||
example['prompt'],
|
||||
max_length=512,
|
||||
truncation=True
|
||||
)
|
||||
|
||||
tokenizer.truncation_side = "right" # For completions
|
||||
chosen_tokens = tokenizer(
|
||||
example['chosen'],
|
||||
max_length=512,
|
||||
truncation=True
|
||||
)
|
||||
|
||||
return {
|
||||
"prompt": tokenizer.decode(prompt_tokens['input_ids']),
|
||||
"chosen": tokenizer.decode(chosen_tokens['input_ids'])
|
||||
}
|
||||
|
||||
dataset = dataset.map(truncate_example)
|
||||
```
|
||||
|
||||
### Deduplication
|
||||
|
||||
**Remove exact duplicates**:
|
||||
```python
|
||||
dataset = dataset.unique('prompt')
|
||||
```
|
||||
|
||||
**Remove near-duplicates** (MinHash):
|
||||
```python
|
||||
from datasketch import MinHash, MinHashLSH
|
||||
|
||||
def deduplicate_lsh(dataset, threshold=0.8):
|
||||
lsh = MinHashLSH(threshold=threshold, num_perm=128)
|
||||
seen = []
|
||||
|
||||
for i, example in enumerate(dataset):
|
||||
m = MinHash(num_perm=128)
|
||||
for word in example['prompt'].split():
|
||||
m.update(word.encode('utf8'))
|
||||
|
||||
if not lsh.query(m):
|
||||
lsh.insert(i, m)
|
||||
seen.append(example)
|
||||
|
||||
return Dataset.from_list(seen)
|
||||
|
||||
dataset = deduplicate_lsh(dataset)
|
||||
```
|
||||
|
||||
## Data Augmentation
|
||||
|
||||
### Paraphrasing Prompts
|
||||
|
||||
```python
|
||||
def paraphrase_prompt(example):
|
||||
# Use paraphrasing model
|
||||
paraphrased = paraphrase_model(example['prompt'])
|
||||
|
||||
return [
|
||||
example, # Original
|
||||
{
|
||||
"prompt": paraphrased,
|
||||
"chosen": example['chosen'],
|
||||
"rejected": example['rejected']
|
||||
}
|
||||
]
|
||||
|
||||
dataset = dataset.map(paraphrase_prompt, batched=False, remove_columns=[])
|
||||
```
|
||||
|
||||
### Difficulty Balancing
|
||||
|
||||
**Mix easy/medium/hard**:
|
||||
```python
|
||||
def categorize_difficulty(example):
|
||||
prompt_len = len(example['prompt'].split())
|
||||
if prompt_len < 20:
|
||||
return "easy"
|
||||
elif prompt_len < 50:
|
||||
return "medium"
|
||||
else:
|
||||
return "hard"
|
||||
|
||||
dataset = dataset.map(lambda x: {"difficulty": categorize_difficulty(x)})
|
||||
|
||||
# Sample balanced dataset
|
||||
easy = dataset.filter(lambda x: x['difficulty'] == 'easy').shuffle().select(range(1000))
|
||||
medium = dataset.filter(lambda x: x['difficulty'] == 'medium').shuffle().select(range(1000))
|
||||
hard = dataset.filter(lambda x: x['difficulty'] == 'hard').shuffle().select(range(1000))
|
||||
|
||||
balanced = concatenate_datasets([easy, medium, hard]).shuffle()
|
||||
```
|
||||
|
||||
## Dataset Statistics
|
||||
|
||||
### Compute Stats
|
||||
|
||||
```python
|
||||
def compute_stats(dataset):
|
||||
prompt_lens = [len(x['prompt'].split()) for x in dataset]
|
||||
chosen_lens = [len(x['chosen'].split()) for x in dataset]
|
||||
rejected_lens = [len(x['rejected'].split()) for x in dataset]
|
||||
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
print(f"Avg prompt length: {np.mean(prompt_lens):.1f} words")
|
||||
print(f"Avg chosen length: {np.mean(chosen_lens):.1f} words")
|
||||
print(f"Avg rejected length: {np.mean(rejected_lens):.1f} words")
|
||||
print(f"Chosen > Rejected: {sum(c > r for c, r in zip(chosen_lens, rejected_lens)) / len(dataset):.1%}")
|
||||
|
||||
compute_stats(dataset)
|
||||
```
|
||||
|
||||
**Expected output**:
|
||||
```
|
||||
Dataset size: 50000
|
||||
Avg prompt length: 45.2 words
|
||||
Avg chosen length: 180.5 words
|
||||
Avg rejected length: 120.3 words
|
||||
Chosen > Rejected: 85.2%
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Data Quality Over Quantity
|
||||
|
||||
- **Prefer**: 10K high-quality pairs
|
||||
- **Over**: 100K noisy pairs
|
||||
|
||||
### 2. Clear Preference Signals
|
||||
|
||||
- Chosen should be noticeably better
|
||||
- Avoid marginal differences
|
||||
- Remove ambiguous pairs
|
||||
|
||||
### 3. Domain Matching
|
||||
|
||||
- Match dataset domain to target use case
|
||||
- Mix datasets for broader coverage
|
||||
- Include safety-filtered data
|
||||
|
||||
### 4. Validate Before Training
|
||||
|
||||
```python
|
||||
# Sample 10 random examples
|
||||
samples = dataset.shuffle().select(range(10))
|
||||
|
||||
for ex in samples:
|
||||
print(f"Prompt: {ex['prompt']}")
|
||||
print(f"Chosen: {ex['chosen'][:100]}...")
|
||||
print(f"Rejected: {ex['rejected'][:100]}...")
|
||||
print(f"Preference clear: {'✓' if len(ex['chosen']) > len(ex['rejected']) else '?'}")
|
||||
print()
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- HuggingFace Datasets: https://huggingface.co/datasets
|
||||
- Alignment Handbook: https://github.com/huggingface/alignment-handbook
|
||||
- UltraFeedback: https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized
|
||||
452
optional-skills/mlops/simpo/references/hyperparameters.md
Normal file
452
optional-skills/mlops/simpo/references/hyperparameters.md
Normal file
|
|
@ -0,0 +1,452 @@
|
|||
# Hyperparameters
|
||||
|
||||
Complete guide to SimPO hyperparameter selection and tuning.
|
||||
|
||||
## Overview
|
||||
|
||||
Key hyperparameters in SimPO:
|
||||
1. **Learning Rate** - Most critical
|
||||
2. **Beta (β)** - Reward scaling
|
||||
3. **Gamma-Beta Ratio (γ/β)** - Target margin
|
||||
4. **SFT Weight** - Regularization strength
|
||||
|
||||
## Learning Rate
|
||||
|
||||
### Recommended Ranges
|
||||
|
||||
**By model size**:
|
||||
| Model Size | Learning Rate | Notes |
|
||||
|------------|---------------|-------|
|
||||
| 1B-3B | 5e-7 to 1e-6 | Higher end safe |
|
||||
| 7B-8B | 3e-7 to 5e-7 | **Standard** |
|
||||
| 13B-30B | 1e-7 to 3e-7 | Lower for stability |
|
||||
| 70B+ | 5e-8 to 1e-7 | Very conservative |
|
||||
|
||||
**By task type**:
|
||||
| Task | Learning Rate | Reason |
|
||||
|------|---------------|--------|
|
||||
| General chat | 5e-7 | Standard |
|
||||
| Code generation | 3e-7 | **Precise reasoning** |
|
||||
| Math reasoning | 3e-7 | **Careful optimization** |
|
||||
| Creative writing | 1e-6 | More aggressive OK |
|
||||
|
||||
### Why Learning Rate Matters
|
||||
|
||||
**Too high** (> 1e-6 for 7B):
|
||||
- Loss divergence
|
||||
- Catastrophic forgetting
|
||||
- Unstable training
|
||||
|
||||
**Too low** (< 1e-7 for 7B):
|
||||
- Very slow convergence
|
||||
- May not finish in time
|
||||
- Undertraining
|
||||
|
||||
**Optimal** (3e-7 to 5e-7 for 7B):
|
||||
- Stable convergence
|
||||
- Good final performance
|
||||
- Efficient training
|
||||
|
||||
### Config Examples
|
||||
|
||||
**Mistral 7B (general)**:
|
||||
```yaml
|
||||
learning_rate: 5e-7
|
||||
num_train_epochs: 1
|
||||
warmup_ratio: 0.1
|
||||
lr_scheduler_type: cosine
|
||||
```
|
||||
|
||||
**Llama 3 8B (reasoning)**:
|
||||
```yaml
|
||||
learning_rate: 3e-7
|
||||
num_train_epochs: 1
|
||||
warmup_ratio: 0.1
|
||||
lr_scheduler_type: cosine
|
||||
```
|
||||
|
||||
**Gemma 2 9B (creative)**:
|
||||
```yaml
|
||||
learning_rate: 1e-6
|
||||
num_train_epochs: 1
|
||||
warmup_ratio: 0.1
|
||||
lr_scheduler_type: linear
|
||||
```
|
||||
|
||||
## Beta (β)
|
||||
|
||||
### Recommended Values
|
||||
|
||||
**Range**: 2.0 to 10.0 (much higher than DPO's 0.01-0.1)
|
||||
|
||||
**By preference strength**:
|
||||
| Beta | Preference Strength | Use Case |
|
||||
|------|-------------------|----------|
|
||||
| 1.0-2.0 | Weak | Subtle preferences |
|
||||
| 2.0-5.0 | **Standard** | General alignment |
|
||||
| 5.0-10.0 | Strong | Clear preferences |
|
||||
|
||||
**Default**: 2.0 to 2.5
|
||||
|
||||
### Why Beta Matters
|
||||
|
||||
**Low beta** (< 2.0):
|
||||
- Weak reward signal
|
||||
- Slow preference learning
|
||||
- May underfit
|
||||
|
||||
**High beta** (> 10.0):
|
||||
- Very strong reward signal
|
||||
- Risk of overfitting
|
||||
- May ignore weak preferences
|
||||
|
||||
**Optimal** (2.0-5.0):
|
||||
- Balanced reward scaling
|
||||
- Stable training
|
||||
- Good generalization
|
||||
|
||||
### Interaction with Gamma
|
||||
|
||||
**Beta and gamma together**:
|
||||
```
|
||||
Target margin in reward space = gamma
|
||||
Target margin in logit space = gamma / beta
|
||||
```
|
||||
|
||||
**Example**:
|
||||
```yaml
|
||||
beta: 2.0
|
||||
gamma_beta_ratio: 0.5
|
||||
# Effective gamma = 2.0 * 0.5 = 1.0
|
||||
```
|
||||
|
||||
### Config Examples
|
||||
|
||||
**Weak preferences**:
|
||||
```yaml
|
||||
beta: 2.0
|
||||
gamma_beta_ratio: 0.3 # Small margin
|
||||
```
|
||||
|
||||
**Standard**:
|
||||
```yaml
|
||||
beta: 2.5
|
||||
gamma_beta_ratio: 0.5 # Default
|
||||
```
|
||||
|
||||
**Strong preferences**:
|
||||
```yaml
|
||||
beta: 5.0
|
||||
gamma_beta_ratio: 0.7 # Larger margin
|
||||
```
|
||||
|
||||
## Gamma-Beta Ratio (γ/β)
|
||||
|
||||
### Recommended Values
|
||||
|
||||
**Range**: 0.0 to 1.0
|
||||
|
||||
**By scenario**:
|
||||
| Ratio | Margin | Use Case |
|
||||
|-------|--------|----------|
|
||||
| 0.0-0.3 | Small | Weak preference data |
|
||||
| 0.4-0.6 | **Standard** | General use |
|
||||
| 0.7-1.0 | Large | Very clear preferences |
|
||||
|
||||
**Default**: 0.5
|
||||
|
||||
### Why Gamma Matters
|
||||
|
||||
**Low gamma** (< 0.3):
|
||||
- Small target margin
|
||||
- Less aggressive alignment
|
||||
- More conservative
|
||||
|
||||
**High gamma** (> 0.7):
|
||||
- Large target margin
|
||||
- Stronger alignment
|
||||
- More aggressive
|
||||
|
||||
**Optimal** (0.4-0.6):
|
||||
- Balanced margin
|
||||
- Stable training
|
||||
- Good alignment
|
||||
|
||||
### Mathematical Meaning
|
||||
|
||||
**In loss function**:
|
||||
```python
|
||||
logits = pi_logratios - gamma_beta_ratio
|
||||
loss = -log(sigmoid(beta * logits))
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- gamma_beta_ratio shifts the decision boundary
|
||||
- Higher ratio = requires larger log prob difference
|
||||
- Controls how "clear" preferences must be
|
||||
|
||||
### Config Examples
|
||||
|
||||
**Noisy preferences**:
|
||||
```yaml
|
||||
gamma_beta_ratio: 0.3 # Smaller margin, more tolerant
|
||||
```
|
||||
|
||||
**Standard**:
|
||||
```yaml
|
||||
gamma_beta_ratio: 0.5 # Default
|
||||
```
|
||||
|
||||
**High-quality preferences**:
|
||||
```yaml
|
||||
gamma_beta_ratio: 0.8 # Larger margin, stricter
|
||||
```
|
||||
|
||||
## SFT Weight
|
||||
|
||||
### Recommended Values
|
||||
|
||||
**Range**: 0.0 to 1.0
|
||||
|
||||
**By model type**:
|
||||
| Model Type | SFT Weight | Reason |
|
||||
|------------|-----------|--------|
|
||||
| Base model | 0.0 | No prior capabilities |
|
||||
| **Instruct model** | 0.05-0.1 | Preserve instruction following |
|
||||
| Chat model | 0.1-0.2 | Preserve conversational skills |
|
||||
|
||||
**Default**: 0.0 (no SFT regularization)
|
||||
|
||||
### Why SFT Weight Matters
|
||||
|
||||
**Zero SFT** (0.0):
|
||||
- Pure preference optimization
|
||||
- May forget capabilities
|
||||
- Standard for base models
|
||||
|
||||
**Low SFT** (0.05-0.1):
|
||||
- Balanced approach
|
||||
- **Recommended for instruct models**
|
||||
- Slight capability preservation
|
||||
|
||||
**High SFT** (> 0.2):
|
||||
- Strong capability preservation
|
||||
- Weaker preference alignment
|
||||
- May reduce alignment gains
|
||||
|
||||
### Trade-off
|
||||
|
||||
```
|
||||
Total Loss = SimPO Loss + (sft_weight * SFT Loss)
|
||||
```
|
||||
|
||||
**Example**:
|
||||
```yaml
|
||||
sft_weight: 0.1
|
||||
# 90% preference optimization + 10% capability preservation
|
||||
```
|
||||
|
||||
### Config Examples
|
||||
|
||||
**Base model (no SFT)**:
|
||||
```yaml
|
||||
model_name_or_path: mistralai/Mistral-7B-v0.1
|
||||
sft_weight: 0.0
|
||||
```
|
||||
|
||||
**Instruct model (light SFT)**:
|
||||
```yaml
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
sft_weight: 0.1
|
||||
```
|
||||
|
||||
**Chat model (moderate SFT)**:
|
||||
```yaml
|
||||
model_name_or_path: HuggingFaceH4/zephyr-7b-beta
|
||||
sft_weight: 0.2
|
||||
```
|
||||
|
||||
## Model-Size-Specific Recommendations
|
||||
|
||||
### 7B Models (Mistral, Llama 3)
|
||||
|
||||
**Standard config**:
|
||||
```yaml
|
||||
learning_rate: 5e-7
|
||||
beta: 2.0
|
||||
gamma_beta_ratio: 0.5
|
||||
sft_weight: 0.0 # 0.1 if instruct model
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
```
|
||||
|
||||
### 8B-13B Models
|
||||
|
||||
**Standard config**:
|
||||
```yaml
|
||||
learning_rate: 3e-7
|
||||
beta: 2.5
|
||||
gamma_beta_ratio: 0.5
|
||||
sft_weight: 0.1 # If instruct
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
```
|
||||
|
||||
### 70B Models
|
||||
|
||||
**Standard config**:
|
||||
```yaml
|
||||
learning_rate: 1e-7
|
||||
beta: 2.0
|
||||
gamma_beta_ratio: 0.5
|
||||
sft_weight: 0.05
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 16
|
||||
```
|
||||
|
||||
## Batch Size & Gradient Accumulation
|
||||
|
||||
### Effective Batch Size
|
||||
|
||||
```
|
||||
Effective Batch Size = per_device_batch_size * num_gpus * grad_accum_steps
|
||||
```
|
||||
|
||||
**Recommended effective batch sizes**:
|
||||
- 7B: 128-256
|
||||
- 13B: 64-128
|
||||
- 70B: 32-64
|
||||
|
||||
### Config Examples
|
||||
|
||||
**Single GPU (A100 40GB)**:
|
||||
```yaml
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 128 # Effective batch = 128
|
||||
```
|
||||
|
||||
**4 GPUs (A100 40GB)**:
|
||||
```yaml
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 16 # Effective batch = 2*4*16 = 128
|
||||
```
|
||||
|
||||
**8 GPUs (A100 80GB)**:
|
||||
```yaml
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 8 # Effective batch = 2*8*8 = 128
|
||||
```
|
||||
|
||||
## Loss Type
|
||||
|
||||
### Sigmoid vs Hinge
|
||||
|
||||
**Sigmoid** (default, recommended):
|
||||
```yaml
|
||||
loss_type: sigmoid
|
||||
label_smoothing: 0.0
|
||||
```
|
||||
|
||||
**Hinge** (experimental):
|
||||
```yaml
|
||||
loss_type: hinge
|
||||
# No label smoothing for hinge
|
||||
```
|
||||
|
||||
**When to use hinge**:
|
||||
- Margin-based tasks
|
||||
- SVM-style optimization
|
||||
- Experimental purposes
|
||||
|
||||
**Generally**: Stick with sigmoid
|
||||
|
||||
## Tuning Guide
|
||||
|
||||
### Step 1: Start with Defaults
|
||||
|
||||
```yaml
|
||||
learning_rate: 5e-7 # For 7B
|
||||
beta: 2.0
|
||||
gamma_beta_ratio: 0.5
|
||||
sft_weight: 0.0 # 0.1 if instruct
|
||||
loss_type: sigmoid
|
||||
```
|
||||
|
||||
### Step 2: Monitor Training
|
||||
|
||||
**Check every 100 steps**:
|
||||
- Loss curve (should decrease smoothly)
|
||||
- Reward margin (should increase)
|
||||
- Chosen/rejected logps (should separate)
|
||||
|
||||
### Step 3: Adjust if Needed
|
||||
|
||||
**If loss diverges**:
|
||||
```yaml
|
||||
learning_rate: 3e-7 # Reduce from 5e-7
|
||||
beta: 1.0 # Reduce from 2.0
|
||||
```
|
||||
|
||||
**If loss plateaus early**:
|
||||
```yaml
|
||||
learning_rate: 1e-6 # Increase from 5e-7
|
||||
beta: 5.0 # Increase from 2.0
|
||||
```
|
||||
|
||||
**If model forgets**:
|
||||
```yaml
|
||||
sft_weight: 0.2 # Increase from 0.0
|
||||
```
|
||||
|
||||
## Complete Example Configs
|
||||
|
||||
### Mistral 7B Base (Standard)
|
||||
|
||||
```yaml
|
||||
model_name_or_path: mistralai/Mistral-7B-v0.1
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 1.0
|
||||
|
||||
learning_rate: 5e-7
|
||||
beta: 2.0
|
||||
gamma_beta_ratio: 0.5
|
||||
loss_type: sigmoid
|
||||
sft_weight: 0.0
|
||||
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
warmup_ratio: 0.1
|
||||
lr_scheduler_type: cosine
|
||||
|
||||
bf16: true
|
||||
gradient_checkpointing: true
|
||||
```
|
||||
|
||||
### Llama 3 8B Instruct (Reasoning)
|
||||
|
||||
```yaml
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
dataset_mixer:
|
||||
argilla/distilabel-math-preference-dpo: 1.0
|
||||
|
||||
learning_rate: 3e-7
|
||||
beta: 5.0
|
||||
gamma_beta_ratio: 0.7
|
||||
loss_type: sigmoid
|
||||
sft_weight: 0.1
|
||||
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 16
|
||||
warmup_ratio: 0.1
|
||||
lr_scheduler_type: cosine
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- SimPO paper: https://arxiv.org/abs/2405.14734
|
||||
- Alignment Handbook: https://github.com/huggingface/alignment-handbook
|
||||
350
optional-skills/mlops/simpo/references/loss-functions.md
Normal file
350
optional-skills/mlops/simpo/references/loss-functions.md
Normal file
|
|
@ -0,0 +1,350 @@
|
|||
# Loss Functions
|
||||
|
||||
Complete guide to SimPO loss functions and mathematical formulations.
|
||||
|
||||
## Overview
|
||||
|
||||
SimPO supports two loss types:
|
||||
- **Sigmoid** (default) - Smooth, differentiable loss
|
||||
- **Hinge** - Margin-based, sparse loss
|
||||
|
||||
Both are reference-free (no reference model needed).
|
||||
|
||||
## SimPO Loss Formula
|
||||
|
||||
### Core Calculation
|
||||
|
||||
**Step 1: Log probability ratio**:
|
||||
```
|
||||
pi_logratios = log P_θ(y_chosen|x) - log P_θ(y_rejected|x)
|
||||
```
|
||||
|
||||
**Step 2: Apply target margin**:
|
||||
```
|
||||
logits = pi_logratios - γ/β
|
||||
```
|
||||
Where:
|
||||
- γ/β = `gamma_beta_ratio` (target margin)
|
||||
|
||||
**Step 3: Compute loss** (depends on loss type)
|
||||
|
||||
### Sigmoid Loss (Default)
|
||||
|
||||
**Formula**:
|
||||
```
|
||||
L = -log σ(β * logits) * (1 - ε) - log σ(-β * logits) * ε
|
||||
```
|
||||
|
||||
Where:
|
||||
- β = `beta` (reward scaling)
|
||||
- σ = sigmoid function
|
||||
- ε = `label_smoothing` (default 0.0)
|
||||
|
||||
**Implementation**:
|
||||
```python
|
||||
losses = (
|
||||
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
||||
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
||||
)
|
||||
```
|
||||
|
||||
**Characteristics**:
|
||||
- Smooth, continuous gradients
|
||||
- Probabilistic interpretation
|
||||
- Standard choice for most tasks
|
||||
- Works well with higher beta values
|
||||
|
||||
### Hinge Loss
|
||||
|
||||
**Formula**:
|
||||
```
|
||||
L = max(0, 1 - β * logits)
|
||||
```
|
||||
|
||||
**Implementation**:
|
||||
```python
|
||||
losses = torch.relu(1 - self.beta * logits)
|
||||
```
|
||||
|
||||
**Characteristics**:
|
||||
- Non-smooth (has kink at logits = 1/β)
|
||||
- Margin-based (SVM-style)
|
||||
- Can lead to sparser solutions
|
||||
- Less commonly used
|
||||
|
||||
## Comparison to DPO
|
||||
|
||||
### DPO Loss (Reference Model Required)
|
||||
|
||||
**Formula**:
|
||||
```
|
||||
L_DPO = -E[log σ(β * log(π_θ(y_w|x)/π_ref(y_w|x)) - β * log(π_θ(y_l|x)/π_ref(y_l|x)))]
|
||||
```
|
||||
|
||||
**Key features**:
|
||||
- Requires reference model π_ref
|
||||
- Normalizes by reference log probabilities
|
||||
- More conservative (stays close to reference)
|
||||
|
||||
### SimPO Loss (Reference-Free)
|
||||
|
||||
**Formula**:
|
||||
```
|
||||
L_SimPO = -log σ(β * (log π_θ(y_w|x) - log π_θ(y_l|x) - γ/β))
|
||||
```
|
||||
|
||||
**Key features**:
|
||||
- No reference model needed
|
||||
- Direct preference optimization
|
||||
- Target margin γ/β controls preference strength
|
||||
- More efficient (fewer model forward passes)
|
||||
|
||||
**Visual comparison**:
|
||||
```
|
||||
DPO: [Policy] - [Reference] → Loss
|
||||
SimPO: [Policy] → Loss
|
||||
```
|
||||
|
||||
## Average Log Probability Reward
|
||||
|
||||
### Calculation
|
||||
|
||||
**Per-token log probabilities**:
|
||||
```python
|
||||
# Get log probs for each token
|
||||
per_token_logps = log_softmax(logits).gather(dim=-1, index=labels)
|
||||
|
||||
# Create mask to ignore padding
|
||||
loss_mask = (labels != label_pad_token_id)
|
||||
```
|
||||
|
||||
**Average log probability** (if `average_log_prob=True`):
|
||||
```python
|
||||
avg_logp = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
||||
```
|
||||
|
||||
**Sum log probability** (if `average_log_prob=False`):
|
||||
```python
|
||||
sum_logp = (per_token_logps * loss_mask).sum(-1)
|
||||
```
|
||||
|
||||
**Why average?**
|
||||
- Normalizes for sequence length
|
||||
- Prevents bias toward shorter/longer responses
|
||||
- Standard practice in SimPO
|
||||
|
||||
### Reward Metrics
|
||||
|
||||
**Chosen reward**:
|
||||
```python
|
||||
chosen_rewards = beta * policy_chosen_logps.detach()
|
||||
```
|
||||
|
||||
**Rejected reward**:
|
||||
```python
|
||||
rejected_rewards = beta * policy_rejected_logps.detach()
|
||||
```
|
||||
|
||||
**Reward margin**:
|
||||
```python
|
||||
reward_margin = chosen_rewards.mean() - rejected_rewards.mean()
|
||||
```
|
||||
|
||||
## Label Smoothing
|
||||
|
||||
### Formula with Smoothing
|
||||
|
||||
**Sigmoid loss**:
|
||||
```
|
||||
L = -log σ(β * logits) * (1 - ε) - log σ(-β * logits) * ε
|
||||
```
|
||||
|
||||
**Effect**:
|
||||
- ε = 0.0: No smoothing (default)
|
||||
- ε = 0.1: 10% smoothing (soft labels)
|
||||
- ε = 0.5: Maximum smoothing
|
||||
|
||||
**When to use**:
|
||||
- Noisy preference labels
|
||||
- Uncertain preferences
|
||||
- Prevent overconfidence
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
label_smoothing: 0.1 # 10% smoothing
|
||||
```
|
||||
|
||||
## SFT Regularization
|
||||
|
||||
### Combined Loss
|
||||
|
||||
**With SFT component**:
|
||||
```
|
||||
L_total = L_SimPO + λ * L_SFT
|
||||
```
|
||||
|
||||
Where:
|
||||
- L_SFT = cross-entropy loss on chosen responses
|
||||
- λ = `sft_weight` (0.0 to 1.0)
|
||||
|
||||
**Implementation**:
|
||||
```python
|
||||
if self.sft_weight > 0:
|
||||
sft_loss = -policy_chosen_logps
|
||||
total_loss = simpo_loss + self.sft_weight * sft_loss
|
||||
```
|
||||
|
||||
**When to use**:
|
||||
- Preserve model capabilities
|
||||
- Prevent catastrophic forgetting
|
||||
- Fine-tuning instruct models
|
||||
|
||||
**Trade-off**:
|
||||
- Higher sft_weight: Preserve capabilities, less alignment
|
||||
- Lower sft_weight: Stronger alignment, may forget capabilities
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
sft_weight: 0.1 # 10% SFT regularization
|
||||
```
|
||||
|
||||
## Loss Type Selection
|
||||
|
||||
### Sigmoid vs Hinge
|
||||
|
||||
| Aspect | Sigmoid | Hinge |
|
||||
|--------|---------|-------|
|
||||
| Smoothness | Smooth | Non-smooth |
|
||||
| Gradients | Continuous | Discontinuous at margin |
|
||||
| Sparsity | Dense solutions | Sparse solutions |
|
||||
| Interpretability | Probabilistic | Geometric margin |
|
||||
| Use case | **General purpose** | Margin-based tasks |
|
||||
| Recommendation | **Default choice** | Experimental |
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
# Sigmoid (default)
|
||||
loss_type: sigmoid
|
||||
|
||||
# Hinge (alternative)
|
||||
loss_type: hinge
|
||||
```
|
||||
|
||||
## Mathematical Properties
|
||||
|
||||
### Gradient Analysis
|
||||
|
||||
**Sigmoid loss gradient**:
|
||||
```
|
||||
∂L/∂logits = -β * σ(-β * logits) * (1 - ε) + β * σ(β * logits) * ε
|
||||
```
|
||||
|
||||
**Hinge loss gradient**:
|
||||
```
|
||||
∂L/∂logits = -β if logits < 1/β
|
||||
0 otherwise
|
||||
```
|
||||
|
||||
**Implications**:
|
||||
- Sigmoid: Always provides gradient signal
|
||||
- Hinge: No gradient when margin satisfied
|
||||
|
||||
### Convergence Behavior
|
||||
|
||||
**Sigmoid**:
|
||||
- Asymptotically approaches zero loss
|
||||
- Continues optimizing even with large margins
|
||||
- Smoother training curves
|
||||
|
||||
**Hinge**:
|
||||
- Reaches zero loss at margin
|
||||
- Stops optimizing once margin satisfied
|
||||
- May have training plateaus
|
||||
|
||||
## Complete Loss Examples
|
||||
|
||||
### Example 1: Basic SimPO (Sigmoid)
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
beta: 2.0
|
||||
gamma_beta_ratio: 0.5
|
||||
loss_type: sigmoid
|
||||
label_smoothing: 0.0
|
||||
sft_weight: 0.0
|
||||
```
|
||||
|
||||
**Loss calculation**:
|
||||
```python
|
||||
# Step 1: Compute log probs
|
||||
chosen_logps = avg_log_prob(policy(chosen)) # e.g., -1.2
|
||||
rejected_logps = avg_log_prob(policy(rejected)) # e.g., -2.5
|
||||
|
||||
# Step 2: Log ratio and margin
|
||||
pi_logratios = -1.2 - (-2.5) = 1.3
|
||||
logits = 1.3 - 0.5 = 0.8
|
||||
|
||||
# Step 3: Sigmoid loss
|
||||
loss = -log(sigmoid(2.0 * 0.8))
|
||||
= -log(sigmoid(1.6))
|
||||
= -log(0.832)
|
||||
= 0.184
|
||||
```
|
||||
|
||||
### Example 2: SimPO with SFT
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
beta: 2.5
|
||||
gamma_beta_ratio: 0.5
|
||||
loss_type: sigmoid
|
||||
sft_weight: 0.1
|
||||
```
|
||||
|
||||
**Loss calculation**:
|
||||
```python
|
||||
# SimPO loss (as above)
|
||||
simpo_loss = 0.184
|
||||
|
||||
# SFT loss
|
||||
sft_loss = -chosen_logps = -(-1.2) = 1.2
|
||||
|
||||
# Total loss
|
||||
total_loss = simpo_loss + 0.1 * sft_loss
|
||||
= 0.184 + 0.12
|
||||
= 0.304
|
||||
```
|
||||
|
||||
## Debugging
|
||||
|
||||
### Check Reward Margins
|
||||
|
||||
**Low margin (< 0.5)**:
|
||||
- Preferences not being learned
|
||||
- Increase beta or gamma_beta_ratio
|
||||
|
||||
**High margin (> 5.0)**:
|
||||
- May be overfitting
|
||||
- Reduce beta or learning rate
|
||||
|
||||
**Monitor**:
|
||||
```python
|
||||
reward_margin = chosen_rewards.mean() - rejected_rewards.mean()
|
||||
print(f"Reward margin: {reward_margin:.2f}")
|
||||
```
|
||||
|
||||
### Check Log Probabilities
|
||||
|
||||
**Typical values**:
|
||||
- Chosen: -1.0 to -2.0 (higher is better)
|
||||
- Rejected: -2.0 to -4.0 (lower is worse)
|
||||
|
||||
**Warning signs**:
|
||||
- Both very negative (< -10): Model not learning
|
||||
- Both very positive (> 0): Numerical instability
|
||||
|
||||
## References
|
||||
|
||||
- SimPO paper: https://arxiv.org/abs/2405.14734
|
||||
- DPO paper: https://arxiv.org/abs/2305.18290
|
||||
- Implementation: https://github.com/princeton-nlp/SimPO
|
||||
467
optional-skills/mlops/slime/SKILL.md
Normal file
467
optional-skills/mlops/slime/SKILL.md
Normal file
|
|
@ -0,0 +1,467 @@
|
|||
---
|
||||
name: slime-rl-training
|
||||
description: Provides guidance for LLM post-training with RL using slime, a Megatron+SGLang framework. Use when training GLM models, implementing custom data generation workflows, or needing tight Megatron-LM integration for RL scaling.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [sglang-router>=0.2.3, ray, torch>=2.0.0, transformers>=4.40.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Reinforcement Learning, Megatron-LM, SGLang, GRPO, Post-Training, GLM]
|
||||
|
||||
---
|
||||
|
||||
# slime: LLM Post-Training Framework for RL Scaling
|
||||
|
||||
slime is an LLM post-training framework from Tsinghua's THUDM team, powering GLM-4.5, GLM-4.6, and GLM-4.7. It connects Megatron-LM for training with SGLang for high-throughput rollout generation.
|
||||
|
||||
## When to Use slime
|
||||
|
||||
**Choose slime when you need:**
|
||||
- Megatron-LM native training with SGLang inference
|
||||
- Custom data generation workflows with flexible data buffers
|
||||
- Training GLM, Qwen3, DeepSeek V3, or Llama 3 models
|
||||
- Research-grade framework with production backing (Z.ai)
|
||||
|
||||
**Consider alternatives when:**
|
||||
- You need enterprise-grade stability features → use **miles**
|
||||
- You want flexible backend swapping → use **verl**
|
||||
- You need PyTorch-native abstractions → use **torchforge**
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Training**: Megatron-LM with full parallelism support (TP, PP, DP, SP)
|
||||
- **Rollout**: SGLang-based high-throughput generation with router
|
||||
- **Data Buffer**: Flexible prompt management and sample storage
|
||||
- **Models**: GLM-4.x, Qwen3, DeepSeek V3/R1, Llama 3
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ Data Buffer │
|
||||
│ - Prompt initialization and management │
|
||||
│ - Custom data generation and filtering │
|
||||
│ - Rollout sample storage │
|
||||
└─────────────┬───────────────────────────┬───────────────┘
|
||||
│ │
|
||||
┌─────────────▼───────────┐ ┌─────────────▼───────────────┐
|
||||
│ Training (Megatron-LM) │ │ Rollout (SGLang + Router) │
|
||||
│ - Actor model training │ │ - Response generation │
|
||||
│ - Critic (optional) │ │ - Reward/verifier output │
|
||||
│ - Weight sync to rollout│ │ - Multi-turn support │
|
||||
└─────────────────────────┘ └─────────────────────────────┘
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Recommended: Docker
|
||||
docker pull slimerl/slime:latest
|
||||
docker run --rm --gpus all --ipc=host --shm-size=16g \
|
||||
-it slimerl/slime:latest /bin/bash
|
||||
|
||||
# Inside container
|
||||
cd /root/slime && pip install -e . --no-deps
|
||||
```
|
||||
|
||||
### From Source
|
||||
|
||||
```bash
|
||||
git clone https://github.com/THUDM/slime.git
|
||||
cd slime
|
||||
pip install -r requirements.txt
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## Quick Start: GRPO Training
|
||||
|
||||
```bash
|
||||
# Source model configuration
|
||||
source scripts/models/qwen3-4B.sh
|
||||
|
||||
# Launch training
|
||||
python train.py \
|
||||
--actor-num-nodes 1 \
|
||||
--actor-num-gpus-per-node 4 \
|
||||
--rollout-num-gpus 4 \
|
||||
--advantage-estimator grpo \
|
||||
--use-kl-loss --kl-loss-coef 0.001 \
|
||||
--rollout-batch-size 32 \
|
||||
--n-samples-per-prompt 8 \
|
||||
--global-batch-size 256 \
|
||||
--num-rollout 3000 \
|
||||
--prompt-data /path/to/data.jsonl \
|
||||
${MODEL_ARGS[@]} ${CKPT_ARGS[@]}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Workflow 1: Standard GRPO Training
|
||||
|
||||
Use this workflow for training reasoning models with group-relative advantages.
|
||||
|
||||
### Prerequisites Checklist
|
||||
- [ ] Docker environment or Megatron-LM + SGLang installed
|
||||
- [ ] Model checkpoint (HuggingFace or Megatron format)
|
||||
- [ ] Training data in JSONL format
|
||||
|
||||
### Step 1: Prepare Data
|
||||
|
||||
```python
|
||||
# data.jsonl format
|
||||
{"prompt": "What is 2 + 2?", "label": "4"}
|
||||
{"prompt": "Solve: 3x = 12", "label": "x = 4"}
|
||||
```
|
||||
|
||||
Or with chat format:
|
||||
```python
|
||||
{
|
||||
"prompt": [
|
||||
{"role": "system", "content": "You are a math tutor."},
|
||||
{"role": "user", "content": "What is 15 + 27?"}
|
||||
],
|
||||
"label": "42"
|
||||
}
|
||||
```
|
||||
|
||||
### Step 2: Configure Model
|
||||
|
||||
Choose a pre-configured model script:
|
||||
|
||||
```bash
|
||||
# List available models
|
||||
ls scripts/models/
|
||||
# glm4-9B.sh, qwen3-4B.sh, qwen3-30B-A3B.sh, deepseek-v3.sh, llama3-8B.sh, ...
|
||||
|
||||
# Source your model
|
||||
source scripts/models/qwen3-4B.sh
|
||||
```
|
||||
|
||||
### Step 3: Launch Training
|
||||
|
||||
```bash
|
||||
python train.py \
|
||||
--actor-num-nodes 1 \
|
||||
--actor-num-gpus-per-node 8 \
|
||||
--rollout-num-gpus 8 \
|
||||
--advantage-estimator grpo \
|
||||
--use-kl-loss \
|
||||
--kl-loss-coef 0.001 \
|
||||
--prompt-data /path/to/train.jsonl \
|
||||
--input-key prompt \
|
||||
--label-key label \
|
||||
--apply-chat-template \
|
||||
--rollout-batch-size 32 \
|
||||
--n-samples-per-prompt 8 \
|
||||
--global-batch-size 256 \
|
||||
--num-rollout 3000 \
|
||||
--save-interval 100 \
|
||||
--eval-interval 50 \
|
||||
${MODEL_ARGS[@]}
|
||||
```
|
||||
|
||||
### Step 4: Monitor Training
|
||||
- [ ] Check TensorBoard: `tensorboard --logdir outputs/`
|
||||
- [ ] Verify reward curves are increasing
|
||||
- [ ] Monitor GPU utilization across nodes
|
||||
|
||||
---
|
||||
|
||||
## Workflow 2: Asynchronous Training
|
||||
|
||||
Use async mode for higher throughput by overlapping rollout and training.
|
||||
|
||||
### When to Use Async
|
||||
- Large models with long generation times
|
||||
- High GPU idle time in synchronous mode
|
||||
- Sufficient memory for buffering
|
||||
|
||||
### Launch Async Training
|
||||
|
||||
```bash
|
||||
python train_async.py \
|
||||
--actor-num-nodes 1 \
|
||||
--actor-num-gpus-per-node 8 \
|
||||
--rollout-num-gpus 8 \
|
||||
--advantage-estimator grpo \
|
||||
--async-buffer-size 4 \
|
||||
--prompt-data /path/to/train.jsonl \
|
||||
${MODEL_ARGS[@]}
|
||||
```
|
||||
|
||||
### Async-Specific Parameters
|
||||
|
||||
```bash
|
||||
--async-buffer-size 4 # Number of rollouts to buffer
|
||||
--update-weights-interval 2 # Sync weights every N rollouts
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Workflow 3: Multi-Turn Agentic Training
|
||||
|
||||
Use this workflow for training agents with tool use or multi-step reasoning.
|
||||
|
||||
### Prerequisites
|
||||
- [ ] Custom generate function for multi-turn logic
|
||||
- [ ] Tool/environment interface
|
||||
|
||||
### Step 1: Define Custom Generate Function
|
||||
|
||||
```python
|
||||
# custom_generate.py
|
||||
async def custom_generate(args, samples, evaluation=False):
|
||||
"""Multi-turn generation with tool calling."""
|
||||
for sample in samples:
|
||||
conversation = sample.prompt
|
||||
|
||||
for turn in range(args.max_turns):
|
||||
# Generate response
|
||||
response = await generate_single(conversation)
|
||||
|
||||
# Check for tool call
|
||||
tool_call = extract_tool_call(response)
|
||||
if tool_call:
|
||||
tool_result = execute_tool(tool_call)
|
||||
conversation.append({"role": "assistant", "content": response})
|
||||
conversation.append({"role": "tool", "content": tool_result})
|
||||
else:
|
||||
break
|
||||
|
||||
sample.response = response
|
||||
sample.reward = compute_reward(sample)
|
||||
|
||||
return samples
|
||||
```
|
||||
|
||||
### Step 2: Launch with Custom Function
|
||||
|
||||
```bash
|
||||
python train.py \
|
||||
--custom-generate-function-path custom_generate.py \
|
||||
--max-turns 5 \
|
||||
--prompt-data /path/to/agent_data.jsonl \
|
||||
${MODEL_ARGS[@]}
|
||||
```
|
||||
|
||||
See `examples/search-r1/` for a complete multi-turn search example.
|
||||
|
||||
---
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
### Three Argument Categories
|
||||
|
||||
slime uses three types of arguments:
|
||||
|
||||
**1. Megatron Arguments** (passed directly):
|
||||
```bash
|
||||
--tensor-model-parallel-size 2
|
||||
--pipeline-model-parallel-size 1
|
||||
--num-layers 32
|
||||
--hidden-size 4096
|
||||
```
|
||||
|
||||
**2. SGLang Arguments** (prefixed with `--sglang-`):
|
||||
```bash
|
||||
--sglang-mem-fraction-static 0.8
|
||||
--sglang-context-length 8192
|
||||
--sglang-log-level INFO
|
||||
```
|
||||
|
||||
**3. slime Arguments**:
|
||||
```bash
|
||||
# Resource allocation
|
||||
--actor-num-nodes 1
|
||||
--actor-num-gpus-per-node 8
|
||||
--rollout-num-gpus 8
|
||||
--colocate # Share GPUs between training/inference
|
||||
|
||||
# Data
|
||||
--prompt-data /path/to/data.jsonl
|
||||
--input-key prompt
|
||||
--label-key label
|
||||
|
||||
# Training loop
|
||||
--num-rollout 3000
|
||||
--rollout-batch-size 32
|
||||
--n-samples-per-prompt 8
|
||||
--global-batch-size 256
|
||||
|
||||
# Algorithm
|
||||
--advantage-estimator grpo # or: gspo, ppo, reinforce_plus_plus
|
||||
--use-kl-loss
|
||||
--kl-loss-coef 0.001
|
||||
```
|
||||
|
||||
### Key Constraints
|
||||
|
||||
```
|
||||
rollout_batch_size × n_samples_per_prompt = global_batch_size × num_steps_per_rollout
|
||||
```
|
||||
|
||||
Example: 32 × 8 = 256 × 1
|
||||
|
||||
---
|
||||
|
||||
## Data Buffer System
|
||||
|
||||
slime's data buffer enables flexible data management:
|
||||
|
||||
### Basic Data Source
|
||||
|
||||
```python
|
||||
class RolloutDataSource:
|
||||
def get_samples(self, num_samples):
|
||||
"""Fetch prompts from dataset."""
|
||||
return self.dataset.sample(num_samples)
|
||||
|
||||
def add_samples(self, samples):
|
||||
"""Called after generation (no-op by default)."""
|
||||
pass
|
||||
```
|
||||
|
||||
### Buffered Data Source (Off-Policy)
|
||||
|
||||
```python
|
||||
class RolloutDataSourceWithBuffer(RolloutDataSource):
|
||||
def __init__(self):
|
||||
self.buffer = []
|
||||
|
||||
def add_samples(self, samples):
|
||||
"""Store generated samples for reuse."""
|
||||
self.buffer.extend(samples)
|
||||
|
||||
def buffer_filter(self, args, buffer, num_samples):
|
||||
"""Custom selection logic (prioritized, stratified, etc.)."""
|
||||
return select_best(buffer, num_samples)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Common Issues and Solutions
|
||||
|
||||
### Issue: SGLang Engine Crash
|
||||
|
||||
**Symptoms**: Inference engine dies mid-training
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Enable fault tolerance
|
||||
--use-fault-tolerance
|
||||
|
||||
# Increase memory allocation
|
||||
--sglang-mem-fraction-static 0.85
|
||||
|
||||
# Reduce batch size
|
||||
--rollout-batch-size 16
|
||||
```
|
||||
|
||||
### Issue: Weight Sync Timeout
|
||||
|
||||
**Symptoms**: Training hangs after rollout
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Increase sync interval
|
||||
--update-weights-interval 5
|
||||
|
||||
# Use colocated mode (no network transfer)
|
||||
--colocate
|
||||
```
|
||||
|
||||
### Issue: OOM During Training
|
||||
|
||||
**Symptoms**: CUDA OOM in backward pass
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Enable gradient checkpointing
|
||||
--recompute-activations
|
||||
|
||||
# Reduce micro-batch size
|
||||
--micro-batch-size 1
|
||||
|
||||
# Enable sequence parallelism
|
||||
--sequence-parallel
|
||||
```
|
||||
|
||||
### Issue: Slow Data Loading
|
||||
|
||||
**Symptoms**: GPU idle during data fetch
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Increase data workers
|
||||
--num-data-workers 4
|
||||
|
||||
# Use streaming dataset
|
||||
--streaming-data
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model Family | Configurations |
|
||||
|--------------|----------------|
|
||||
| GLM | GLM-4.5, GLM-4.6, GLM-4.7, GLM-Z1-9B |
|
||||
| Qwen | Qwen3 (4B, 8B, 30B-A3B), Qwen3-MoE, Qwen2.5 |
|
||||
| DeepSeek | V3, V3.1, R1 |
|
||||
| Llama | Llama 3 (8B, 70B) |
|
||||
| Others | Kimi K2, Moonlight-16B |
|
||||
|
||||
Each model has pre-configured scripts in `scripts/models/`.
|
||||
|
||||
---
|
||||
|
||||
## Advanced Topics
|
||||
|
||||
### Co-location Mode
|
||||
|
||||
Share GPUs between training and inference to reduce memory:
|
||||
|
||||
```bash
|
||||
python train.py \
|
||||
--colocate \
|
||||
--actor-num-gpus-per-node 8 \
|
||||
--sglang-mem-fraction-static 0.4 \
|
||||
${MODEL_ARGS[@]}
|
||||
```
|
||||
|
||||
### Custom Reward Model
|
||||
|
||||
```python
|
||||
# custom_rm.py
|
||||
class CustomRewardModel:
|
||||
def __init__(self, model_path):
|
||||
self.model = load_model(model_path)
|
||||
|
||||
def compute_reward(self, prompts, responses):
|
||||
inputs = self.tokenize(prompts, responses)
|
||||
scores = self.model(inputs)
|
||||
return scores.tolist()
|
||||
```
|
||||
|
||||
```bash
|
||||
--custom-rm-path custom_rm.py
|
||||
```
|
||||
|
||||
### Evaluation Multi-Task
|
||||
|
||||
```bash
|
||||
--eval-prompt-data aime /path/to/aime.jsonl \
|
||||
--eval-prompt-data gsm8k /path/to/gsm8k.jsonl \
|
||||
--n-samples-per-eval-prompt 16
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://thudm.github.io/slime/
|
||||
- **GitHub**: https://github.com/THUDM/slime
|
||||
- **Blog**: https://lmsys.org/blog/2025-07-09-slime/
|
||||
- **Examples**: See `examples/` directory for 14+ worked examples
|
||||
|
||||
392
optional-skills/mlops/slime/references/api-reference.md
Normal file
392
optional-skills/mlops/slime/references/api-reference.md
Normal file
|
|
@ -0,0 +1,392 @@
|
|||
# slime API Reference
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
slime operates with a three-module architecture orchestrated by Ray:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ Data Buffer │
|
||||
│ - Prompt initialization and management │
|
||||
│ - Custom data generation and filtering │
|
||||
│ - Rollout sample storage │
|
||||
└─────────────┬───────────────────────────┬───────────────┘
|
||||
│ │
|
||||
┌─────────────▼───────────┐ ┌─────────────▼───────────────┐
|
||||
│ Training (Megatron-LM) │ │ Rollout (SGLang + Router) │
|
||||
│ - Actor model training │ │ - Response generation │
|
||||
│ - Critic (optional) │ │ - Reward/verifier output │
|
||||
│ - Weight sync to rollout│ │ - Multi-turn support │
|
||||
└─────────────────────────┘ └─────────────────────────────┘
|
||||
```
|
||||
|
||||
## Core Data Structures
|
||||
|
||||
### Sample Object
|
||||
|
||||
The `Sample` object is the core data structure defined in `slime/utils/types.py`:
|
||||
|
||||
```python
|
||||
from slime.utils.types import Sample
|
||||
|
||||
@dataclass
|
||||
class Sample:
|
||||
# Core fields
|
||||
group_index: Optional[int] # Group index for batching
|
||||
index: Optional[int] # Sample index
|
||||
prompt: str | list[dict] = "" # Input prompt or chat history
|
||||
tokens: list[int] = field(default_factory=list) # Token IDs
|
||||
response: str = "" # Generated response
|
||||
response_length: int = 0 # Response length in tokens
|
||||
label: Optional[str] = None # Ground truth label
|
||||
reward: Optional[float | dict] = None # RL reward signal
|
||||
loss_mask: Optional[list[int]] = None # 1=compute loss, 0=mask
|
||||
status: Status = Status.PENDING # Sample status
|
||||
metadata: dict = field(default_factory=dict) # Custom data
|
||||
|
||||
# Multimodal support
|
||||
multimodal_inputs: Optional[Any] = None # Raw multimodal data (images, videos)
|
||||
multimodal_train_inputs: Optional[Any] = None # Processed multimodal data (pixel_values)
|
||||
|
||||
# Rollout tracking
|
||||
weight_versions: list[str] = field(default_factory=list)
|
||||
rollout_log_probs: Optional[list[float]] = None # Log probs from SGLang
|
||||
rollout_routed_experts: Optional[list[list[int]]] = None # Expert routing (MoE)
|
||||
|
||||
# Control fields
|
||||
remove_sample: bool = False
|
||||
generate_function_path: Optional[str] = None
|
||||
train_metadata: Optional[dict] = None
|
||||
non_generation_time: float = 0.0
|
||||
|
||||
# Speculative decoding info (nested dataclass)
|
||||
@dataclass
|
||||
class SpecInfo:
|
||||
spec_accept_token_num: int = 0
|
||||
spec_draft_token_num: int = 0
|
||||
spec_verify_ct: int = 0
|
||||
completion_token_num: int = 0
|
||||
```
|
||||
|
||||
### Status Enum
|
||||
|
||||
```python
|
||||
class Status(Enum):
|
||||
PENDING = "pending" # Not yet processed
|
||||
COMPLETED = "completed" # Successfully generated
|
||||
TRUNCATED = "truncated" # Hit max length
|
||||
ABORTED = "aborted" # Failed generation
|
||||
FAILED = "failed" # Generation failed
|
||||
```
|
||||
|
||||
## Configuration System
|
||||
|
||||
slime uses three categories of command-line arguments:
|
||||
|
||||
### 1. Megatron Arguments
|
||||
|
||||
All Megatron-LM arguments are supported directly:
|
||||
|
||||
```bash
|
||||
--tensor-model-parallel-size 2
|
||||
--pipeline-model-parallel-size 1
|
||||
--num-layers 32
|
||||
--hidden-size 4096
|
||||
--num-attention-heads 32
|
||||
--seq-length 4096
|
||||
--micro-batch-size 1
|
||||
--global-batch-size 256
|
||||
```
|
||||
|
||||
### 2. SGLang Arguments
|
||||
|
||||
SGLang arguments are prefixed with `--sglang-`:
|
||||
|
||||
```bash
|
||||
--sglang-mem-fraction-static 0.8 # GPU memory for KV cache
|
||||
--sglang-context-length 8192 # Maximum context length
|
||||
--sglang-log-level INFO # Logging verbosity
|
||||
--sglang-tp-size 2 # Tensor parallelism
|
||||
--sglang-disable-cuda-graph # Disable CUDA graphs
|
||||
```
|
||||
|
||||
### 3. slime-Specific Arguments
|
||||
|
||||
Defined in `slime/utils/arguments.py`:
|
||||
|
||||
```bash
|
||||
# Resource Allocation
|
||||
--actor-num-nodes 1 # Training nodes
|
||||
--actor-num-gpus-per-node 8 # GPUs per training node
|
||||
--rollout-num-gpus 8 # Total rollout GPUs
|
||||
--rollout-num-gpus-per-engine 2 # GPUs per SGLang engine
|
||||
--colocate # Share GPUs for train/inference
|
||||
|
||||
# Data Configuration
|
||||
--prompt-data /path/to/data.jsonl # Training data path
|
||||
--input-key prompt # Key for prompts in JSON
|
||||
--label-key label # Key for labels in JSON
|
||||
--apply-chat-template # Apply chat formatting
|
||||
|
||||
# Training Loop
|
||||
--num-rollout 3000 # Total rollout iterations
|
||||
--rollout-batch-size 32 # Prompts per rollout
|
||||
--n-samples-per-prompt 8 # Responses per prompt
|
||||
--global-batch-size 256 # Training batch size
|
||||
--num-steps-per-rollout 1 # Training steps per rollout
|
||||
|
||||
# RL Algorithm
|
||||
--advantage-estimator grpo # grpo, gspo, ppo, reinforce_plus_plus
|
||||
--use-kl-loss # Enable KL loss
|
||||
--kl-loss-coef 0.001 # KL coefficient
|
||||
--calculate-per-token-loss # Token-level loss
|
||||
|
||||
# Off-Policy Options
|
||||
--use-tis # Truncated Importance Sampling
|
||||
--tis-threshold 0.9 # TIS threshold
|
||||
--true-on-policy-mode # Force on-policy training
|
||||
```
|
||||
|
||||
## Data Buffer System
|
||||
|
||||
### RolloutDataSource (Base Class)
|
||||
|
||||
```python
|
||||
from slime.data import RolloutDataSource
|
||||
|
||||
class RolloutDataSource:
|
||||
def __init__(self, dataset, args):
|
||||
self.dataset = dataset
|
||||
self.args = args
|
||||
|
||||
def get_samples(self, num_samples: int) -> list[Sample]:
|
||||
"""Fetch prompts from dataset."""
|
||||
return [Sample(prompt=p) for p in self.dataset.sample(num_samples)]
|
||||
|
||||
def add_samples(self, samples: list[Sample]) -> None:
|
||||
"""Called after generation (no-op by default)."""
|
||||
pass
|
||||
```
|
||||
|
||||
### Buffered Data Source (Off-Policy)
|
||||
|
||||
```python
|
||||
from slime.data import RolloutDataSourceWithBuffer
|
||||
|
||||
class RolloutDataSourceWithBuffer(RolloutDataSource):
|
||||
def __init__(self, dataset, args):
|
||||
super().__init__(dataset, args)
|
||||
self.buffer = []
|
||||
|
||||
def add_samples(self, samples: list[Sample]) -> None:
|
||||
"""Store generated samples for reuse."""
|
||||
self.buffer.extend(samples)
|
||||
|
||||
def buffer_filter(self, args, buffer, num_samples) -> list[Sample]:
|
||||
"""Custom selection logic."""
|
||||
# Example: prioritized sampling based on reward
|
||||
sorted_buffer = sorted(buffer, key=lambda s: s.reward, reverse=True)
|
||||
return sorted_buffer[:num_samples]
|
||||
```
|
||||
|
||||
## Custom Functions
|
||||
|
||||
### Custom Generate Function
|
||||
|
||||
For multi-turn or tool-calling scenarios:
|
||||
|
||||
```python
|
||||
# custom_generate.py
|
||||
from slime.data import Sample
|
||||
|
||||
async def custom_generate(args, samples: list[Sample], evaluation: bool = False) -> list[Sample]:
|
||||
"""
|
||||
Custom generation function for multi-turn interactions.
|
||||
|
||||
Args:
|
||||
args: Training arguments
|
||||
samples: List of Sample objects with prompts
|
||||
evaluation: Whether this is an evaluation run
|
||||
|
||||
Returns:
|
||||
List of Sample objects with responses and rewards
|
||||
"""
|
||||
for sample in samples:
|
||||
conversation = sample.prompt if isinstance(sample.prompt, list) else [
|
||||
{"role": "user", "content": sample.prompt}
|
||||
]
|
||||
|
||||
for turn in range(args.max_turns):
|
||||
# Generate response
|
||||
response = await generate_single(conversation)
|
||||
|
||||
# Check for tool call
|
||||
tool_call = extract_tool_call(response)
|
||||
if tool_call:
|
||||
# Execute tool
|
||||
tool_result = await execute_tool(tool_call)
|
||||
conversation.append({"role": "assistant", "content": response})
|
||||
conversation.append({"role": "tool", "content": tool_result})
|
||||
else:
|
||||
# Final response
|
||||
sample.response = response
|
||||
break
|
||||
|
||||
# Compute reward
|
||||
sample.reward = compute_reward(sample)
|
||||
|
||||
# Set loss mask (1 for model tokens, 0 for tool responses)
|
||||
sample.loss_mask = build_loss_mask(sample)
|
||||
|
||||
return samples
|
||||
```
|
||||
|
||||
Usage:
|
||||
```bash
|
||||
python train.py \
|
||||
--custom-generate-function-path custom_generate.py \
|
||||
--max-turns 5
|
||||
```
|
||||
|
||||
### Custom Reward Function
|
||||
|
||||
```python
|
||||
# custom_rm.py
|
||||
from slime.data import Sample
|
||||
|
||||
async def reward_func(args, sample: Sample, **kwargs) -> float:
|
||||
"""
|
||||
Compute reward for a single sample.
|
||||
|
||||
Args:
|
||||
args: Training arguments
|
||||
sample: Sample object with response
|
||||
|
||||
Returns:
|
||||
Reward score (float)
|
||||
"""
|
||||
response = sample.response
|
||||
ground_truth = sample.label or sample.metadata.get("answer", "")
|
||||
|
||||
# Example: exact match reward
|
||||
if response.strip() == ground_truth.strip():
|
||||
return 1.0
|
||||
return 0.0
|
||||
|
||||
# For batched processing (more efficient)
|
||||
async def batched_custom_rm(args, samples: list[Sample]) -> list[float]:
|
||||
"""Batch reward computation."""
|
||||
rewards = []
|
||||
for sample in samples:
|
||||
reward = await reward_func(args, sample)
|
||||
rewards.append(reward)
|
||||
return rewards
|
||||
```
|
||||
|
||||
Usage:
|
||||
```bash
|
||||
python train.py \
|
||||
--custom-rm-path custom_rm.py \
|
||||
--group-rm # Enable batched processing
|
||||
```
|
||||
|
||||
## Model Configuration
|
||||
|
||||
### Pre-configured Model Scripts
|
||||
|
||||
Located in `scripts/models/`:
|
||||
|
||||
```bash
|
||||
# List available models
|
||||
ls scripts/models/
|
||||
# glm4-9B.sh, qwen3-4B.sh, qwen3-30B-A3B.sh, deepseek-v3.sh, llama3-8B.sh
|
||||
|
||||
# Source model configuration
|
||||
source scripts/models/qwen3-4B.sh
|
||||
# This sets MODEL_ARGS and CKPT_ARGS arrays
|
||||
```
|
||||
|
||||
### Example Model Script
|
||||
|
||||
```bash
|
||||
# scripts/models/qwen3-4B.sh
|
||||
export MODEL_ARGS=(
|
||||
--num-layers 36
|
||||
--hidden-size 2560
|
||||
--num-attention-heads 20
|
||||
--num-query-groups 4
|
||||
--ffn-hidden-size 6912
|
||||
--max-position-embeddings 32768
|
||||
--rotary-percent 1.0
|
||||
--rotary-base 1000000
|
||||
--swiglu
|
||||
--untie-embeddings-and-output-weights
|
||||
--no-position-embedding
|
||||
--normalization RMSNorm
|
||||
--tokenizer-type HuggingFaceTokenizer
|
||||
--bf16
|
||||
)
|
||||
|
||||
export CKPT_ARGS=(
|
||||
--hf-checkpoint /path/to/qwen3-4b-hf
|
||||
--initial-megatron-checkpoint /path/to/megatron/ckpt
|
||||
)
|
||||
```
|
||||
|
||||
## Async Training
|
||||
|
||||
### Enabling Async Mode
|
||||
|
||||
```bash
|
||||
python train_async.py \
|
||||
--actor-num-gpus-per-node 8 \
|
||||
--rollout-num-gpus 8 \
|
||||
--async-buffer-size 4 \
|
||||
--update-weights-interval 2 \
|
||||
${MODEL_ARGS[@]}
|
||||
```
|
||||
|
||||
### Async-Specific Parameters
|
||||
|
||||
```bash
|
||||
--async-buffer-size 4 # Number of rollouts to buffer
|
||||
--update-weights-interval 2 # Sync weights every N rollouts
|
||||
```
|
||||
|
||||
**Note**: Colocated mode (`--colocate`) is NOT supported with async training.
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Multi-Task Evaluation
|
||||
|
||||
```bash
|
||||
--eval-prompt-data aime /path/to/aime.jsonl \
|
||||
--eval-prompt-data gsm8k /path/to/gsm8k.jsonl \
|
||||
--n-samples-per-eval-prompt 16 \
|
||||
--eval-interval 50
|
||||
```
|
||||
|
||||
### Evaluation Configuration
|
||||
|
||||
```bash
|
||||
--eval-interval 50 # Evaluate every N rollouts
|
||||
--n-samples-per-eval-prompt 16 # Samples for evaluation
|
||||
--eval-temperature 0.0 # Greedy decoding for eval
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model Family | Configurations |
|
||||
|--------------|----------------|
|
||||
| GLM | GLM-4.5, GLM-4.6, GLM-4.7, GLM-Z1-9B |
|
||||
| Qwen | Qwen3 (4B, 8B, 30B-A3B), Qwen3-MoE, Qwen2.5 |
|
||||
| DeepSeek | V3, V3.1, R1 |
|
||||
| Llama | Llama 3 (8B, 70B) |
|
||||
| Others | Kimi K2, Moonlight-16B |
|
||||
|
||||
## Resources
|
||||
|
||||
- Documentation: https://thudm.github.io/slime/
|
||||
- GitHub: https://github.com/THUDM/slime
|
||||
- Blog: https://lmsys.org/blog/2025-07-09-slime/
|
||||
- Examples: `examples/` directory (14+ worked examples)
|
||||
386
optional-skills/mlops/slime/references/troubleshooting.md
Normal file
386
optional-skills/mlops/slime/references/troubleshooting.md
Normal file
|
|
@ -0,0 +1,386 @@
|
|||
# slime Troubleshooting Guide
|
||||
|
||||
## Common Issues and Solutions
|
||||
|
||||
### SGLang Issues
|
||||
|
||||
#### Issue: SGLang Engine Crash
|
||||
|
||||
**Symptoms**: Inference engine dies mid-training, connection errors
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Enable fault tolerance**:
|
||||
```bash
|
||||
--use-fault-tolerance
|
||||
```
|
||||
|
||||
2. **Increase memory allocation**:
|
||||
```bash
|
||||
--sglang-mem-fraction-static 0.85 # Increase from 0.8
|
||||
```
|
||||
|
||||
3. **Reduce batch size**:
|
||||
```bash
|
||||
--rollout-batch-size 16 # Reduce from 32
|
||||
```
|
||||
|
||||
4. **Disable CUDA graphs** (for debugging):
|
||||
```bash
|
||||
--sglang-disable-cuda-graph
|
||||
```
|
||||
|
||||
#### Issue: SGLang Router Load Imbalance
|
||||
|
||||
**Symptoms**: Some SGLang engines overloaded while others idle
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Adjust routing strategy**:
|
||||
```bash
|
||||
--sglang-router-strategy round_robin
|
||||
```
|
||||
|
||||
2. **Increase number of engines**:
|
||||
```bash
|
||||
--rollout-num-gpus-per-engine 1 # More engines, less GPUs each
|
||||
```
|
||||
|
||||
### Weight Synchronization Issues
|
||||
|
||||
#### Issue: Weight Sync Timeout
|
||||
|
||||
**Symptoms**: Training hangs after rollout, timeout errors
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Increase sync interval** (async mode):
|
||||
```bash
|
||||
--update-weights-interval 5 # Increase from 2
|
||||
```
|
||||
|
||||
2. **Use colocated mode** (eliminates network transfer):
|
||||
```bash
|
||||
--colocate
|
||||
```
|
||||
|
||||
3. **Check network bandwidth**:
|
||||
```bash
|
||||
# Verify InfiniBand is enabled
|
||||
ibstat
|
||||
```
|
||||
|
||||
#### Issue: Weight Sync Failures in Multi-Node
|
||||
|
||||
**Symptoms**: Nodes fail to receive updated weights
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Set NCCL environment**:
|
||||
```bash
|
||||
export NCCL_DEBUG=INFO
|
||||
export NCCL_SOCKET_IFNAME=eth0
|
||||
export NCCL_IB_DISABLE=0
|
||||
```
|
||||
|
||||
2. **Increase timeout**:
|
||||
```bash
|
||||
export NCCL_TIMEOUT=1800
|
||||
```
|
||||
|
||||
### Memory Issues
|
||||
|
||||
#### Issue: OOM During Training
|
||||
|
||||
**Symptoms**: CUDA OOM in backward pass
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Enable gradient checkpointing**:
|
||||
```bash
|
||||
--recompute-activations
|
||||
```
|
||||
|
||||
2. **Reduce micro-batch size**:
|
||||
```bash
|
||||
--micro-batch-size 1
|
||||
```
|
||||
|
||||
3. **Enable sequence parallelism**:
|
||||
```bash
|
||||
--sequence-parallel
|
||||
```
|
||||
|
||||
4. **Reduce global batch size**:
|
||||
```bash
|
||||
--global-batch-size 128 # Reduce from 256
|
||||
```
|
||||
|
||||
#### Issue: OOM in Colocated Mode
|
||||
|
||||
**Symptoms**: OOM when both training and inference run on same GPUs
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Reduce SGLang memory**:
|
||||
```bash
|
||||
--sglang-mem-fraction-static 0.4 # Reduce from 0.8
|
||||
```
|
||||
|
||||
2. **Enable offloading**:
|
||||
```bash
|
||||
--offload-optimizer-states
|
||||
```
|
||||
|
||||
3. **Use smaller sequence length**:
|
||||
```bash
|
||||
--seq-length 2048 # Reduce from 4096
|
||||
```
|
||||
|
||||
### Data Loading Issues
|
||||
|
||||
#### Issue: Slow Data Loading
|
||||
|
||||
**Symptoms**: GPU idle during data fetch, low GPU utilization
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Increase data workers**:
|
||||
```bash
|
||||
--num-data-workers 4
|
||||
```
|
||||
|
||||
2. **Use streaming dataset**:
|
||||
```bash
|
||||
--streaming-data
|
||||
```
|
||||
|
||||
3. **Pre-tokenize data**:
|
||||
```python
|
||||
# Pre-process data offline
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("model_path")
|
||||
# Save tokenized data
|
||||
```
|
||||
|
||||
#### Issue: Data Format Errors
|
||||
|
||||
**Symptoms**: KeyError, missing fields, parsing failures
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Verify data format**:
|
||||
```python
|
||||
import json
|
||||
with open("data.jsonl") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
assert "prompt" in data, "Missing prompt field"
|
||||
assert "label" in data, "Missing label field"
|
||||
```
|
||||
|
||||
2. **Check key names**:
|
||||
```bash
|
||||
--input-key prompt # Must match your data
|
||||
--label-key label # Must match your data
|
||||
```
|
||||
|
||||
### Training Stability Issues
|
||||
|
||||
#### Issue: Loss Explosion / NaN
|
||||
|
||||
**Symptoms**: Loss becomes NaN or explodes
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Reduce learning rate**:
|
||||
```bash
|
||||
--lr 1e-6 # Reduce from 5e-6
|
||||
```
|
||||
|
||||
2. **Enable gradient clipping**:
|
||||
```bash
|
||||
--clip-grad 1.0
|
||||
```
|
||||
|
||||
3. **Check for data issues**:
|
||||
```python
|
||||
# Verify no empty prompts or responses
|
||||
for sample in dataset:
|
||||
assert len(sample["prompt"]) > 0
|
||||
```
|
||||
|
||||
4. **Use BF16 instead of FP16**:
|
||||
```bash
|
||||
--bf16 # More numerically stable
|
||||
```
|
||||
|
||||
#### Issue: Reward Collapse
|
||||
|
||||
**Symptoms**: Reward drops to zero, model outputs garbage
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Increase KL penalty**:
|
||||
```bash
|
||||
--kl-loss-coef 0.01 # Increase from 0.001
|
||||
```
|
||||
|
||||
2. **Reduce number of samples**:
|
||||
```bash
|
||||
--n-samples-per-prompt 4 # Reduce from 8
|
||||
```
|
||||
|
||||
3. **Verify reward function**:
|
||||
```python
|
||||
# Test reward function independently
|
||||
from custom_rm import reward_func
|
||||
sample = Sample(prompt="test", response="test response")
|
||||
reward = reward_func(args, sample)
|
||||
print(f"Reward: {reward}") # Should be reasonable
|
||||
```
|
||||
|
||||
### Async Training Issues
|
||||
|
||||
#### Issue: Async Training Not Supported with Colocate
|
||||
|
||||
**Symptoms**: Error when using `--colocate` with `train_async.py`
|
||||
|
||||
**Solution**: Colocated mode is NOT supported for async training. Use separate GPUs:
|
||||
```bash
|
||||
# Remove --colocate flag
|
||||
python train_async.py \
|
||||
--actor-num-gpus-per-node 4 \
|
||||
--rollout-num-gpus 4 \
|
||||
# No --colocate
|
||||
```
|
||||
|
||||
#### Issue: Stale Weights in Async Mode
|
||||
|
||||
**Symptoms**: Policy divergence, inconsistent behavior
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Reduce async buffer size**:
|
||||
```bash
|
||||
--async-buffer-size 2 # Reduce from 4
|
||||
```
|
||||
|
||||
2. **Increase weight update frequency**:
|
||||
```bash
|
||||
--update-weights-interval 1 # Sync every rollout
|
||||
```
|
||||
|
||||
### Multi-Turn Training Issues
|
||||
|
||||
#### Issue: Tool Responses Included in Loss
|
||||
|
||||
**Symptoms**: Model learns to output tool responses verbatim
|
||||
|
||||
**Solution**: Properly set loss mask in custom generate function:
|
||||
```python
|
||||
def build_loss_mask(sample):
|
||||
"""Create loss mask that excludes tool responses."""
|
||||
mask = []
|
||||
for i, token in enumerate(sample.tokens):
|
||||
if is_tool_response(token, sample.metadata):
|
||||
mask.append(0) # Don't compute loss
|
||||
else:
|
||||
mask.append(1) # Compute loss
|
||||
return mask
|
||||
```
|
||||
|
||||
#### Issue: Multi-Turn Context Too Long
|
||||
|
||||
**Symptoms**: OOM or truncation in multi-turn conversations
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Limit conversation history**:
|
||||
```python
|
||||
# In custom generate function
|
||||
conversation = sample.prompt[-10:] # Keep last 10 turns
|
||||
```
|
||||
|
||||
2. **Increase context length**:
|
||||
```bash
|
||||
--sglang-context-length 16384
|
||||
```
|
||||
|
||||
### Checkpoint Issues
|
||||
|
||||
#### Issue: Checkpoint Loading Fails
|
||||
|
||||
**Symptoms**: Cannot load saved checkpoint
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Verify checkpoint path**:
|
||||
```bash
|
||||
ls -la /path/to/checkpoint/
|
||||
```
|
||||
|
||||
2. **Check parallelism matches**:
|
||||
```bash
|
||||
# Checkpoint was saved with TP=2, must load with TP=2
|
||||
--tensor-model-parallel-size 2
|
||||
```
|
||||
|
||||
3. **Convert HuggingFace to Megatron** (if needed):
|
||||
```bash
|
||||
python tools/convert_hf_to_megatron.py \
|
||||
--hf_model_path /path/to/hf/model \
|
||||
--save_path /path/to/megatron/checkpoint
|
||||
```
|
||||
|
||||
### Debugging Tips
|
||||
|
||||
#### Enable Verbose Logging
|
||||
|
||||
```bash
|
||||
--log-level DEBUG
|
||||
export SLIME_DEBUG=1
|
||||
```
|
||||
|
||||
#### Check GPU Utilization
|
||||
|
||||
```bash
|
||||
watch -n 1 nvidia-smi
|
||||
```
|
||||
|
||||
#### Monitor Training
|
||||
|
||||
```bash
|
||||
tensorboard --logdir outputs/
|
||||
```
|
||||
|
||||
#### Test Custom Functions Independently
|
||||
|
||||
```python
|
||||
# Test reward function
|
||||
import asyncio
|
||||
from custom_rm import reward_func
|
||||
|
||||
async def test():
|
||||
sample = Sample(prompt="test", response="test", label="expected")
|
||||
reward = await reward_func(args, sample)
|
||||
print(f"Reward: {reward}")
|
||||
|
||||
asyncio.run(test())
|
||||
```
|
||||
|
||||
## Constraint Reference
|
||||
|
||||
Key constraint to remember:
|
||||
|
||||
```
|
||||
rollout_batch_size × n_samples_per_prompt = global_batch_size × num_steps_per_rollout
|
||||
```
|
||||
|
||||
Example: `32 × 8 = 256 × 1`
|
||||
|
||||
## Resources
|
||||
|
||||
- GitHub Issues: https://github.com/THUDM/slime/issues
|
||||
- Documentation: https://thudm.github.io/slime/
|
||||
- Examples: `examples/` directory
|
||||
190
optional-skills/mlops/tensorrt-llm/SKILL.md
Normal file
190
optional-skills/mlops/tensorrt-llm/SKILL.md
Normal file
|
|
@ -0,0 +1,190 @@
|
|||
---
|
||||
name: tensorrt-llm
|
||||
description: Optimizes LLM inference with NVIDIA TensorRT for maximum throughput and lowest latency. Use for production deployment on NVIDIA GPUs (A100/H100), when you need 10-100x faster inference than PyTorch, or for serving models with quantization (FP8/INT4), in-flight batching, and multi-GPU scaling.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [tensorrt-llm, torch]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Inference Serving, TensorRT-LLM, NVIDIA, Inference Optimization, High Throughput, Low Latency, Production, FP8, INT4, In-Flight Batching, Multi-GPU]
|
||||
|
||||
---
|
||||
|
||||
# TensorRT-LLM
|
||||
|
||||
NVIDIA's open-source library for optimizing LLM inference with state-of-the-art performance on NVIDIA GPUs.
|
||||
|
||||
## When to use TensorRT-LLM
|
||||
|
||||
**Use TensorRT-LLM when:**
|
||||
- Deploying on NVIDIA GPUs (A100, H100, GB200)
|
||||
- Need maximum throughput (24,000+ tokens/sec on Llama 3)
|
||||
- Require low latency for real-time applications
|
||||
- Working with quantized models (FP8, INT4, FP4)
|
||||
- Scaling across multiple GPUs or nodes
|
||||
|
||||
**Use vLLM instead when:**
|
||||
- Need simpler setup and Python-first API
|
||||
- Want PagedAttention without TensorRT compilation
|
||||
- Working with AMD GPUs or non-NVIDIA hardware
|
||||
|
||||
**Use llama.cpp instead when:**
|
||||
- Deploying on CPU or Apple Silicon
|
||||
- Need edge deployment without NVIDIA GPUs
|
||||
- Want simpler GGUF quantization format
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Docker (recommended)
|
||||
docker pull nvidia/tensorrt_llm:latest
|
||||
|
||||
# pip install
|
||||
pip install tensorrt_llm==1.2.0rc3
|
||||
|
||||
# Requires CUDA 13.0.0, TensorRT 10.13.2, Python 3.10-3.12
|
||||
```
|
||||
|
||||
### Basic inference
|
||||
|
||||
```python
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
|
||||
# Initialize model
|
||||
llm = LLM(model="meta-llama/Meta-Llama-3-8B")
|
||||
|
||||
# Configure sampling
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=100,
|
||||
temperature=0.7,
|
||||
top_p=0.9
|
||||
)
|
||||
|
||||
# Generate
|
||||
prompts = ["Explain quantum computing"]
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
for output in outputs:
|
||||
print(output.text)
|
||||
```
|
||||
|
||||
### Serving with trtllm-serve
|
||||
|
||||
```bash
|
||||
# Start server (automatic model download and compilation)
|
||||
trtllm-serve meta-llama/Meta-Llama-3-8B \
|
||||
--tp_size 4 \ # Tensor parallelism (4 GPUs)
|
||||
--max_batch_size 256 \
|
||||
--max_num_tokens 4096
|
||||
|
||||
# Client request
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "meta-llama/Meta-Llama-3-8B",
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100
|
||||
}'
|
||||
```
|
||||
|
||||
## Key features
|
||||
|
||||
### Performance optimizations
|
||||
- **In-flight batching**: Dynamic batching during generation
|
||||
- **Paged KV cache**: Efficient memory management
|
||||
- **Flash Attention**: Optimized attention kernels
|
||||
- **Quantization**: FP8, INT4, FP4 for 2-4× faster inference
|
||||
- **CUDA graphs**: Reduced kernel launch overhead
|
||||
|
||||
### Parallelism
|
||||
- **Tensor parallelism (TP)**: Split model across GPUs
|
||||
- **Pipeline parallelism (PP)**: Layer-wise distribution
|
||||
- **Expert parallelism**: For Mixture-of-Experts models
|
||||
- **Multi-node**: Scale beyond single machine
|
||||
|
||||
### Advanced features
|
||||
- **Speculative decoding**: Faster generation with draft models
|
||||
- **LoRA serving**: Efficient multi-adapter deployment
|
||||
- **Disaggregated serving**: Separate prefill and generation
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Quantized model (FP8)
|
||||
|
||||
```python
|
||||
from tensorrt_llm import LLM
|
||||
|
||||
# Load FP8 quantized model (2× faster, 50% memory)
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-70B",
|
||||
dtype="fp8",
|
||||
max_num_tokens=8192
|
||||
)
|
||||
|
||||
# Inference same as before
|
||||
outputs = llm.generate(["Summarize this article..."])
|
||||
```
|
||||
|
||||
### Multi-GPU deployment
|
||||
|
||||
```python
|
||||
# Tensor parallelism across 8 GPUs
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-405B",
|
||||
tensor_parallel_size=8,
|
||||
dtype="fp8"
|
||||
)
|
||||
```
|
||||
|
||||
### Batch inference
|
||||
|
||||
```python
|
||||
# Process 100 prompts efficiently
|
||||
prompts = [f"Question {i}: ..." for i in range(100)]
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params=SamplingParams(max_tokens=200)
|
||||
)
|
||||
|
||||
# Automatic in-flight batching for maximum throughput
|
||||
```
|
||||
|
||||
## Performance benchmarks
|
||||
|
||||
**Meta Llama 3-8B** (H100 GPU):
|
||||
- Throughput: 24,000 tokens/sec
|
||||
- Latency: ~10ms per token
|
||||
- vs PyTorch: **100× faster**
|
||||
|
||||
**Llama 3-70B** (8× A100 80GB):
|
||||
- FP8 quantization: 2× faster than FP16
|
||||
- Memory: 50% reduction with FP8
|
||||
|
||||
## Supported models
|
||||
|
||||
- **LLaMA family**: Llama 2, Llama 3, CodeLlama
|
||||
- **GPT family**: GPT-2, GPT-J, GPT-NeoX
|
||||
- **Qwen**: Qwen, Qwen2, QwQ
|
||||
- **DeepSeek**: DeepSeek-V2, DeepSeek-V3
|
||||
- **Mixtral**: Mixtral-8x7B, Mixtral-8x22B
|
||||
- **Vision**: LLaVA, Phi-3-vision
|
||||
- **100+ models** on HuggingFace
|
||||
|
||||
## References
|
||||
|
||||
- **[Optimization Guide](references/optimization.md)** - Quantization, batching, KV cache tuning
|
||||
- **[Multi-GPU Setup](references/multi-gpu.md)** - Tensor/pipeline parallelism, multi-node
|
||||
- **[Serving Guide](references/serving.md)** - Production deployment, monitoring, autoscaling
|
||||
|
||||
## Resources
|
||||
|
||||
- **Docs**: https://nvidia.github.io/TensorRT-LLM/
|
||||
- **GitHub**: https://github.com/NVIDIA/TensorRT-LLM
|
||||
- **Models**: https://huggingface.co/models?library=tensorrt_llm
|
||||
|
||||
|
||||
298
optional-skills/mlops/tensorrt-llm/references/multi-gpu.md
Normal file
298
optional-skills/mlops/tensorrt-llm/references/multi-gpu.md
Normal file
|
|
@ -0,0 +1,298 @@
|
|||
# Multi-GPU Deployment Guide
|
||||
|
||||
Comprehensive guide to scaling TensorRT-LLM across multiple GPUs and nodes.
|
||||
|
||||
## Parallelism Strategies
|
||||
|
||||
### Tensor Parallelism (TP)
|
||||
|
||||
**What it does**: Splits model layers across GPUs horizontally.
|
||||
|
||||
**Use case**:
|
||||
- Model fits in total GPU memory but not single GPU
|
||||
- Need low latency (single forward pass)
|
||||
- GPUs on same node (NVLink required for best performance)
|
||||
|
||||
**Example** (Llama 3-70B on 4× A100):
|
||||
```python
|
||||
from tensorrt_llm import LLM
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-70B",
|
||||
tensor_parallel_size=4, # Split across 4 GPUs
|
||||
dtype="fp16"
|
||||
)
|
||||
|
||||
# Model automatically sharded across GPUs
|
||||
# Single forward pass, low latency
|
||||
```
|
||||
|
||||
**Performance**:
|
||||
- Latency: ~Same as single GPU
|
||||
- Throughput: 4× higher (4 GPUs)
|
||||
- Communication: High (activations synced every layer)
|
||||
|
||||
### Pipeline Parallelism (PP)
|
||||
|
||||
**What it does**: Splits model layers across GPUs vertically (layer-wise).
|
||||
|
||||
**Use case**:
|
||||
- Very large models (175B+)
|
||||
- Can tolerate higher latency
|
||||
- GPUs across multiple nodes
|
||||
|
||||
**Example** (Llama 3-405B on 8× H100):
|
||||
```python
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-405B",
|
||||
tensor_parallel_size=4, # TP=4 within nodes
|
||||
pipeline_parallel_size=2, # PP=2 across nodes
|
||||
dtype="fp8"
|
||||
)
|
||||
|
||||
# Total: 8 GPUs (4×2)
|
||||
# Layers 0-40: Node 1 (4 GPUs with TP)
|
||||
# Layers 41-80: Node 2 (4 GPUs with TP)
|
||||
```
|
||||
|
||||
**Performance**:
|
||||
- Latency: Higher (sequential through pipeline)
|
||||
- Throughput: High with micro-batching
|
||||
- Communication: Lower than TP
|
||||
|
||||
### Expert Parallelism (EP)
|
||||
|
||||
**What it does**: Distributes MoE experts across GPUs.
|
||||
|
||||
**Use case**: Mixture-of-Experts models (Mixtral, DeepSeek-V2)
|
||||
|
||||
**Example** (Mixtral-8x22B on 8× A100):
|
||||
```python
|
||||
llm = LLM(
|
||||
model="mistralai/Mixtral-8x22B",
|
||||
tensor_parallel_size=4,
|
||||
expert_parallel_size=2, # Distribute 8 experts across 2 groups
|
||||
dtype="fp8"
|
||||
)
|
||||
```
|
||||
|
||||
## Configuration Examples
|
||||
|
||||
### Small model (7-13B) - Single GPU
|
||||
|
||||
```python
|
||||
# Llama 3-8B on 1× A100 80GB
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-8B",
|
||||
dtype="fp16" # or fp8 for H100
|
||||
)
|
||||
```
|
||||
|
||||
**Resources**:
|
||||
- GPU: 1× A100 80GB
|
||||
- Memory: ~16GB model + 30GB KV cache
|
||||
- Throughput: 3,000-5,000 tokens/sec
|
||||
|
||||
### Medium model (70B) - Multi-GPU same node
|
||||
|
||||
```python
|
||||
# Llama 3-70B on 4× A100 80GB (NVLink)
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-70B",
|
||||
tensor_parallel_size=4,
|
||||
dtype="fp8" # 70GB → 35GB per GPU
|
||||
)
|
||||
```
|
||||
|
||||
**Resources**:
|
||||
- GPU: 4× A100 80GB with NVLink
|
||||
- Memory: ~35GB per GPU (FP8)
|
||||
- Throughput: 10,000-15,000 tokens/sec
|
||||
- Latency: 15-20ms per token
|
||||
|
||||
### Large model (405B) - Multi-node
|
||||
|
||||
```python
|
||||
# Llama 3-405B on 2 nodes × 8 H100 = 16 GPUs
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-405B",
|
||||
tensor_parallel_size=8, # TP within each node
|
||||
pipeline_parallel_size=2, # PP across 2 nodes
|
||||
dtype="fp8"
|
||||
)
|
||||
```
|
||||
|
||||
**Resources**:
|
||||
- GPU: 2 nodes × 8 H100 80GB
|
||||
- Memory: ~25GB per GPU (FP8)
|
||||
- Throughput: 20,000-30,000 tokens/sec
|
||||
- Network: InfiniBand recommended
|
||||
|
||||
## Server Deployment
|
||||
|
||||
### Single-node multi-GPU
|
||||
|
||||
```bash
|
||||
# Llama 3-70B on 4 GPUs (automatic TP)
|
||||
trtllm-serve meta-llama/Meta-Llama-3-70B \
|
||||
--tp_size 4 \
|
||||
--max_batch_size 256 \
|
||||
--dtype fp8
|
||||
|
||||
# Listens on http://localhost:8000
|
||||
```
|
||||
|
||||
### Multi-node with Ray
|
||||
|
||||
```bash
|
||||
# Node 1 (head node)
|
||||
ray start --head --port=6379
|
||||
|
||||
# Node 2 (worker)
|
||||
ray start --address='node1:6379'
|
||||
|
||||
# Deploy across cluster
|
||||
trtllm-serve meta-llama/Meta-Llama-3-405B \
|
||||
--tp_size 8 \
|
||||
--pp_size 2 \
|
||||
--num_workers 2 \ # 2 nodes
|
||||
--dtype fp8
|
||||
```
|
||||
|
||||
### Kubernetes deployment
|
||||
|
||||
```yaml
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: tensorrt-llm-llama3-70b
|
||||
spec:
|
||||
replicas: 1
|
||||
template:
|
||||
spec:
|
||||
containers:
|
||||
- name: trtllm
|
||||
image: nvidia/tensorrt_llm:latest
|
||||
command:
|
||||
- trtllm-serve
|
||||
- meta-llama/Meta-Llama-3-70B
|
||||
- --tp_size=4
|
||||
- --max_batch_size=256
|
||||
resources:
|
||||
limits:
|
||||
nvidia.com/gpu: 4 # Request 4 GPUs
|
||||
```
|
||||
|
||||
## Parallelism Decision Tree
|
||||
|
||||
```
|
||||
Model size < 20GB?
|
||||
├─ YES: Single GPU (no parallelism)
|
||||
└─ NO: Model size < 80GB?
|
||||
├─ YES: TP=2 or TP=4 (same node)
|
||||
└─ NO: Model size < 320GB?
|
||||
├─ YES: TP=4 or TP=8 (same node, NVLink required)
|
||||
└─ NO: TP=8 + PP=2 (multi-node)
|
||||
```
|
||||
|
||||
## Communication Optimization
|
||||
|
||||
### NVLink vs PCIe
|
||||
|
||||
**NVLink** (DGX A100, HGX H100):
|
||||
- Bandwidth: 600 GB/s (A100), 900 GB/s (H100)
|
||||
- Ideal for TP (high communication)
|
||||
- **Recommended for all multi-GPU setups**
|
||||
|
||||
**PCIe**:
|
||||
- Bandwidth: 64 GB/s (PCIe 4.0 x16)
|
||||
- 10× slower than NVLink
|
||||
- Avoid TP, use PP instead
|
||||
|
||||
### InfiniBand for multi-node
|
||||
|
||||
**HDR InfiniBand** (200 Gb/s):
|
||||
- Required for multi-node TP or PP
|
||||
- Latency: <1μs
|
||||
- **Essential for 405B+ models**
|
||||
|
||||
## Monitoring Multi-GPU
|
||||
|
||||
```python
|
||||
# Monitor GPU utilization
|
||||
nvidia-smi dmon -s u
|
||||
|
||||
# Monitor memory
|
||||
nvidia-smi dmon -s m
|
||||
|
||||
# Monitor NVLink utilization
|
||||
nvidia-smi nvlink --status
|
||||
|
||||
# TensorRT-LLM built-in metrics
|
||||
curl http://localhost:8000/metrics
|
||||
```
|
||||
|
||||
**Key metrics**:
|
||||
- GPU utilization: Target 80-95%
|
||||
- Memory usage: Should be balanced across GPUs
|
||||
- NVLink traffic: High for TP, low for PP
|
||||
- Throughput: Tokens/sec across all GPUs
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Imbalanced GPU memory
|
||||
|
||||
**Symptom**: GPU 0 has 90% memory, GPU 3 has 40%
|
||||
|
||||
**Solutions**:
|
||||
- Verify TP/PP configuration
|
||||
- Check model sharding (should be equal)
|
||||
- Restart server to reset state
|
||||
|
||||
### Low NVLink utilization
|
||||
|
||||
**Symptom**: NVLink bandwidth <100 GB/s with TP=4
|
||||
|
||||
**Solutions**:
|
||||
- Verify NVLink topology: `nvidia-smi topo -m`
|
||||
- Check for PCIe fallback
|
||||
- Ensure GPUs are on same NVSwitch
|
||||
|
||||
### OOM with multi-GPU
|
||||
|
||||
**Solutions**:
|
||||
- Increase TP size (more GPUs)
|
||||
- Reduce batch size
|
||||
- Enable FP8 quantization
|
||||
- Use pipeline parallelism
|
||||
|
||||
## Performance Scaling
|
||||
|
||||
### TP Scaling (Llama 3-70B, FP8)
|
||||
|
||||
| GPUs | TP Size | Throughput | Latency | Efficiency |
|
||||
|------|---------|------------|---------|------------|
|
||||
| 1 | 1 | OOM | - | - |
|
||||
| 2 | 2 | 6,000 tok/s | 18ms | 85% |
|
||||
| 4 | 4 | 11,000 tok/s | 16ms | 78% |
|
||||
| 8 | 8 | 18,000 tok/s | 15ms | 64% |
|
||||
|
||||
**Note**: Efficiency drops with more GPUs due to communication overhead.
|
||||
|
||||
### PP Scaling (Llama 3-405B, FP8)
|
||||
|
||||
| Nodes | TP | PP | Total GPUs | Throughput |
|
||||
|-------|----|----|------------|------------|
|
||||
| 1 | 8 | 1 | 8 | OOM |
|
||||
| 2 | 8 | 2 | 16 | 25,000 tok/s |
|
||||
| 4 | 8 | 4 | 32 | 45,000 tok/s |
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Prefer TP over PP** when possible (lower latency)
|
||||
2. **Use NVLink** for all TP deployments
|
||||
3. **Use InfiniBand** for multi-node deployments
|
||||
4. **Start with smallest TP** that fits model in memory
|
||||
5. **Monitor GPU balance** - all GPUs should have similar utilization
|
||||
6. **Test with benchmark** before production
|
||||
7. **Use FP8** on H100 for 2× speedup
|
||||
242
optional-skills/mlops/tensorrt-llm/references/optimization.md
Normal file
242
optional-skills/mlops/tensorrt-llm/references/optimization.md
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
# TensorRT-LLM Optimization Guide
|
||||
|
||||
Comprehensive guide to optimizing LLM inference with TensorRT-LLM.
|
||||
|
||||
## Quantization
|
||||
|
||||
### FP8 Quantization (Recommended for H100)
|
||||
|
||||
**Benefits**:
|
||||
- 2× faster inference
|
||||
- 50% memory reduction
|
||||
- Minimal accuracy loss (<1% perplexity degradation)
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from tensorrt_llm import LLM
|
||||
|
||||
# Automatic FP8 quantization
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-70B",
|
||||
dtype="fp8",
|
||||
quantization="fp8"
|
||||
)
|
||||
```
|
||||
|
||||
**Performance** (Llama 3-70B on 8× H100):
|
||||
- FP16: 5,000 tokens/sec
|
||||
- FP8: **10,000 tokens/sec** (2× speedup)
|
||||
- Memory: 140GB → 70GB
|
||||
|
||||
### INT4 Quantization (Maximum compression)
|
||||
|
||||
**Benefits**:
|
||||
- 4× memory reduction
|
||||
- 3-4× faster inference
|
||||
- Fits larger models on same hardware
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
# INT4 with AWQ calibration
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-405B",
|
||||
dtype="int4_awq",
|
||||
quantization="awq"
|
||||
)
|
||||
|
||||
# INT4 with GPTQ calibration
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-405B",
|
||||
dtype="int4_gptq",
|
||||
quantization="gptq"
|
||||
)
|
||||
```
|
||||
|
||||
**Trade-offs**:
|
||||
- Accuracy: 1-3% perplexity increase
|
||||
- Speed: 3-4× faster than FP16
|
||||
- Use case: When memory is critical
|
||||
|
||||
## In-Flight Batching
|
||||
|
||||
**What it does**: Dynamically batches requests during generation instead of waiting for all sequences to finish.
|
||||
|
||||
**Configuration**:
|
||||
```python
|
||||
# Server configuration
|
||||
trtllm-serve meta-llama/Meta-Llama-3-8B \
|
||||
--max_batch_size 256 \ # Maximum concurrent sequences
|
||||
--max_num_tokens 4096 \ # Total tokens in batch
|
||||
--enable_chunked_context \ # Split long prompts
|
||||
--scheduler_policy max_utilization
|
||||
```
|
||||
|
||||
**Performance**:
|
||||
- Throughput: **4-8× higher** vs static batching
|
||||
- Latency: Lower P50/P99 for mixed workloads
|
||||
- GPU utilization: 80-95% vs 40-60%
|
||||
|
||||
## Paged KV Cache
|
||||
|
||||
**What it does**: Manages KV cache memory like OS manages virtual memory (paging).
|
||||
|
||||
**Benefits**:
|
||||
- 40-60% higher throughput
|
||||
- No memory fragmentation
|
||||
- Supports longer sequences
|
||||
|
||||
**Configuration**:
|
||||
```python
|
||||
# Automatic paged KV cache (default)
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-8B",
|
||||
kv_cache_free_gpu_mem_fraction=0.9, # Use 90% GPU mem for cache
|
||||
enable_prefix_caching=True # Cache common prefixes
|
||||
)
|
||||
```
|
||||
|
||||
## Speculative Decoding
|
||||
|
||||
**What it does**: Uses small draft model to predict multiple tokens, verified by target model in parallel.
|
||||
|
||||
**Speedup**: 2-3× faster for long generations
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from tensorrt_llm import LLM
|
||||
|
||||
# Target model (Llama 3-70B)
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-70B",
|
||||
speculative_model="meta-llama/Meta-Llama-3-8B", # Draft model
|
||||
num_speculative_tokens=5 # Tokens to predict ahead
|
||||
)
|
||||
|
||||
# Same API, 2-3× faster
|
||||
outputs = llm.generate(prompts)
|
||||
```
|
||||
|
||||
**Best models for drafting**:
|
||||
- Target: Llama 3-70B → Draft: Llama 3-8B
|
||||
- Target: Qwen2-72B → Draft: Qwen2-7B
|
||||
- Same family, 8-10× smaller
|
||||
|
||||
## CUDA Graphs
|
||||
|
||||
**What it does**: Reduces kernel launch overhead by recording GPU operations.
|
||||
|
||||
**Benefits**:
|
||||
- 10-20% lower latency
|
||||
- More stable P99 latency
|
||||
- Better for small batch sizes
|
||||
|
||||
**Configuration** (automatic by default):
|
||||
```python
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-8B",
|
||||
enable_cuda_graph=True, # Default: True
|
||||
cuda_graph_cache_size=2 # Cache 2 graph variants
|
||||
)
|
||||
```
|
||||
|
||||
## Chunked Context
|
||||
|
||||
**What it does**: Splits long prompts into chunks to reduce memory spikes.
|
||||
|
||||
**Use case**: Prompts >8K tokens with limited GPU memory
|
||||
|
||||
**Configuration**:
|
||||
```bash
|
||||
trtllm-serve meta-llama/Meta-Llama-3-8B \
|
||||
--max_num_tokens 4096 \
|
||||
--enable_chunked_context \
|
||||
--max_chunked_prefill_length 2048 # Process 2K tokens at a time
|
||||
```
|
||||
|
||||
## Overlap Scheduling
|
||||
|
||||
**What it does**: Overlaps compute and memory operations.
|
||||
|
||||
**Benefits**:
|
||||
- 15-25% higher throughput
|
||||
- Better GPU utilization
|
||||
- Default in v1.2.0+
|
||||
|
||||
**No configuration needed** - enabled automatically.
|
||||
|
||||
## Quantization Comparison Table
|
||||
|
||||
| Method | Memory | Speed | Accuracy | Use Case |
|
||||
|--------|--------|-------|----------|----------|
|
||||
| FP16 | 1× (baseline) | 1× | Best | High accuracy needed |
|
||||
| FP8 | 0.5× | 2× | -0.5% ppl | **H100 default** |
|
||||
| INT4 AWQ | 0.25× | 3-4× | -1.5% ppl | Memory critical |
|
||||
| INT4 GPTQ | 0.25× | 3-4× | -2% ppl | Maximum speed |
|
||||
|
||||
## Tuning Workflow
|
||||
|
||||
1. **Start with defaults**:
|
||||
```python
|
||||
llm = LLM(model="meta-llama/Meta-Llama-3-70B")
|
||||
```
|
||||
|
||||
2. **Enable FP8** (if H100):
|
||||
```python
|
||||
llm = LLM(model="...", dtype="fp8")
|
||||
```
|
||||
|
||||
3. **Tune batch size**:
|
||||
```python
|
||||
# Increase until OOM, then reduce 20%
|
||||
trtllm-serve ... --max_batch_size 256
|
||||
```
|
||||
|
||||
4. **Enable chunked context** (if long prompts):
|
||||
```bash
|
||||
--enable_chunked_context --max_chunked_prefill_length 2048
|
||||
```
|
||||
|
||||
5. **Try speculative decoding** (if latency critical):
|
||||
```python
|
||||
llm = LLM(model="...", speculative_model="...")
|
||||
```
|
||||
|
||||
## Benchmarking
|
||||
|
||||
```bash
|
||||
# Install benchmark tool
|
||||
pip install tensorrt_llm[benchmark]
|
||||
|
||||
# Run benchmark
|
||||
python benchmarks/python/benchmark.py \
|
||||
--model meta-llama/Meta-Llama-3-8B \
|
||||
--batch_size 64 \
|
||||
--input_len 128 \
|
||||
--output_len 256 \
|
||||
--dtype fp8
|
||||
```
|
||||
|
||||
**Metrics to track**:
|
||||
- Throughput (tokens/sec)
|
||||
- Latency P50/P90/P99 (ms)
|
||||
- GPU memory usage (GB)
|
||||
- GPU utilization (%)
|
||||
|
||||
## Common Issues
|
||||
|
||||
**OOM errors**:
|
||||
- Reduce `max_batch_size`
|
||||
- Reduce `max_num_tokens`
|
||||
- Enable INT4 quantization
|
||||
- Increase `tensor_parallel_size`
|
||||
|
||||
**Low throughput**:
|
||||
- Increase `max_batch_size`
|
||||
- Enable in-flight batching
|
||||
- Verify CUDA graphs enabled
|
||||
- Check GPU utilization
|
||||
|
||||
**High latency**:
|
||||
- Try speculative decoding
|
||||
- Reduce `max_batch_size` (less queueing)
|
||||
- Use FP8 instead of FP16
|
||||
470
optional-skills/mlops/tensorrt-llm/references/serving.md
Normal file
470
optional-skills/mlops/tensorrt-llm/references/serving.md
Normal file
|
|
@ -0,0 +1,470 @@
|
|||
# Production Serving Guide
|
||||
|
||||
Comprehensive guide to deploying TensorRT-LLM in production environments.
|
||||
|
||||
## Server Modes
|
||||
|
||||
### trtllm-serve (Recommended)
|
||||
|
||||
**Features**:
|
||||
- OpenAI-compatible API
|
||||
- Automatic model download and compilation
|
||||
- Built-in load balancing
|
||||
- Prometheus metrics
|
||||
- Health checks
|
||||
|
||||
**Basic usage**:
|
||||
```bash
|
||||
trtllm-serve meta-llama/Meta-Llama-3-8B \
|
||||
--tp_size 1 \
|
||||
--max_batch_size 256 \
|
||||
--port 8000
|
||||
```
|
||||
|
||||
**Advanced configuration**:
|
||||
```bash
|
||||
trtllm-serve meta-llama/Meta-Llama-3-70B \
|
||||
--tp_size 4 \
|
||||
--dtype fp8 \
|
||||
--max_batch_size 256 \
|
||||
--max_num_tokens 4096 \
|
||||
--enable_chunked_context \
|
||||
--scheduler_policy max_utilization \
|
||||
--port 8000 \
|
||||
--api_key $API_KEY # Optional authentication
|
||||
```
|
||||
|
||||
### Python LLM API (For embedding)
|
||||
|
||||
```python
|
||||
from tensorrt_llm import LLM
|
||||
|
||||
class LLMService:
|
||||
def __init__(self):
|
||||
self.llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-8B",
|
||||
dtype="fp8"
|
||||
)
|
||||
|
||||
def generate(self, prompt, max_tokens=100):
|
||||
from tensorrt_llm import SamplingParams
|
||||
|
||||
params = SamplingParams(
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.7
|
||||
)
|
||||
outputs = self.llm.generate([prompt], params)
|
||||
return outputs[0].text
|
||||
|
||||
# Use in FastAPI, Flask, etc
|
||||
from fastapi import FastAPI
|
||||
app = FastAPI()
|
||||
service = LLMService()
|
||||
|
||||
@app.post("/generate")
|
||||
def generate(prompt: str):
|
||||
return {"response": service.generate(prompt)}
|
||||
```
|
||||
|
||||
## OpenAI-Compatible API
|
||||
|
||||
### Chat Completions
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "meta-llama/Meta-Llama-3-8B",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Explain quantum computing"}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 500,
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
**Response**:
|
||||
```json
|
||||
{
|
||||
"id": "chat-abc123",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "meta-llama/Meta-Llama-3-8B",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Quantum computing is..."
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 25,
|
||||
"completion_tokens": 150,
|
||||
"total_tokens": 175
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Streaming
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "meta-llama/Meta-Llama-3-8B",
|
||||
"messages": [{"role": "user", "content": "Count to 10"}],
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
**Response** (SSE stream):
|
||||
```
|
||||
data: {"choices":[{"delta":{"content":"1"}}]}
|
||||
|
||||
data: {"choices":[{"delta":{"content":", 2"}}]}
|
||||
|
||||
data: {"choices":[{"delta":{"content":", 3"}}]}
|
||||
|
||||
data: [DONE]
|
||||
```
|
||||
|
||||
### Completions
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "meta-llama/Meta-Llama-3-8B",
|
||||
"prompt": "The capital of France is",
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.0
|
||||
}'
|
||||
```
|
||||
|
||||
## Monitoring
|
||||
|
||||
### Prometheus Metrics
|
||||
|
||||
**Enable metrics**:
|
||||
```bash
|
||||
trtllm-serve meta-llama/Meta-Llama-3-8B \
|
||||
--enable_metrics \
|
||||
--metrics_port 9090
|
||||
```
|
||||
|
||||
**Key metrics**:
|
||||
```bash
|
||||
# Scrape metrics
|
||||
curl http://localhost:9090/metrics
|
||||
|
||||
# Important metrics:
|
||||
# - trtllm_request_success_total - Total successful requests
|
||||
# - trtllm_request_latency_seconds - Request latency histogram
|
||||
# - trtllm_tokens_generated_total - Total tokens generated
|
||||
# - trtllm_active_requests - Current active requests
|
||||
# - trtllm_queue_size - Requests waiting in queue
|
||||
# - trtllm_gpu_memory_usage_bytes - GPU memory usage
|
||||
# - trtllm_kv_cache_usage_ratio - KV cache utilization
|
||||
```
|
||||
|
||||
### Health Checks
|
||||
|
||||
```bash
|
||||
# Readiness probe
|
||||
curl http://localhost:8000/health/ready
|
||||
|
||||
# Liveness probe
|
||||
curl http://localhost:8000/health/live
|
||||
|
||||
# Model info
|
||||
curl http://localhost:8000/v1/models
|
||||
```
|
||||
|
||||
**Kubernetes probes**:
|
||||
```yaml
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health/live
|
||||
port: 8000
|
||||
initialDelaySeconds: 60
|
||||
periodSeconds: 10
|
||||
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health/ready
|
||||
port: 8000
|
||||
initialDelaySeconds: 30
|
||||
periodSeconds: 5
|
||||
```
|
||||
|
||||
## Production Deployment
|
||||
|
||||
### Docker Deployment
|
||||
|
||||
**Dockerfile**:
|
||||
```dockerfile
|
||||
FROM nvidia/tensorrt_llm:latest
|
||||
|
||||
# Copy any custom configs
|
||||
COPY config.yaml /app/config.yaml
|
||||
|
||||
# Expose ports
|
||||
EXPOSE 8000 9090
|
||||
|
||||
# Start server
|
||||
CMD ["trtllm-serve", "meta-llama/Meta-Llama-3-8B", \
|
||||
"--tp_size", "4", \
|
||||
"--dtype", "fp8", \
|
||||
"--max_batch_size", "256", \
|
||||
"--enable_metrics", \
|
||||
"--metrics_port", "9090"]
|
||||
```
|
||||
|
||||
**Run container**:
|
||||
```bash
|
||||
docker run --gpus all -p 8000:8000 -p 9090:9090 \
|
||||
tensorrt-llm:latest
|
||||
```
|
||||
|
||||
### Kubernetes Deployment
|
||||
|
||||
**Complete deployment**:
|
||||
```yaml
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: tensorrt-llm
|
||||
spec:
|
||||
replicas: 2 # Multiple replicas for HA
|
||||
selector:
|
||||
matchLabels:
|
||||
app: tensorrt-llm
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: tensorrt-llm
|
||||
spec:
|
||||
containers:
|
||||
- name: trtllm
|
||||
image: nvidia/tensorrt_llm:latest
|
||||
command:
|
||||
- trtllm-serve
|
||||
- meta-llama/Meta-Llama-3-70B
|
||||
- --tp_size=4
|
||||
- --dtype=fp8
|
||||
- --max_batch_size=256
|
||||
- --enable_metrics
|
||||
ports:
|
||||
- containerPort: 8000
|
||||
name: http
|
||||
- containerPort: 9090
|
||||
name: metrics
|
||||
resources:
|
||||
limits:
|
||||
nvidia.com/gpu: 4
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health/live
|
||||
port: 8000
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health/ready
|
||||
port: 8000
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: tensorrt-llm
|
||||
spec:
|
||||
selector:
|
||||
app: tensorrt-llm
|
||||
ports:
|
||||
- name: http
|
||||
port: 80
|
||||
targetPort: 8000
|
||||
- name: metrics
|
||||
port: 9090
|
||||
targetPort: 9090
|
||||
type: LoadBalancer
|
||||
```
|
||||
|
||||
### Load Balancing
|
||||
|
||||
**NGINX configuration**:
|
||||
```nginx
|
||||
upstream tensorrt_llm {
|
||||
least_conn; # Route to least busy server
|
||||
server trtllm-1:8000 max_fails=3 fail_timeout=30s;
|
||||
server trtllm-2:8000 max_fails=3 fail_timeout=30s;
|
||||
server trtllm-3:8000 max_fails=3 fail_timeout=30s;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 80;
|
||||
location / {
|
||||
proxy_pass http://tensorrt_llm;
|
||||
proxy_read_timeout 300s; # Long timeout for slow generations
|
||||
proxy_connect_timeout 10s;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Autoscaling
|
||||
|
||||
### Horizontal Pod Autoscaler (HPA)
|
||||
|
||||
```yaml
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
name: tensorrt-llm-hpa
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
name: tensorrt-llm
|
||||
minReplicas: 2
|
||||
maxReplicas: 10
|
||||
metrics:
|
||||
- type: Pods
|
||||
pods:
|
||||
metric:
|
||||
name: trtllm_active_requests
|
||||
target:
|
||||
type: AverageValue
|
||||
averageValue: "50" # Scale when avg >50 active requests
|
||||
```
|
||||
|
||||
### Custom Metrics
|
||||
|
||||
```yaml
|
||||
# Scale based on queue size
|
||||
- type: Pods
|
||||
pods:
|
||||
metric:
|
||||
name: trtllm_queue_size
|
||||
target:
|
||||
type: AverageValue
|
||||
averageValue: "10"
|
||||
```
|
||||
|
||||
## Cost Optimization
|
||||
|
||||
### GPU Selection
|
||||
|
||||
**A100 80GB** ($3-4/hour):
|
||||
- Use for: 70B models with FP8
|
||||
- Throughput: 10,000-15,000 tok/s (TP=4)
|
||||
- Cost per 1M tokens: $0.20-0.30
|
||||
|
||||
**H100 80GB** ($6-8/hour):
|
||||
- Use for: 70B models with FP8, 405B models
|
||||
- Throughput: 20,000-30,000 tok/s (TP=4)
|
||||
- Cost per 1M tokens: $0.15-0.25 (2× faster = lower cost)
|
||||
|
||||
**L4** ($0.50-1/hour):
|
||||
- Use for: 7-8B models
|
||||
- Throughput: 1,000-2,000 tok/s
|
||||
- Cost per 1M tokens: $0.25-0.50
|
||||
|
||||
### Batch Size Tuning
|
||||
|
||||
**Impact on cost**:
|
||||
- Batch size 1: 1,000 tok/s → $3/hour per 1M = $3/M tokens
|
||||
- Batch size 64: 5,000 tok/s → $3/hour per 5M = $0.60/M tokens
|
||||
- **5× cost reduction** with batching
|
||||
|
||||
**Recommendation**: Target batch size 32-128 for cost efficiency.
|
||||
|
||||
## Security
|
||||
|
||||
### API Authentication
|
||||
|
||||
```bash
|
||||
# Generate API key
|
||||
export API_KEY=$(openssl rand -hex 32)
|
||||
|
||||
# Start server with authentication
|
||||
trtllm-serve meta-llama/Meta-Llama-3-8B \
|
||||
--api_key $API_KEY
|
||||
|
||||
# Client request
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Authorization: Bearer $API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model": "...", "messages": [...]}'
|
||||
```
|
||||
|
||||
### Network Policies
|
||||
|
||||
```yaml
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: NetworkPolicy
|
||||
metadata:
|
||||
name: tensorrt-llm-policy
|
||||
spec:
|
||||
podSelector:
|
||||
matchLabels:
|
||||
app: tensorrt-llm
|
||||
policyTypes:
|
||||
- Ingress
|
||||
ingress:
|
||||
- from:
|
||||
- podSelector:
|
||||
matchLabels:
|
||||
app: api-gateway # Only allow from gateway
|
||||
ports:
|
||||
- protocol: TCP
|
||||
port: 8000
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### High latency
|
||||
|
||||
**Diagnosis**:
|
||||
```bash
|
||||
# Check queue size
|
||||
curl http://localhost:9090/metrics | grep queue_size
|
||||
|
||||
# Check active requests
|
||||
curl http://localhost:9090/metrics | grep active_requests
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
- Scale horizontally (more replicas)
|
||||
- Increase batch size (if GPU underutilized)
|
||||
- Enable chunked context (if long prompts)
|
||||
- Use FP8 quantization
|
||||
|
||||
### OOM crashes
|
||||
|
||||
**Solutions**:
|
||||
- Reduce `max_batch_size`
|
||||
- Reduce `max_num_tokens`
|
||||
- Enable FP8 or INT4 quantization
|
||||
- Increase `tensor_parallel_size`
|
||||
|
||||
### Timeout errors
|
||||
|
||||
**NGINX config**:
|
||||
```nginx
|
||||
proxy_read_timeout 600s; # 10 minutes for very long generations
|
||||
proxy_send_timeout 600s;
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use FP8 on H100** for 2× speedup and 50% cost reduction
|
||||
2. **Monitor metrics** - Set up Prometheus + Grafana
|
||||
3. **Set readiness probes** - Prevent routing to unhealthy pods
|
||||
4. **Use load balancing** - Distribute load across replicas
|
||||
5. **Tune batch size** - Balance latency and throughput
|
||||
6. **Enable streaming** - Better UX for chat applications
|
||||
7. **Set up autoscaling** - Handle traffic spikes
|
||||
8. **Use persistent volumes** - Cache compiled models
|
||||
9. **Implement retries** - Handle transient failures
|
||||
10. **Monitor costs** - Track cost per token
|
||||
361
optional-skills/mlops/torchtitan/SKILL.md
Normal file
361
optional-skills/mlops/torchtitan/SKILL.md
Normal file
|
|
@ -0,0 +1,361 @@
|
|||
---
|
||||
name: distributed-llm-pretraining-torchtitan
|
||||
description: Provides PyTorch-native distributed LLM pretraining using torchtitan with 4D parallelism (FSDP2, TP, PP, CP). Use when pretraining Llama 3.1, DeepSeek V3, or custom models at scale from 8 to 512+ GPUs with Float8, torch.compile, and distributed checkpointing.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [torch>=2.6.0, torchtitan>=0.2.0, torchao>=0.5.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Model Architecture, Distributed Training, TorchTitan, FSDP2, Tensor Parallel, Pipeline Parallel, Context Parallel, Float8, Llama, Pretraining]
|
||||
|
||||
---
|
||||
|
||||
# TorchTitan - PyTorch Native Distributed LLM Pretraining
|
||||
|
||||
## Quick start
|
||||
|
||||
TorchTitan is PyTorch's official platform for large-scale LLM pretraining with composable 4D parallelism (FSDP2, TP, PP, CP), achieving 65%+ speedups over baselines on H100 GPUs.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
# From PyPI (stable)
|
||||
pip install torchtitan
|
||||
|
||||
# From source (latest features, requires PyTorch nightly)
|
||||
git clone https://github.com/pytorch/torchtitan
|
||||
cd torchtitan
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
**Download tokenizer**:
|
||||
```bash
|
||||
# Get HF token from https://huggingface.co/settings/tokens
|
||||
python scripts/download_hf_assets.py --repo_id meta-llama/Llama-3.1-8B --assets tokenizer --hf_token=...
|
||||
```
|
||||
|
||||
**Start training on 8 GPUs**:
|
||||
```bash
|
||||
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Pretrain Llama 3.1 8B on single node
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
Single Node Pretraining:
|
||||
- [ ] Step 1: Download tokenizer
|
||||
- [ ] Step 2: Configure training
|
||||
- [ ] Step 3: Launch training
|
||||
- [ ] Step 4: Monitor and checkpoint
|
||||
```
|
||||
|
||||
**Step 1: Download tokenizer**
|
||||
|
||||
```bash
|
||||
python scripts/download_hf_assets.py \
|
||||
--repo_id meta-llama/Llama-3.1-8B \
|
||||
--assets tokenizer \
|
||||
--hf_token=YOUR_HF_TOKEN
|
||||
```
|
||||
|
||||
**Step 2: Configure training**
|
||||
|
||||
Edit or create a TOML config file:
|
||||
|
||||
```toml
|
||||
# llama3_8b_custom.toml
|
||||
[job]
|
||||
dump_folder = "./outputs"
|
||||
description = "Llama 3.1 8B training"
|
||||
|
||||
[model]
|
||||
name = "llama3"
|
||||
flavor = "8B"
|
||||
hf_assets_path = "./assets/hf/Llama-3.1-8B"
|
||||
|
||||
[optimizer]
|
||||
name = "AdamW"
|
||||
lr = 3e-4
|
||||
|
||||
[lr_scheduler]
|
||||
warmup_steps = 200
|
||||
|
||||
[training]
|
||||
local_batch_size = 2
|
||||
seq_len = 8192
|
||||
max_norm = 1.0
|
||||
steps = 1000
|
||||
dataset = "c4"
|
||||
|
||||
[parallelism]
|
||||
data_parallel_shard_degree = -1 # Use all GPUs for FSDP
|
||||
|
||||
[activation_checkpoint]
|
||||
mode = "selective"
|
||||
selective_ac_option = "op"
|
||||
|
||||
[checkpoint]
|
||||
enable = true
|
||||
folder = "checkpoint"
|
||||
interval = 500
|
||||
```
|
||||
|
||||
**Step 3: Launch training**
|
||||
|
||||
```bash
|
||||
# 8 GPUs on single node
|
||||
CONFIG_FILE="./llama3_8b_custom.toml" ./run_train.sh
|
||||
|
||||
# Or explicitly with torchrun
|
||||
torchrun --nproc_per_node=8 \
|
||||
-m torchtitan.train \
|
||||
--job.config_file ./llama3_8b_custom.toml
|
||||
```
|
||||
|
||||
**Step 4: Monitor and checkpoint**
|
||||
|
||||
TensorBoard logs are saved to `./outputs/tb/`:
|
||||
```bash
|
||||
tensorboard --logdir ./outputs/tb
|
||||
```
|
||||
|
||||
### Workflow 2: Multi-node training with SLURM
|
||||
|
||||
```
|
||||
Multi-Node Training:
|
||||
- [ ] Step 1: Configure parallelism for scale
|
||||
- [ ] Step 2: Set up SLURM script
|
||||
- [ ] Step 3: Submit job
|
||||
- [ ] Step 4: Resume from checkpoint
|
||||
```
|
||||
|
||||
**Step 1: Configure parallelism for scale**
|
||||
|
||||
For 70B model on 256 GPUs (32 nodes):
|
||||
```toml
|
||||
[parallelism]
|
||||
data_parallel_shard_degree = 32 # FSDP across 32 ranks
|
||||
tensor_parallel_degree = 8 # TP within node
|
||||
pipeline_parallel_degree = 1 # No PP for 70B
|
||||
context_parallel_degree = 1 # Increase for long sequences
|
||||
```
|
||||
|
||||
**Step 2: Set up SLURM script**
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=llama70b
|
||||
#SBATCH --nodes=32
|
||||
#SBATCH --ntasks-per-node=8
|
||||
#SBATCH --gpus-per-node=8
|
||||
|
||||
srun torchrun \
|
||||
--nnodes=32 \
|
||||
--nproc_per_node=8 \
|
||||
--rdzv_backend=c10d \
|
||||
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
|
||||
-m torchtitan.train \
|
||||
--job.config_file ./llama3_70b.toml
|
||||
```
|
||||
|
||||
**Step 3: Submit job**
|
||||
|
||||
```bash
|
||||
sbatch multinode_trainer.slurm
|
||||
```
|
||||
|
||||
**Step 4: Resume from checkpoint**
|
||||
|
||||
Training auto-resumes if checkpoint exists in configured folder.
|
||||
|
||||
### Workflow 3: Enable Float8 training for H100s
|
||||
|
||||
Float8 provides 30-50% speedup on H100 GPUs.
|
||||
|
||||
```
|
||||
Float8 Training:
|
||||
- [ ] Step 1: Install torchao
|
||||
- [ ] Step 2: Configure Float8
|
||||
- [ ] Step 3: Launch with compile
|
||||
```
|
||||
|
||||
**Step 1: Install torchao**
|
||||
|
||||
```bash
|
||||
USE_CPP=0 pip install git+https://github.com/pytorch/ao.git
|
||||
```
|
||||
|
||||
**Step 2: Configure Float8**
|
||||
|
||||
Add to your TOML config:
|
||||
```toml
|
||||
[model]
|
||||
converters = ["quantize.linear.float8"]
|
||||
|
||||
[quantize.linear.float8]
|
||||
enable_fsdp_float8_all_gather = true
|
||||
precompute_float8_dynamic_scale_for_fsdp = true
|
||||
filter_fqns = ["output"] # Exclude output layer
|
||||
|
||||
[compile]
|
||||
enable = true
|
||||
components = ["model", "loss"]
|
||||
```
|
||||
|
||||
**Step 3: Launch with compile**
|
||||
|
||||
```bash
|
||||
CONFIG_FILE="./llama3_8b.toml" ./run_train.sh \
|
||||
--model.converters="quantize.linear.float8" \
|
||||
--quantize.linear.float8.enable_fsdp_float8_all_gather \
|
||||
--compile.enable
|
||||
```
|
||||
|
||||
### Workflow 4: 4D parallelism for 405B models
|
||||
|
||||
```
|
||||
4D Parallelism (FSDP + TP + PP + CP):
|
||||
- [ ] Step 1: Create seed checkpoint
|
||||
- [ ] Step 2: Configure 4D parallelism
|
||||
- [ ] Step 3: Launch on 512 GPUs
|
||||
```
|
||||
|
||||
**Step 1: Create seed checkpoint**
|
||||
|
||||
Required for consistent initialization across PP stages:
|
||||
```bash
|
||||
NGPU=1 CONFIG_FILE=./llama3_405b.toml ./run_train.sh \
|
||||
--checkpoint.enable \
|
||||
--checkpoint.create_seed_checkpoint \
|
||||
--parallelism.data_parallel_shard_degree 1 \
|
||||
--parallelism.tensor_parallel_degree 1 \
|
||||
--parallelism.pipeline_parallel_degree 1
|
||||
```
|
||||
|
||||
**Step 2: Configure 4D parallelism**
|
||||
|
||||
```toml
|
||||
[parallelism]
|
||||
data_parallel_shard_degree = 8 # FSDP
|
||||
tensor_parallel_degree = 8 # TP within node
|
||||
pipeline_parallel_degree = 8 # PP across nodes
|
||||
context_parallel_degree = 1 # CP for long sequences
|
||||
|
||||
[training]
|
||||
local_batch_size = 32
|
||||
seq_len = 8192
|
||||
```
|
||||
|
||||
**Step 3: Launch on 512 GPUs**
|
||||
|
||||
```bash
|
||||
# 64 nodes x 8 GPUs = 512 GPUs
|
||||
srun torchrun --nnodes=64 --nproc_per_node=8 \
|
||||
-m torchtitan.train \
|
||||
--job.config_file ./llama3_405b.toml
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use TorchTitan when:**
|
||||
- Pretraining LLMs from scratch (8B to 405B+)
|
||||
- Need PyTorch-native solution without third-party dependencies
|
||||
- Require composable 4D parallelism (FSDP2, TP, PP, CP)
|
||||
- Training on H100s with Float8 support
|
||||
- Want interoperable checkpoints with torchtune/HuggingFace
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Megatron-LM**: Maximum performance for NVIDIA-only deployments
|
||||
- **DeepSpeed**: Broader ZeRO optimization ecosystem, inference support
|
||||
- **Axolotl/TRL**: Fine-tuning rather than pretraining
|
||||
- **LitGPT**: Educational, smaller-scale training
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Out of memory on large models**
|
||||
|
||||
Enable activation checkpointing and reduce batch size:
|
||||
```toml
|
||||
[activation_checkpoint]
|
||||
mode = "full" # Instead of "selective"
|
||||
|
||||
[training]
|
||||
local_batch_size = 1
|
||||
```
|
||||
|
||||
Or use gradient accumulation:
|
||||
```toml
|
||||
[training]
|
||||
local_batch_size = 1
|
||||
global_batch_size = 32 # Accumulates gradients
|
||||
```
|
||||
|
||||
**Issue: TP causes high memory with async collectives**
|
||||
|
||||
Set environment variable:
|
||||
```bash
|
||||
export TORCH_NCCL_AVOID_RECORD_STREAMS=1
|
||||
```
|
||||
|
||||
**Issue: Float8 training not faster**
|
||||
|
||||
Float8 only benefits large GEMMs. Filter small layers:
|
||||
```toml
|
||||
[quantize.linear.float8]
|
||||
filter_fqns = ["attention.wk", "attention.wv", "output", "auto_filter_small_kn"]
|
||||
```
|
||||
|
||||
**Issue: Checkpoint loading fails after parallelism change**
|
||||
|
||||
Use DCP's resharding capability:
|
||||
```bash
|
||||
# Convert sharded checkpoint to single file
|
||||
python -m torch.distributed.checkpoint.format_utils \
|
||||
dcp_to_torch checkpoint/step-1000 checkpoint.pt
|
||||
```
|
||||
|
||||
**Issue: Pipeline parallelism initialization**
|
||||
|
||||
Create seed checkpoint first (see Workflow 4, Step 1).
|
||||
|
||||
## Supported models
|
||||
|
||||
| Model | Sizes | Status |
|
||||
|-------|-------|--------|
|
||||
| Llama 3.1 | 8B, 70B, 405B | Production |
|
||||
| Llama 4 | Various | Experimental |
|
||||
| DeepSeek V3 | 16B, 236B, 671B (MoE) | Experimental |
|
||||
| GPT-OSS | 20B, 120B (MoE) | Experimental |
|
||||
| Qwen 3 | Various | Experimental |
|
||||
| Flux | Diffusion | Experimental |
|
||||
|
||||
## Performance benchmarks (H100)
|
||||
|
||||
| Model | GPUs | Parallelism | TPS/GPU | Techniques |
|
||||
|-------|------|-------------|---------|------------|
|
||||
| Llama 8B | 8 | FSDP | 5,762 | Baseline |
|
||||
| Llama 8B | 8 | FSDP+compile+FP8 | 8,532 | +48% |
|
||||
| Llama 70B | 256 | FSDP+TP+AsyncTP | 876 | 2D parallel |
|
||||
| Llama 405B | 512 | FSDP+TP+PP | 128 | 3D parallel |
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**FSDP2 configuration**: See [references/fsdp.md](references/fsdp.md) for detailed FSDP2 vs FSDP1 comparison and ZeRO equivalents.
|
||||
|
||||
**Float8 training**: See [references/float8.md](references/float8.md) for tensorwise vs rowwise scaling recipes.
|
||||
|
||||
**Checkpointing**: See [references/checkpoint.md](references/checkpoint.md) for HuggingFace conversion and async checkpointing.
|
||||
|
||||
**Adding custom models**: See [references/custom-models.md](references/custom-models.md) for TrainSpec protocol.
|
||||
|
||||
## Resources
|
||||
|
||||
- GitHub: https://github.com/pytorch/torchtitan
|
||||
- Paper: https://arxiv.org/abs/2410.06511
|
||||
- ICLR 2025: https://iclr.cc/virtual/2025/poster/29620
|
||||
- PyTorch Forum: https://discuss.pytorch.org/c/distributed/torchtitan/44
|
||||
|
||||
181
optional-skills/mlops/torchtitan/references/checkpoint.md
Normal file
181
optional-skills/mlops/torchtitan/references/checkpoint.md
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
# Checkpointing in TorchTitan
|
||||
|
||||
TorchTitan uses PyTorch Distributed Checkpoint (DCP) for fault-tolerant, interoperable checkpointing.
|
||||
|
||||
## Basic Configuration
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
folder = "checkpoint"
|
||||
interval = 500
|
||||
```
|
||||
|
||||
## Save Model Only (Smaller Checkpoints)
|
||||
|
||||
Exclude optimizer state and training metadata:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
last_save_model_only = true
|
||||
export_dtype = "bfloat16" # Optional: export in lower precision
|
||||
```
|
||||
|
||||
## Excluding Keys from Loading
|
||||
|
||||
Partial checkpoint loading for modified settings:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
exclude_from_loading = ["data_loader", "lr_scheduler"]
|
||||
```
|
||||
|
||||
CLI equivalent:
|
||||
```bash
|
||||
--checkpoint.exclude_from_loading data_loader,lr_scheduler
|
||||
```
|
||||
|
||||
## Creating Seed Checkpoints
|
||||
|
||||
Required for Pipeline Parallelism to ensure consistent initialization:
|
||||
|
||||
```bash
|
||||
NGPU=1 CONFIG_FILE=<path_to_config> ./run_train.sh \
|
||||
--checkpoint.enable \
|
||||
--checkpoint.create_seed_checkpoint \
|
||||
--parallelism.data_parallel_replicate_degree 1 \
|
||||
--parallelism.data_parallel_shard_degree 1 \
|
||||
--parallelism.tensor_parallel_degree 1 \
|
||||
--parallelism.pipeline_parallel_degree 1 \
|
||||
--parallelism.context_parallel_degree 1 \
|
||||
--parallelism.expert_parallel_degree 1
|
||||
```
|
||||
|
||||
This initializes on single CPU for reproducible initialization across any GPU count.
|
||||
|
||||
## Async Checkpointing
|
||||
|
||||
Reduce checkpoint overhead with async writes:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
async_mode = "async" # Options: "disabled", "async", "async_with_pinned_mem"
|
||||
```
|
||||
|
||||
## HuggingFace Conversion
|
||||
|
||||
### During Training
|
||||
|
||||
Save directly in HuggingFace format:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
last_save_in_hf = true
|
||||
last_save_model_only = true
|
||||
```
|
||||
|
||||
Load from HuggingFace:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
initial_load_in_hf = true
|
||||
|
||||
[model]
|
||||
hf_assets_path = "./path/to/hf/checkpoint"
|
||||
```
|
||||
|
||||
### Offline Conversion
|
||||
|
||||
Convert without running training:
|
||||
|
||||
```bash
|
||||
# HuggingFace -> TorchTitan
|
||||
python ./scripts/checkpoint_conversion/convert_from_hf.py \
|
||||
<input_dir> <output_dir> \
|
||||
--model_name llama3 \
|
||||
--model_flavor 8B
|
||||
|
||||
# TorchTitan -> HuggingFace
|
||||
python ./scripts/checkpoint_conversion/convert_to_hf.py \
|
||||
<input_dir> <output_dir> \
|
||||
--hf_assets_path ./assets/hf/Llama3.1-8B \
|
||||
--model_name llama3 \
|
||||
--model_flavor 8B
|
||||
```
|
||||
|
||||
### Example
|
||||
|
||||
```bash
|
||||
python ./scripts/convert_from_hf.py \
|
||||
~/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920/ \
|
||||
./initial_load_path/ \
|
||||
--model_name llama3 \
|
||||
--model_flavor 8B
|
||||
```
|
||||
|
||||
## Converting to Single .pt File
|
||||
|
||||
Convert DCP sharded checkpoint to single PyTorch file:
|
||||
|
||||
```bash
|
||||
python -m torch.distributed.checkpoint.format_utils \
|
||||
dcp_to_torch \
|
||||
torchtitan/outputs/checkpoint/step-1000 \
|
||||
checkpoint.pt
|
||||
```
|
||||
|
||||
## Checkpoint Structure
|
||||
|
||||
DCP saves sharded checkpoints that can be resharded for different parallelism configurations:
|
||||
|
||||
```
|
||||
checkpoint/
|
||||
├── step-500/
|
||||
│ ├── .metadata
|
||||
│ ├── __0_0.distcp
|
||||
│ ├── __0_1.distcp
|
||||
│ └── ...
|
||||
└── step-1000/
|
||||
└── ...
|
||||
```
|
||||
|
||||
## Resume Training
|
||||
|
||||
Training auto-resumes from the latest checkpoint in the configured folder. To resume from a specific step:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
load_step = 500 # Resume from step 500
|
||||
```
|
||||
|
||||
## Interoperability with TorchTune
|
||||
|
||||
Checkpoints saved with `last_save_model_only = true` can be loaded directly into [torchtune](https://github.com/pytorch/torchtune) for fine-tuning.
|
||||
|
||||
## Full Configuration Example
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
folder = "checkpoint"
|
||||
interval = 500
|
||||
load_step = -1 # -1 = latest, or specify step number
|
||||
last_save_model_only = true
|
||||
export_dtype = "bfloat16"
|
||||
async_mode = "async"
|
||||
exclude_from_loading = []
|
||||
last_save_in_hf = false
|
||||
initial_load_in_hf = false
|
||||
create_seed_checkpoint = false
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Large models**: Use `async_mode = "async"` to overlap checkpoint saves with training
|
||||
2. **Fine-tuning export**: Enable `last_save_model_only` and `export_dtype = "bfloat16"` for smaller files
|
||||
3. **Pipeline parallelism**: Always create seed checkpoint first
|
||||
4. **Debugging**: Save frequent checkpoints during development, reduce for production
|
||||
5. **HF interop**: Use conversion scripts for offline conversion, direct save/load for training workflows
|
||||
258
optional-skills/mlops/torchtitan/references/custom-models.md
Normal file
258
optional-skills/mlops/torchtitan/references/custom-models.md
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
# Adding Custom Models to TorchTitan
|
||||
|
||||
This guide explains how to add a new model to TorchTitan following the established patterns.
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
torchtitan/models/your_model/
|
||||
├── model/
|
||||
│ ├── __init__.py
|
||||
│ ├── args.py # Model arguments
|
||||
│ ├── model.py # Model definition
|
||||
│ └── state_dict_adapter.py # HF conversion (optional)
|
||||
├── infra/
|
||||
│ ├── __init__.py
|
||||
│ ├── parallelize.py # TP, FSDP, compile application
|
||||
│ └── pipeline.py # PP application (optional)
|
||||
├── train_configs/
|
||||
│ ├── debug_model.toml
|
||||
│ └── your_model_XB.toml
|
||||
├── __init__.py # TrainSpec registration
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## Step 1: Define Model Arguments
|
||||
|
||||
Inherit from `BaseModelArgs`:
|
||||
|
||||
```python
|
||||
# model/args.py
|
||||
from torchtitan.protocols.model import BaseModelArgs
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class YourModelArgs(BaseModelArgs):
|
||||
dim: int = 4096
|
||||
n_layers: int = 32
|
||||
n_heads: int = 32
|
||||
vocab_size: int = 128256
|
||||
|
||||
def get_nparams_and_flops(self, seq_len: int) -> tuple[int, int]:
|
||||
"""Return (num_params, flops_per_token) for throughput calculation."""
|
||||
nparams = self.vocab_size * self.dim + ... # Calculate params
|
||||
flops = 6 * nparams # Approximate: 6 * params for forward+backward
|
||||
return nparams, flops
|
||||
|
||||
def update_from_config(self, job_config) -> "YourModelArgs":
|
||||
"""Update args from training config."""
|
||||
# Override specific args from job_config if needed
|
||||
return self
|
||||
```
|
||||
|
||||
## Step 2: Define Model
|
||||
|
||||
Inherit from `ModelProtocol`:
|
||||
|
||||
```python
|
||||
# model/model.py
|
||||
import torch.nn as nn
|
||||
from torchtitan.protocols.model import ModelProtocol
|
||||
from .args import YourModelArgs
|
||||
|
||||
class YourModel(ModelProtocol):
|
||||
def __init__(self, args: YourModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
||||
self.layers = nn.ModuleDict({
|
||||
str(i): TransformerBlock(args) for i in range(args.n_layers)
|
||||
})
|
||||
self.norm = RMSNorm(args.dim)
|
||||
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
||||
|
||||
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
h = self.tok_embeddings(tokens)
|
||||
for layer in self.layers.values():
|
||||
h = layer(h)
|
||||
h = self.norm(h)
|
||||
return self.output(h)
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize weights recursively."""
|
||||
for module in self.modules():
|
||||
if hasattr(module, 'init_weights') and module is not self:
|
||||
module.init_weights()
|
||||
elif isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, std=0.02)
|
||||
```
|
||||
|
||||
**Important guidelines**:
|
||||
- Write single-device model code (parallelism applied externally)
|
||||
- Use `nn.ModuleDict` for layers (preserves FQNs when deleting for PP)
|
||||
- Make input/output layers optional for PP compatibility
|
||||
- Define `init_weights()` recursively
|
||||
|
||||
## Step 3: Parallelize Function
|
||||
|
||||
```python
|
||||
# infra/parallelize.py
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
from torch.distributed.tensor.parallel import parallelize_module
|
||||
|
||||
def parallelize_your_model(
|
||||
model: YourModel,
|
||||
world_mesh: DeviceMesh,
|
||||
parallel_dims: ParallelDims,
|
||||
job_config: JobConfig,
|
||||
):
|
||||
# Apply in this order: TP -> AC -> compile -> FSDP
|
||||
|
||||
# 1. Tensor Parallelism
|
||||
if parallel_dims.tp_enabled:
|
||||
apply_tp(model, world_mesh["tp"], job_config)
|
||||
|
||||
# 2. Activation Checkpointing
|
||||
if job_config.activation_checkpoint.mode == "full":
|
||||
apply_ac(model, job_config)
|
||||
|
||||
# 3. torch.compile
|
||||
if job_config.compile.enable:
|
||||
model = torch.compile(model)
|
||||
|
||||
# 4. FSDP
|
||||
if parallel_dims.dp_enabled:
|
||||
apply_fsdp(model, world_mesh["dp"], job_config)
|
||||
|
||||
return model
|
||||
```
|
||||
|
||||
## Step 4: Create TrainSpec
|
||||
|
||||
```python
|
||||
# __init__.py
|
||||
from torchtitan.protocols.train_spec import TrainSpec, register_train_spec
|
||||
from .model.model import YourModel
|
||||
from .model.args import YourModelArgs
|
||||
from .infra.parallelize import parallelize_your_model
|
||||
|
||||
MODEL_CONFIGS = {
|
||||
"8B": YourModelArgs(dim=4096, n_layers=32, n_heads=32),
|
||||
"70B": YourModelArgs(dim=8192, n_layers=80, n_heads=64),
|
||||
}
|
||||
|
||||
def get_train_spec(flavor: str) -> TrainSpec:
|
||||
return TrainSpec(
|
||||
model_cls=YourModel,
|
||||
model_args=MODEL_CONFIGS[flavor],
|
||||
parallelize_fn=parallelize_your_model,
|
||||
pipeline_fn=None, # Or your_pipeline_fn for PP
|
||||
build_optimizer_fn=build_optimizer, # Reuse existing
|
||||
build_lr_scheduler_fn=build_lr_scheduler, # Reuse existing
|
||||
build_dataloader_fn=build_dataloader, # Reuse existing
|
||||
build_tokenizer_fn=build_tokenizer, # Reuse existing
|
||||
build_loss_fn=build_loss, # Reuse existing
|
||||
state_dict_adapter=None, # Or YourStateDictAdapter
|
||||
)
|
||||
|
||||
# Register so train.py can find it
|
||||
register_train_spec("your_model", get_train_spec)
|
||||
```
|
||||
|
||||
## Step 5: State Dict Adapter (Optional)
|
||||
|
||||
For HuggingFace checkpoint conversion:
|
||||
|
||||
```python
|
||||
# model/state_dict_adapter.py
|
||||
from torchtitan.protocols.state_dict_adapter import BaseStateDictAdapter
|
||||
|
||||
class YourStateDictAdapter(BaseStateDictAdapter):
|
||||
def to_hf(self, state_dict: dict) -> dict:
|
||||
"""Convert torchtitan state dict to HF format."""
|
||||
hf_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
hf_key = self._convert_key_to_hf(key)
|
||||
hf_state_dict[hf_key] = value
|
||||
return hf_state_dict
|
||||
|
||||
def from_hf(self, state_dict: dict) -> dict:
|
||||
"""Convert HF state dict to torchtitan format."""
|
||||
tt_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
tt_key = self._convert_key_from_hf(key)
|
||||
tt_state_dict[tt_key] = value
|
||||
return tt_state_dict
|
||||
```
|
||||
|
||||
## Step 6: Training Config
|
||||
|
||||
```toml
|
||||
# train_configs/your_model_8b.toml
|
||||
[job]
|
||||
dump_folder = "./outputs"
|
||||
description = "Your Model 8B training"
|
||||
|
||||
[model]
|
||||
name = "your_model"
|
||||
flavor = "8B"
|
||||
|
||||
[optimizer]
|
||||
name = "AdamW"
|
||||
lr = 3e-4
|
||||
|
||||
[training]
|
||||
local_batch_size = 2
|
||||
seq_len = 8192
|
||||
steps = 1000
|
||||
dataset = "c4"
|
||||
|
||||
[parallelism]
|
||||
data_parallel_shard_degree = -1
|
||||
tensor_parallel_degree = 1
|
||||
```
|
||||
|
||||
## Step 7: Register Model
|
||||
|
||||
Add to `torchtitan/models/__init__.py`:
|
||||
|
||||
```python
|
||||
from .your_model import get_train_spec as get_your_model_train_spec
|
||||
|
||||
MODEL_REGISTRY["your_model"] = get_your_model_train_spec
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
### Numerics Test
|
||||
|
||||
Compare output with HuggingFace implementation:
|
||||
|
||||
```python
|
||||
def test_numerics():
|
||||
# Load same checkpoint into both implementations
|
||||
tt_model = YourModel(args).load_checkpoint(...)
|
||||
hf_model = HFYourModel.from_pretrained(...)
|
||||
|
||||
# Compare outputs
|
||||
input_ids = torch.randint(0, vocab_size, (1, 128))
|
||||
tt_output = tt_model(input_ids)
|
||||
hf_output = hf_model(input_ids).logits
|
||||
|
||||
torch.testing.assert_close(tt_output, hf_output, atol=1e-4, rtol=1e-4)
|
||||
```
|
||||
|
||||
### Loss Convergence
|
||||
|
||||
Compare loss curves with verified baseline (see `docs/converging.md`).
|
||||
|
||||
### Performance Benchmark
|
||||
|
||||
Add benchmark config to `benchmarks/` folder.
|
||||
|
||||
## Guiding Principles
|
||||
|
||||
1. **Readability over flexibility**: Don't over-abstract
|
||||
2. **Minimal model changes**: Parallelism applied externally
|
||||
3. **Clean, minimal codebase**: Reuse existing components where possible
|
||||
4. **Single-device semantics**: Model code should work on single GPU
|
||||
133
optional-skills/mlops/torchtitan/references/float8.md
Normal file
133
optional-skills/mlops/torchtitan/references/float8.md
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
# Float8 Training in TorchTitan
|
||||
|
||||
Float8 training provides substantial speedups for models where GEMMs are large enough that the FP8 tensorcore speedup outweighs dynamic quantization overhead.
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
- NVIDIA H100 or newer GPUs (FP8 Tensor Cores)
|
||||
- Blackwell GPUs for MXFP8 training
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
USE_CPP=0 pip install git+https://github.com/pytorch/ao.git
|
||||
```
|
||||
|
||||
## Usage: Tensorwise Scaling
|
||||
|
||||
Standard Float8 with tensorwise dynamic scaling:
|
||||
|
||||
```bash
|
||||
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \
|
||||
--model.converters="quantize.linear.float8" \
|
||||
--quantize.linear.float8.enable_fsdp_float8_all_gather \
|
||||
--quantize.linear.float8.precompute_float8_dynamic_scale_for_fsdp \
|
||||
--compile.enable
|
||||
```
|
||||
|
||||
### Key Arguments
|
||||
|
||||
| Argument | Description |
|
||||
|----------|-------------|
|
||||
| `--model.converters="quantize.linear.float8"` | Swap `nn.Linear` with `Float8Linear` |
|
||||
| `--quantize.linear.float8.enable_fsdp_float8_all_gather` | Communicate in float8 to save bandwidth |
|
||||
| `--quantize.linear.float8.precompute_float8_dynamic_scale_for_fsdp` | Single all-reduce for all AMAX/scales |
|
||||
| `--compile.enable` | Required - fuses float8 scaling/casting kernels |
|
||||
|
||||
## Usage: Rowwise Scaling
|
||||
|
||||
Higher accuracy than tensorwise scaling:
|
||||
|
||||
```bash
|
||||
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \
|
||||
--model.converters="quantize.linear.float8" \
|
||||
--quantize.linear.float8.recipe_name rowwise \
|
||||
--compile.enable
|
||||
```
|
||||
|
||||
## Filtering Layers
|
||||
|
||||
Not all layers benefit from Float8. Filter small layers:
|
||||
|
||||
```bash
|
||||
--quantize.linear.float8.filter_fqns="attention.wk,attention.wv,output"
|
||||
```
|
||||
|
||||
### Auto-filtering
|
||||
|
||||
Automatically skip layers too small to benefit:
|
||||
|
||||
```bash
|
||||
--quantize.linear.float8.filter_fqns="auto_filter_small_kn"
|
||||
```
|
||||
|
||||
Thresholds based on H100 microbenchmarks where speedup > overhead.
|
||||
|
||||
## TOML Configuration
|
||||
|
||||
```toml
|
||||
[model]
|
||||
converters = ["quantize.linear.float8"]
|
||||
|
||||
[quantize.linear.float8]
|
||||
enable_fsdp_float8_all_gather = true
|
||||
precompute_float8_dynamic_scale_for_fsdp = true
|
||||
filter_fqns = ["output", "auto_filter_small_kn"]
|
||||
|
||||
[compile]
|
||||
enable = true
|
||||
components = ["model", "loss"]
|
||||
```
|
||||
|
||||
## How Float8 Works with Distributed Training
|
||||
|
||||
### Single Device
|
||||
|
||||
Cast input and weight to float8 inside forward before calling `torch._scaled_mm`:
|
||||
|
||||
```python
|
||||
# Float8 matmul requires scales
|
||||
torch._scaled_mm(input_fp8, weight_fp8, scale_a=scale_input, scale_b=scale_weight)
|
||||
```
|
||||
|
||||
### FSDP + Float8
|
||||
|
||||
1. Cast sharded high-precision weights (1/N per rank) to float8
|
||||
2. Perform float8 all-gather (saves bandwidth vs bf16/fp32)
|
||||
3. Communicate `max(abs)` across ranks for scale computation
|
||||
4. At forward start, have unsharded float8 weights ready
|
||||
|
||||
**Net benefit**: Float8 all-gather + amax communication can beat bf16/fp32 all-gather, depending on world size and message size.
|
||||
|
||||
### TP + Float8
|
||||
|
||||
- **Input**: Cast sharded input to float8, all-gather in float8
|
||||
- **Weights**: Communicate `max(abs)` for sharded weights
|
||||
- **Matmul**: Float8 input (unsharded) x float8 weight (sharded) with global scales
|
||||
|
||||
## Scaling Strategies
|
||||
|
||||
| Strategy | Status | Description |
|
||||
|----------|--------|-------------|
|
||||
| Tensorwise dynamic | Stable | Single scale per tensor |
|
||||
| Rowwise dynamic | Alpha | Scale per row, higher accuracy |
|
||||
|
||||
## Performance Gains
|
||||
|
||||
From benchmarks on H100:
|
||||
|
||||
| Configuration | TPS/GPU | vs Baseline |
|
||||
|---------------|---------|-------------|
|
||||
| FSDP only | 5,762 | - |
|
||||
| FSDP + compile | 6,667 | +16% |
|
||||
| FSDP + compile + Float8 | 8,532 | +48% |
|
||||
|
||||
## Determining Float8 Benefit
|
||||
|
||||
Check [torchao microbenchmarks](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) for forward+backward pass speedups on "layer norm => linear => sigmoid" for different M,N,K sizes.
|
||||
|
||||
Rule of thumb: GEMMs with K,N > 4096 typically benefit from Float8.
|
||||
|
||||
## MXFP8 Training (Blackwell)
|
||||
|
||||
For NVIDIA Blackwell GPUs, TorchTitan supports MXFP8 (Microscaling FP8) for both dense and MoE models. See [docs/mxfp8.md](https://github.com/pytorch/torchtitan/blob/main/docs/mxfp8.md) for details.
|
||||
126
optional-skills/mlops/torchtitan/references/fsdp.md
Normal file
126
optional-skills/mlops/torchtitan/references/fsdp.md
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
# FSDP2 in TorchTitan
|
||||
|
||||
## Why FSDP2?
|
||||
|
||||
FSDP2 is a rewrite of PyTorch's Fully Sharded Data Parallel (FSDP) API, removing the `FlatParameter` abstraction for better composability and simpler implementation.
|
||||
|
||||
### Key improvements over FSDP1
|
||||
|
||||
- **DTensor-based sharding**: Sharded parameters are `DTensor`s on dim-0, enabling easy manipulation and communication-free sharded state dicts
|
||||
- **Better memory management**: Deterministic and lower GPU memory (7% reduction) by avoiding `recordStream`
|
||||
- **Simplified API**: Fewer arguments, no wrapper class
|
||||
|
||||
### Performance
|
||||
|
||||
On Llama-7B with 8x H100s, FSDP2 achieves higher MFU with 7% lower peak memory than FSDP1, matching the same loss curve.
|
||||
|
||||
## API Reference
|
||||
|
||||
```python
|
||||
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, OffloadPolicy
|
||||
|
||||
@contract(state_cls=FSDPState)
|
||||
def fully_shard(
|
||||
module: nn.Module,
|
||||
*,
|
||||
mesh: Optional[DeviceMesh] = None,
|
||||
reshard_after_forward: Union[bool, int] = True,
|
||||
mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
|
||||
offload_policy: OffloadPolicy = OffloadPolicy(),
|
||||
) -> nn.Module:
|
||||
```
|
||||
|
||||
## Sharding Strategies (ZeRO Equivalents)
|
||||
|
||||
| FSDP2 Configuration | FSDP1 Equivalent | DeepSpeed |
|
||||
|---------------------|------------------|-----------|
|
||||
| 1D mesh + `reshard_after_forward=True` | FULL_SHARD | ZeRO-3 |
|
||||
| 1D mesh + `reshard_after_forward=False` | SHARD_GRAD_OP | ZeRO-2 |
|
||||
| 2D mesh + `reshard_after_forward=True` | HYBRID_SHARD | MiCS |
|
||||
| 1D/2D mesh + `reshard_after_forward=8` (int) | - | ZeRO++ hpZ |
|
||||
|
||||
## Meta-Device Initialization
|
||||
|
||||
FSDP2 supports materializing tensors onto GPU _after_ sharding:
|
||||
|
||||
```python
|
||||
# Initialize on meta device (no memory)
|
||||
with torch.device("meta"):
|
||||
model = Transformer()
|
||||
|
||||
# Apply FSDP2 sharding
|
||||
for module in model.modules():
|
||||
if isinstance(module, TransformerBlock):
|
||||
fully_shard(module)
|
||||
fully_shard(model)
|
||||
|
||||
# Parameters still on meta device
|
||||
for tensor in itertools.chain(model.parameters(), model.buffers()):
|
||||
assert tensor.device == torch.device("meta")
|
||||
|
||||
# Allocate sharded parameters on GPU
|
||||
model.to_empty(device="cuda")
|
||||
|
||||
# Initialize weights
|
||||
model.init_weights()
|
||||
```
|
||||
|
||||
## State Dict Differences
|
||||
|
||||
| Operation | FSDP1 | FSDP2 |
|
||||
|-----------|-------|-------|
|
||||
| `model.state_dict()` | Full state dict | Sharded state dict (no communication) |
|
||||
| `optim.state_dict()` | Local state dict | Sharded state dict (no communication) |
|
||||
| `summon_full_params()` | Supported | Use `DTensor` APIs like `full_tensor()` |
|
||||
| Gradient clipping | `FSDP.clip_grad_norm_()` | `nn.utils.clip_grad_norm_()` |
|
||||
|
||||
## Mixed Precision
|
||||
|
||||
```python
|
||||
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
|
||||
|
||||
mp_policy = MixedPrecisionPolicy(
|
||||
param_dtype=torch.bfloat16,
|
||||
reduce_dtype=torch.float32,
|
||||
output_dtype=torch.bfloat16,
|
||||
cast_forward_inputs=True,
|
||||
)
|
||||
|
||||
fully_shard(model, mp_policy=mp_policy)
|
||||
```
|
||||
|
||||
## HSDP (Hybrid Sharded Data Parallel)
|
||||
|
||||
For 2D parallelism with replication + sharding:
|
||||
|
||||
```python
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
# Replicate across 4 groups, shard within 8 GPUs each
|
||||
mesh = init_device_mesh("cuda", (4, 8), mesh_dim_names=("replicate", "shard"))
|
||||
|
||||
fully_shard(model, mesh=mesh)
|
||||
```
|
||||
|
||||
## Configuration in TorchTitan
|
||||
|
||||
```toml
|
||||
[parallelism]
|
||||
# FSDP sharding degree (-1 = auto, use all available GPUs)
|
||||
data_parallel_shard_degree = -1
|
||||
|
||||
# HSDP replication degree (1 = pure FSDP, >1 = HSDP)
|
||||
data_parallel_replicate_degree = 1
|
||||
```
|
||||
|
||||
## Removed Arguments from FSDP1
|
||||
|
||||
These FSDP1 arguments are no longer needed:
|
||||
|
||||
- `auto_wrap_policy`: Apply `fully_shard` directly to modules
|
||||
- `backward_prefetch`: Always uses BACKWARD_PRE
|
||||
- `param_init_fn`: Use meta-device initialization
|
||||
- `device_id`: Uses mesh's device automatically
|
||||
- `sync_module_states`: Not needed with DTensor
|
||||
- `limit_all_gathers`: New memory management doesn't need it
|
||||
- `use_orig_params`: Always true (no FlatParameter)
|
||||
96
optional-skills/research/domain-intel/SKILL.md
Normal file
96
optional-skills/research/domain-intel/SKILL.md
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
---
|
||||
name: domain-intel
|
||||
description: Passive domain reconnaissance using Python stdlib. Subdomain discovery, SSL certificate inspection, WHOIS lookups, DNS records, domain availability checks, and bulk multi-domain analysis. No API keys required.
|
||||
---
|
||||
|
||||
# Domain Intelligence — Passive OSINT
|
||||
|
||||
Passive domain reconnaissance using only Python stdlib.
|
||||
**Zero dependencies. Zero API keys. Works on Linux, macOS, and Windows.**
|
||||
|
||||
## Helper script
|
||||
|
||||
This skill includes `scripts/domain_intel.py` — a complete CLI tool for all domain intelligence operations.
|
||||
|
||||
```bash
|
||||
# Subdomain discovery via Certificate Transparency logs
|
||||
python3 SKILL_DIR/scripts/domain_intel.py subdomains example.com
|
||||
|
||||
# SSL certificate inspection (expiry, cipher, SANs, issuer)
|
||||
python3 SKILL_DIR/scripts/domain_intel.py ssl example.com
|
||||
|
||||
# WHOIS lookup (registrar, dates, name servers — 100+ TLDs)
|
||||
python3 SKILL_DIR/scripts/domain_intel.py whois example.com
|
||||
|
||||
# DNS records (A, AAAA, MX, NS, TXT, CNAME)
|
||||
python3 SKILL_DIR/scripts/domain_intel.py dns example.com
|
||||
|
||||
# Domain availability check (passive: DNS + WHOIS + SSL signals)
|
||||
python3 SKILL_DIR/scripts/domain_intel.py available coolstartup.io
|
||||
|
||||
# Bulk analysis — multiple domains, multiple checks in parallel
|
||||
python3 SKILL_DIR/scripts/domain_intel.py bulk example.com github.com google.com
|
||||
python3 SKILL_DIR/scripts/domain_intel.py bulk example.com github.com --checks ssl,dns
|
||||
```
|
||||
|
||||
`SKILL_DIR` is the directory containing this SKILL.md file. All output is structured JSON.
|
||||
|
||||
## Available commands
|
||||
|
||||
| Command | What it does | Data source |
|
||||
|---------|-------------|-------------|
|
||||
| `subdomains` | Find subdomains from certificate logs | crt.sh (HTTPS) |
|
||||
| `ssl` | Inspect TLS certificate details | Direct TCP:443 to target |
|
||||
| `whois` | Registration info, registrar, dates | WHOIS servers (TCP:43) |
|
||||
| `dns` | A, AAAA, MX, NS, TXT, CNAME records | System DNS + Google DoH |
|
||||
| `available` | Check if domain is registered | DNS + WHOIS + SSL signals |
|
||||
| `bulk` | Run multiple checks on multiple domains | All of the above |
|
||||
|
||||
## When to use this vs built-in tools
|
||||
|
||||
- **Use this skill** for infrastructure questions: subdomains, SSL certs, WHOIS, DNS records, availability
|
||||
- **Use `web_search`** for general research about what a domain/company does
|
||||
- **Use `web_extract`** to get the actual content of a webpage
|
||||
- **Use `terminal` with `curl -I`** for a simple "is this URL reachable" check
|
||||
|
||||
| Task | Better tool | Why |
|
||||
|------|-------------|-----|
|
||||
| "What does example.com do?" | `web_extract` | Gets page content, not DNS/WHOIS data |
|
||||
| "Find info about a company" | `web_search` | General research, not domain-specific |
|
||||
| "Is this website safe?" | `web_search` | Reputation checks need web context |
|
||||
| "Check if a URL is reachable" | `terminal` with `curl -I` | Simple HTTP check |
|
||||
| "Find subdomains of X" | **This skill** | Only passive source for this |
|
||||
| "When does the SSL cert expire?" | **This skill** | Built-in tools can't inspect TLS |
|
||||
| "Who registered this domain?" | **This skill** | WHOIS data not in web search |
|
||||
| "Is coolstartup.io available?" | **This skill** | Passive availability via DNS+WHOIS+SSL |
|
||||
|
||||
## Platform compatibility
|
||||
|
||||
Pure Python stdlib (`socket`, `ssl`, `urllib`, `json`, `concurrent.futures`).
|
||||
Works identically on Linux, macOS, and Windows with no dependencies.
|
||||
|
||||
- **crt.sh queries** use HTTPS (port 443) — works behind most firewalls
|
||||
- **WHOIS queries** use TCP port 43 — may be blocked on restrictive networks
|
||||
- **DNS queries** use Google DoH (HTTPS) for MX/NS/TXT — firewall-friendly
|
||||
- **SSL checks** connect to the target on port 443 — the only "active" operation
|
||||
|
||||
## Data sources
|
||||
|
||||
All queries are **passive** — no port scanning, no vulnerability testing:
|
||||
|
||||
- **crt.sh** — Certificate Transparency logs (subdomain discovery, HTTPS only)
|
||||
- **WHOIS servers** — Direct TCP to 100+ authoritative TLD registrars
|
||||
- **Google DNS-over-HTTPS** — MX, NS, TXT, CNAME resolution (firewall-friendly)
|
||||
- **System DNS** — A/AAAA record resolution
|
||||
- **SSL check** is the only "active" operation (TCP connection to target:443)
|
||||
|
||||
## Notes
|
||||
|
||||
- WHOIS queries use TCP port 43 — may be blocked on restrictive networks
|
||||
- Some WHOIS servers redact registrant info (GDPR) — mention this to the user
|
||||
- crt.sh can be slow for very popular domains (thousands of certs) — set reasonable expectations
|
||||
- The availability check is heuristic-based (3 passive signals) — not authoritative like a registrar API
|
||||
|
||||
---
|
||||
|
||||
*Contributed by [@FurkanL0](https://github.com/FurkanL0)*
|
||||
397
optional-skills/research/domain-intel/scripts/domain_intel.py
Normal file
397
optional-skills/research/domain-intel/scripts/domain_intel.py
Normal file
|
|
@ -0,0 +1,397 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Domain Intelligence — Passive OSINT via Python stdlib.
|
||||
|
||||
Usage:
|
||||
python domain_intel.py subdomains example.com
|
||||
python domain_intel.py ssl example.com
|
||||
python domain_intel.py whois example.com
|
||||
python domain_intel.py dns example.com
|
||||
python domain_intel.py available example.com
|
||||
python domain_intel.py bulk example.com github.com google.com --checks ssl,dns
|
||||
|
||||
All output is structured JSON. No dependencies beyond Python stdlib.
|
||||
Works on Linux, macOS, and Windows.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
# ─── Subdomain Discovery (crt.sh) ──────────────────────────────────────────
|
||||
|
||||
def subdomains(domain, include_expired=False, limit=200):
|
||||
"""Find subdomains via Certificate Transparency logs."""
|
||||
url = f"https://crt.sh/?q=%25.{urllib.parse.quote(domain)}&output=json"
|
||||
req = urllib.request.Request(url, headers={
|
||||
"User-Agent": "domain-intel-skill/1.0", "Accept": "application/json",
|
||||
})
|
||||
with urllib.request.urlopen(req, timeout=15) as r:
|
||||
entries = json.loads(r.read().decode())
|
||||
|
||||
seen, results = set(), []
|
||||
now = datetime.now(timezone.utc)
|
||||
for e in entries:
|
||||
not_after = e.get("not_after", "")
|
||||
if not include_expired and not_after:
|
||||
try:
|
||||
dt = datetime.strptime(not_after[:19], "%Y-%m-%dT%H:%M:%S").replace(tzinfo=timezone.utc)
|
||||
if dt <= now:
|
||||
continue
|
||||
except ValueError:
|
||||
pass
|
||||
for name in e.get("name_value", "").splitlines():
|
||||
name = name.strip().lower()
|
||||
if name and name not in seen:
|
||||
seen.add(name)
|
||||
results.append({
|
||||
"subdomain": name,
|
||||
"issuer": e.get("issuer_name", ""),
|
||||
"not_after": not_after,
|
||||
})
|
||||
|
||||
results.sort(key=lambda r: (r["subdomain"].startswith("*"), r["subdomain"]))
|
||||
return {"domain": domain, "count": min(len(results), limit), "subdomains": results[:limit]}
|
||||
|
||||
|
||||
# ─── SSL Certificate Inspection ────────────────────────────────────────────
|
||||
|
||||
def check_ssl(host, port=443, timeout=10):
|
||||
"""Inspect the TLS certificate of a host."""
|
||||
def flat(rdns):
|
||||
r = {}
|
||||
for rdn in rdns:
|
||||
for item in rdn:
|
||||
if isinstance(item, (list, tuple)) and len(item) == 2:
|
||||
r[item[0]] = item[1]
|
||||
return r
|
||||
|
||||
def parse_date(s):
|
||||
for fmt in ("%b %d %H:%M:%S %Y %Z", "%b %d %H:%M:%S %Y %Z"):
|
||||
try:
|
||||
return datetime.strptime(s, fmt).replace(tzinfo=timezone.utc)
|
||||
except ValueError:
|
||||
pass
|
||||
return None
|
||||
|
||||
warning = None
|
||||
try:
|
||||
ctx = ssl.create_default_context()
|
||||
with socket.create_connection((host, port), timeout=timeout) as sock:
|
||||
with ctx.wrap_socket(sock, server_hostname=host) as s:
|
||||
cert, cipher, proto = s.getpeercert(), s.cipher(), s.version()
|
||||
except ssl.SSLCertVerificationError as e:
|
||||
warning = str(e)
|
||||
ctx = ssl.create_default_context()
|
||||
ctx.check_hostname = False
|
||||
ctx.verify_mode = ssl.CERT_NONE
|
||||
with socket.create_connection((host, port), timeout=timeout) as sock:
|
||||
with ctx.wrap_socket(sock, server_hostname=host) as s:
|
||||
cert, cipher, proto = s.getpeercert(), s.cipher(), s.version()
|
||||
|
||||
not_after = parse_date(cert.get("notAfter", ""))
|
||||
now = datetime.now(timezone.utc)
|
||||
days = (not_after - now).days if not_after else None
|
||||
is_expired = days is not None and days < 0
|
||||
|
||||
if is_expired:
|
||||
status = f"EXPIRED ({abs(days)} days ago)"
|
||||
elif days is not None and days <= 14:
|
||||
status = f"CRITICAL — {days} day(s) left"
|
||||
elif days is not None and days <= 30:
|
||||
status = f"WARNING — {days} day(s) left"
|
||||
else:
|
||||
status = f"OK — {days} day(s) remaining" if days is not None else "unknown"
|
||||
|
||||
return {
|
||||
"host": host, "port": port,
|
||||
"subject": flat(cert.get("subject", [])),
|
||||
"issuer": flat(cert.get("issuer", [])),
|
||||
"subject_alt_names": [f"{t}:{v}" for t, v in cert.get("subjectAltName", [])],
|
||||
"not_before": parse_date(cert.get("notBefore", "")).isoformat() if parse_date(cert.get("notBefore", "")) else "",
|
||||
"not_after": not_after.isoformat() if not_after else "",
|
||||
"days_remaining": days, "is_expired": is_expired, "expiry_status": status,
|
||||
"tls_version": proto,
|
||||
"cipher_suite": cipher[0] if cipher else None,
|
||||
"serial_number": cert.get("serialNumber", ""),
|
||||
"verification_warning": warning,
|
||||
}
|
||||
|
||||
|
||||
# ─── WHOIS Lookup ──────────────────────────────────────────────────────────
|
||||
|
||||
WHOIS_SERVERS = {
|
||||
"com": "whois.verisign-grs.com", "net": "whois.verisign-grs.com",
|
||||
"org": "whois.pir.org", "io": "whois.nic.io", "co": "whois.nic.co",
|
||||
"ai": "whois.nic.ai", "dev": "whois.nic.google", "app": "whois.nic.google",
|
||||
"tech": "whois.nic.tech", "shop": "whois.nic.shop", "store": "whois.nic.store",
|
||||
"online": "whois.nic.online", "site": "whois.nic.site", "cloud": "whois.nic.cloud",
|
||||
"digital": "whois.nic.digital", "media": "whois.nic.media", "blog": "whois.nic.blog",
|
||||
"info": "whois.afilias.net", "biz": "whois.biz", "me": "whois.nic.me",
|
||||
"tv": "whois.nic.tv", "cc": "whois.nic.cc", "ws": "whois.website.ws",
|
||||
"uk": "whois.nic.uk", "co.uk": "whois.nic.uk", "de": "whois.denic.de",
|
||||
"nl": "whois.domain-registry.nl", "fr": "whois.nic.fr", "it": "whois.nic.it",
|
||||
"es": "whois.nic.es", "pl": "whois.dns.pl", "ru": "whois.tcinet.ru",
|
||||
"se": "whois.iis.se", "no": "whois.norid.no", "fi": "whois.fi",
|
||||
"ch": "whois.nic.ch", "at": "whois.nic.at", "be": "whois.dns.be",
|
||||
"cz": "whois.nic.cz", "br": "whois.registro.br", "ca": "whois.cira.ca",
|
||||
"mx": "whois.mx", "au": "whois.auda.org.au", "jp": "whois.jprs.jp",
|
||||
"cn": "whois.cnnic.cn", "in": "whois.inregistry.net", "kr": "whois.kr",
|
||||
"sg": "whois.sgnic.sg", "hk": "whois.hkirc.hk", "tr": "whois.nic.tr",
|
||||
"ae": "whois.aeda.net.ae", "za": "whois.registry.net.za",
|
||||
"space": "whois.nic.space", "zone": "whois.nic.zone", "ninja": "whois.nic.ninja",
|
||||
"guru": "whois.nic.guru", "rocks": "whois.nic.rocks", "live": "whois.nic.live",
|
||||
"game": "whois.nic.game", "games": "whois.nic.games",
|
||||
}
|
||||
|
||||
|
||||
def whois_lookup(domain):
|
||||
"""Query WHOIS servers for domain registration info."""
|
||||
parts = domain.split(".")
|
||||
server = WHOIS_SERVERS.get(".".join(parts[-2:])) or WHOIS_SERVERS.get(parts[-1])
|
||||
if not server:
|
||||
return {"error": f"No WHOIS server for .{parts[-1]}"}
|
||||
|
||||
try:
|
||||
with socket.create_connection((server, 43), timeout=10) as s:
|
||||
s.sendall((domain + "\r\n").encode())
|
||||
chunks = []
|
||||
while True:
|
||||
c = s.recv(4096)
|
||||
if not c:
|
||||
break
|
||||
chunks.append(c)
|
||||
raw = b"".join(chunks).decode("utf-8", errors="replace")
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
patterns = {
|
||||
"registrar": r"(?:Registrar|registrar):\s*(.+)",
|
||||
"creation_date": r"(?:Creation Date|Created|created):\s*(.+)",
|
||||
"expiration_date": r"(?:Registry Expiry Date|Expiration Date|Expiry Date):\s*(.+)",
|
||||
"updated_date": r"(?:Updated Date|Last Modified):\s*(.+)",
|
||||
"name_servers": r"(?:Name Server|nserver):\s*(.+)",
|
||||
"status": r"(?:Domain Status|status):\s*(.+)",
|
||||
"dnssec": r"DNSSEC:\s*(.+)",
|
||||
}
|
||||
result = {"domain": domain, "whois_server": server}
|
||||
for key, pat in patterns.items():
|
||||
matches = re.findall(pat, raw, re.IGNORECASE)
|
||||
if matches:
|
||||
if key in ("name_servers", "status"):
|
||||
result[key] = list(dict.fromkeys(m.strip().lower() for m in matches))
|
||||
else:
|
||||
result[key] = matches[0].strip()
|
||||
|
||||
for field in ("creation_date", "expiration_date", "updated_date"):
|
||||
if field in result:
|
||||
for fmt in ("%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%d %H:%M:%S", "%Y-%m-%d"):
|
||||
try:
|
||||
dt = datetime.strptime(result[field][:19], fmt).replace(tzinfo=timezone.utc)
|
||||
result[field] = dt.isoformat()
|
||||
if field == "expiration_date":
|
||||
days = (dt - datetime.now(timezone.utc)).days
|
||||
result["expiration_days_remaining"] = days
|
||||
result["is_expired"] = days < 0
|
||||
break
|
||||
except ValueError:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
# ─── DNS Records ───────────────────────────────────────────────────────────
|
||||
|
||||
def dns_records(domain, types=None):
|
||||
"""Resolve DNS records using system DNS + Google DoH."""
|
||||
if not types:
|
||||
types = ["A", "AAAA", "MX", "NS", "TXT", "CNAME"]
|
||||
records = {}
|
||||
|
||||
for qtype in types:
|
||||
if qtype == "A":
|
||||
try:
|
||||
records["A"] = list(dict.fromkeys(
|
||||
i[4][0] for i in socket.getaddrinfo(domain, None, socket.AF_INET)
|
||||
))
|
||||
except Exception:
|
||||
records["A"] = []
|
||||
elif qtype == "AAAA":
|
||||
try:
|
||||
records["AAAA"] = list(dict.fromkeys(
|
||||
i[4][0] for i in socket.getaddrinfo(domain, None, socket.AF_INET6)
|
||||
))
|
||||
except Exception:
|
||||
records["AAAA"] = []
|
||||
else:
|
||||
url = f"https://dns.google/resolve?name={urllib.parse.quote(domain)}&type={qtype}"
|
||||
try:
|
||||
req = urllib.request.Request(url, headers={"User-Agent": "domain-intel-skill/1.0"})
|
||||
with urllib.request.urlopen(req, timeout=10) as r:
|
||||
data = json.loads(r.read())
|
||||
records[qtype] = [
|
||||
a.get("data", "").strip().rstrip(".")
|
||||
for a in data.get("Answer", []) if a.get("data")
|
||||
]
|
||||
except Exception:
|
||||
records[qtype] = []
|
||||
|
||||
return {"domain": domain, "records": records}
|
||||
|
||||
|
||||
# ─── Domain Availability Check ─────────────────────────────────────────────
|
||||
|
||||
def check_available(domain):
|
||||
"""Check domain availability using passive signals (DNS + WHOIS + SSL)."""
|
||||
signals = {}
|
||||
|
||||
# DNS
|
||||
try:
|
||||
a = [i[4][0] for i in socket.getaddrinfo(domain, None, socket.AF_INET)]
|
||||
except Exception:
|
||||
a = []
|
||||
|
||||
try:
|
||||
ns_url = f"https://dns.google/resolve?name={urllib.parse.quote(domain)}&type=NS"
|
||||
req = urllib.request.Request(ns_url, headers={"User-Agent": "domain-intel-skill/1.0"})
|
||||
with urllib.request.urlopen(req, timeout=10) as r:
|
||||
ns = [x.get("data", "") for x in json.loads(r.read()).get("Answer", [])]
|
||||
except Exception:
|
||||
ns = []
|
||||
|
||||
signals["dns_a"] = a
|
||||
signals["dns_ns"] = ns
|
||||
dns_exists = bool(a or ns)
|
||||
|
||||
# SSL
|
||||
ssl_up = False
|
||||
try:
|
||||
ctx = ssl.create_default_context()
|
||||
ctx.check_hostname = False
|
||||
ctx.verify_mode = ssl.CERT_NONE
|
||||
with socket.create_connection((domain, 443), timeout=3) as s:
|
||||
with ctx.wrap_socket(s, server_hostname=domain):
|
||||
ssl_up = True
|
||||
except Exception:
|
||||
pass
|
||||
signals["ssl_reachable"] = ssl_up
|
||||
|
||||
# WHOIS (quick check)
|
||||
tld = domain.rsplit(".", 1)[-1]
|
||||
server = WHOIS_SERVERS.get(tld)
|
||||
whois_avail = None
|
||||
whois_note = ""
|
||||
if server:
|
||||
try:
|
||||
with socket.create_connection((server, 43), timeout=10) as s:
|
||||
s.sendall((domain + "\r\n").encode())
|
||||
raw = b""
|
||||
while True:
|
||||
c = s.recv(4096)
|
||||
if not c:
|
||||
break
|
||||
raw += c
|
||||
raw = raw.decode("utf-8", errors="replace").lower()
|
||||
if any(p in raw for p in ["no match", "not found", "no data found", "status: free"]):
|
||||
whois_avail = True
|
||||
whois_note = "WHOIS: not found"
|
||||
elif "registrar:" in raw or "creation date:" in raw:
|
||||
whois_avail = False
|
||||
whois_note = "WHOIS: registered"
|
||||
else:
|
||||
whois_note = "WHOIS: inconclusive"
|
||||
except Exception as e:
|
||||
whois_note = f"WHOIS error: {e}"
|
||||
|
||||
signals["whois_available"] = whois_avail
|
||||
signals["whois_note"] = whois_note
|
||||
|
||||
if not dns_exists and whois_avail is True:
|
||||
verdict, conf = "LIKELY AVAILABLE", "high"
|
||||
elif dns_exists or whois_avail is False or ssl_up:
|
||||
verdict, conf = "REGISTERED / IN USE", "high"
|
||||
elif not dns_exists and whois_avail is None:
|
||||
verdict, conf = "POSSIBLY AVAILABLE", "medium"
|
||||
else:
|
||||
verdict, conf = "UNCERTAIN", "low"
|
||||
|
||||
return {"domain": domain, "verdict": verdict, "confidence": conf, "signals": signals}
|
||||
|
||||
|
||||
# ─── Bulk Analysis ─────────────────────────────────────────────────────────
|
||||
|
||||
COMMAND_MAP = {
|
||||
"subdomains": subdomains,
|
||||
"ssl": check_ssl,
|
||||
"whois": whois_lookup,
|
||||
"dns": dns_records,
|
||||
"available": check_available,
|
||||
}
|
||||
|
||||
|
||||
def bulk_check(domains, checks=None, max_workers=5):
|
||||
"""Run multiple checks across multiple domains in parallel."""
|
||||
if not checks:
|
||||
checks = ["ssl", "whois", "dns"]
|
||||
|
||||
def run_one(d):
|
||||
entry = {"domain": d}
|
||||
for check in checks:
|
||||
fn = COMMAND_MAP.get(check)
|
||||
if fn:
|
||||
try:
|
||||
entry[check] = fn(d)
|
||||
except Exception as e:
|
||||
entry[check] = {"error": str(e)}
|
||||
return entry
|
||||
|
||||
results = []
|
||||
with ThreadPoolExecutor(max_workers=min(max_workers, 10)) as ex:
|
||||
futures = {ex.submit(run_one, d): d for d in domains[:20]}
|
||||
for f in as_completed(futures):
|
||||
results.append(f.result())
|
||||
|
||||
return {"total": len(results), "checks": checks, "results": results}
|
||||
|
||||
|
||||
# ─── CLI Entry Point ───────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 3:
|
||||
print(__doc__)
|
||||
sys.exit(1)
|
||||
|
||||
command = sys.argv[1].lower()
|
||||
args = sys.argv[2:]
|
||||
|
||||
if command == "bulk":
|
||||
# Parse --checks flag
|
||||
checks = None
|
||||
domains = []
|
||||
i = 0
|
||||
while i < len(args):
|
||||
if args[i] == "--checks" and i + 1 < len(args):
|
||||
checks = [c.strip() for c in args[i + 1].split(",")]
|
||||
i += 2
|
||||
else:
|
||||
domains.append(args[i])
|
||||
i += 1
|
||||
result = bulk_check(domains, checks)
|
||||
elif command in COMMAND_MAP:
|
||||
result = COMMAND_MAP[command](args[0])
|
||||
else:
|
||||
print(f"Unknown command: {command}")
|
||||
print(f"Available: {', '.join(COMMAND_MAP.keys())}, bulk")
|
||||
sys.exit(1)
|
||||
|
||||
print(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
237
optional-skills/research/duckduckgo-search/SKILL.md
Normal file
237
optional-skills/research/duckduckgo-search/SKILL.md
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
---
|
||||
name: duckduckgo-search
|
||||
description: Free web search via DuckDuckGo — text, news, images, videos. No API key needed. Prefer the `ddgs` CLI when installed; use the Python DDGS library only after verifying that `ddgs` is available in the current runtime.
|
||||
version: 1.3.0
|
||||
author: gamedevCloudy
|
||||
license: MIT
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [search, duckduckgo, web-search, free, fallback]
|
||||
related_skills: [arxiv]
|
||||
fallback_for_toolsets: [web]
|
||||
---
|
||||
|
||||
# DuckDuckGo Search
|
||||
|
||||
Free web search using DuckDuckGo. **No API key required.**
|
||||
|
||||
Preferred when `web_search` is unavailable or unsuitable (for example when `FIRECRAWL_API_KEY` is not set). Can also be used as a standalone search path when DuckDuckGo results are specifically desired.
|
||||
|
||||
## Detection Flow
|
||||
|
||||
Check what is actually available before choosing an approach:
|
||||
|
||||
```bash
|
||||
# Check CLI availability
|
||||
command -v ddgs >/dev/null && echo "DDGS_CLI=installed" || echo "DDGS_CLI=missing"
|
||||
```
|
||||
|
||||
Decision tree:
|
||||
1. If `ddgs` CLI is installed, prefer `terminal` + `ddgs`
|
||||
2. If `ddgs` CLI is missing, do not assume `execute_code` can import `ddgs`
|
||||
3. If the user wants DuckDuckGo specifically, install `ddgs` first in the relevant environment
|
||||
4. Otherwise fall back to built-in web/browser tools
|
||||
|
||||
Important runtime note:
|
||||
- Terminal and `execute_code` are separate runtimes
|
||||
- A successful shell install does not guarantee `execute_code` can import `ddgs`
|
||||
- Never assume third-party Python packages are preinstalled inside `execute_code`
|
||||
|
||||
## Installation
|
||||
|
||||
Install `ddgs` only when DuckDuckGo search is specifically needed and the runtime does not already provide it.
|
||||
|
||||
```bash
|
||||
# Python package + CLI entrypoint
|
||||
pip install ddgs
|
||||
|
||||
# Verify CLI
|
||||
ddgs --help
|
||||
```
|
||||
|
||||
If a workflow depends on Python imports, verify that same runtime can import `ddgs` before using `from ddgs import DDGS`.
|
||||
|
||||
## Method 1: CLI Search (Preferred)
|
||||
|
||||
Use the `ddgs` command via `terminal` when it exists. This is the preferred path because it avoids assuming the `execute_code` sandbox has the `ddgs` Python package installed.
|
||||
|
||||
```bash
|
||||
# Text search
|
||||
ddgs text -k "python async programming" -m 5
|
||||
|
||||
# News search
|
||||
ddgs news -k "artificial intelligence" -m 5
|
||||
|
||||
# Image search
|
||||
ddgs images -k "landscape photography" -m 10
|
||||
|
||||
# Video search
|
||||
ddgs videos -k "python tutorial" -m 5
|
||||
|
||||
# With region filter
|
||||
ddgs text -k "best restaurants" -m 5 -r us-en
|
||||
|
||||
# Recent results only (d=day, w=week, m=month, y=year)
|
||||
ddgs text -k "latest AI news" -m 5 -t w
|
||||
|
||||
# JSON output for parsing
|
||||
ddgs text -k "fastapi tutorial" -m 5 -o json
|
||||
```
|
||||
|
||||
### CLI Flags
|
||||
|
||||
| Flag | Description | Example |
|
||||
|------|-------------|---------|
|
||||
| `-k` | Keywords (query) — **required** | `-k "search terms"` |
|
||||
| `-m` | Max results | `-m 5` |
|
||||
| `-r` | Region | `-r us-en` |
|
||||
| `-t` | Time limit | `-t w` (week) |
|
||||
| `-s` | Safe search | `-s off` |
|
||||
| `-o` | Output format | `-o json` |
|
||||
|
||||
## Method 2: Python API (Only After Verification)
|
||||
|
||||
Use the `DDGS` class in `execute_code` or another Python runtime only after verifying that `ddgs` is installed there. Do not assume `execute_code` includes third-party packages by default.
|
||||
|
||||
Safe wording:
|
||||
- "Use `execute_code` with `ddgs` after installing or verifying the package if needed"
|
||||
|
||||
Avoid saying:
|
||||
- "`execute_code` includes `ddgs`"
|
||||
- "DuckDuckGo search works by default in `execute_code`"
|
||||
|
||||
**Important:** `max_results` must always be passed as a **keyword argument** — positional usage raises an error on all methods.
|
||||
|
||||
### Text Search
|
||||
|
||||
Best for: general research, companies, documentation.
|
||||
|
||||
```python
|
||||
from ddgs import DDGS
|
||||
|
||||
with DDGS() as ddgs:
|
||||
for r in ddgs.text("python async programming", max_results=5):
|
||||
print(r["title"])
|
||||
print(r["href"])
|
||||
print(r.get("body", "")[:200])
|
||||
print()
|
||||
```
|
||||
|
||||
Returns: `title`, `href`, `body`
|
||||
|
||||
### News Search
|
||||
|
||||
Best for: current events, breaking news, latest updates.
|
||||
|
||||
```python
|
||||
from ddgs import DDGS
|
||||
|
||||
with DDGS() as ddgs:
|
||||
for r in ddgs.news("AI regulation 2026", max_results=5):
|
||||
print(r["date"], "-", r["title"])
|
||||
print(r.get("source", ""), "|", r["url"])
|
||||
print(r.get("body", "")[:200])
|
||||
print()
|
||||
```
|
||||
|
||||
Returns: `date`, `title`, `body`, `url`, `image`, `source`
|
||||
|
||||
### Image Search
|
||||
|
||||
Best for: visual references, product images, diagrams.
|
||||
|
||||
```python
|
||||
from ddgs import DDGS
|
||||
|
||||
with DDGS() as ddgs:
|
||||
for r in ddgs.images("semiconductor chip", max_results=5):
|
||||
print(r["title"])
|
||||
print(r["image"])
|
||||
print(r.get("thumbnail", ""))
|
||||
print(r.get("source", ""))
|
||||
print()
|
||||
```
|
||||
|
||||
Returns: `title`, `image`, `thumbnail`, `url`, `height`, `width`, `source`
|
||||
|
||||
### Video Search
|
||||
|
||||
Best for: tutorials, demos, explainers.
|
||||
|
||||
```python
|
||||
from ddgs import DDGS
|
||||
|
||||
with DDGS() as ddgs:
|
||||
for r in ddgs.videos("FastAPI tutorial", max_results=5):
|
||||
print(r["title"])
|
||||
print(r.get("content", ""))
|
||||
print(r.get("duration", ""))
|
||||
print(r.get("provider", ""))
|
||||
print(r.get("published", ""))
|
||||
print()
|
||||
```
|
||||
|
||||
Returns: `title`, `content`, `description`, `duration`, `provider`, `published`, `statistics`, `uploader`
|
||||
|
||||
### Quick Reference
|
||||
|
||||
| Method | Use When | Key Fields |
|
||||
|--------|----------|------------|
|
||||
| `text()` | General research, companies | title, href, body |
|
||||
| `news()` | Current events, updates | date, title, source, body, url |
|
||||
| `images()` | Visuals, diagrams | title, image, thumbnail, url |
|
||||
| `videos()` | Tutorials, demos | title, content, duration, provider |
|
||||
|
||||
## Workflow: Search then Extract
|
||||
|
||||
DuckDuckGo returns titles, URLs, and snippets — not full page content. To get full page content, search first and then extract the most relevant URL with `web_extract`, browser tools, or curl.
|
||||
|
||||
CLI example:
|
||||
|
||||
```bash
|
||||
ddgs text -k "fastapi deployment guide" -m 3 -o json
|
||||
```
|
||||
|
||||
Python example, only after verifying `ddgs` is installed in that runtime:
|
||||
|
||||
```python
|
||||
from ddgs import DDGS
|
||||
|
||||
with DDGS() as ddgs:
|
||||
results = list(ddgs.text("fastapi deployment guide", max_results=3))
|
||||
for r in results:
|
||||
print(r["title"], "->", r["href"])
|
||||
```
|
||||
|
||||
Then extract the best URL with `web_extract` or another content-retrieval tool.
|
||||
|
||||
## Limitations
|
||||
|
||||
- **Rate limiting**: DuckDuckGo may throttle after many rapid requests. Add a short delay between searches if needed.
|
||||
- **No content extraction**: `ddgs` returns snippets, not full page content. Use `web_extract`, browser tools, or curl for the full article/page.
|
||||
- **Results quality**: Generally good but less configurable than Firecrawl's search.
|
||||
- **Availability**: DuckDuckGo may block requests from some cloud IPs. If searches return empty, try different keywords or wait a few seconds.
|
||||
- **Field variability**: Return fields may vary between results or `ddgs` versions. Use `.get()` for optional fields to avoid `KeyError`.
|
||||
- **Separate runtimes**: A successful `ddgs` install in terminal does not automatically mean `execute_code` can import it.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Problem | Likely Cause | What To Do |
|
||||
|---------|--------------|------------|
|
||||
| `ddgs: command not found` | CLI not installed in the shell environment | Install `ddgs`, or use built-in web/browser tools instead |
|
||||
| `ModuleNotFoundError: No module named 'ddgs'` | Python runtime does not have the package installed | Do not use Python DDGS there until that runtime is prepared |
|
||||
| Search returns nothing | Temporary rate limiting or poor query | Wait a few seconds, retry, or adjust the query |
|
||||
| CLI works but `execute_code` import fails | Terminal and `execute_code` are different runtimes | Keep using CLI, or separately prepare the Python runtime |
|
||||
|
||||
## Pitfalls
|
||||
|
||||
- **`max_results` is keyword-only**: `ddgs.text("query", 5)` raises an error. Use `ddgs.text("query", max_results=5)`.
|
||||
- **Do not assume the CLI exists**: Check `command -v ddgs` before using it.
|
||||
- **Do not assume `execute_code` can import `ddgs`**: `from ddgs import DDGS` may fail with `ModuleNotFoundError` unless that runtime was prepared separately.
|
||||
- **Package name**: The package is `ddgs` (previously `duckduckgo-search`). Install with `pip install ddgs`.
|
||||
- **Don't confuse `-k` and `-m`** (CLI): `-k` is for keywords, `-m` is for max results count.
|
||||
- **Empty results**: If `ddgs` returns nothing, it may be rate-limited. Wait a few seconds and retry.
|
||||
|
||||
## Validated With
|
||||
|
||||
Validated examples against `ddgs==9.11.2` semantics. Skill guidance now treats CLI availability and Python import availability as separate concerns so the documented workflow matches actual runtime behavior.
|
||||
28
optional-skills/research/duckduckgo-search/scripts/duckduckgo.sh
Executable file
28
optional-skills/research/duckduckgo-search/scripts/duckduckgo.sh
Executable file
|
|
@ -0,0 +1,28 @@
|
|||
#!/bin/bash
|
||||
# DuckDuckGo Search Helper Script
|
||||
# Wrapper around ddgs CLI with sensible defaults
|
||||
# Usage: ./duckduckgo.sh <query> [max_results]
|
||||
|
||||
set -e
|
||||
|
||||
QUERY="$1"
|
||||
MAX_RESULTS="${2:-5}"
|
||||
|
||||
if [ -z "$QUERY" ]; then
|
||||
echo "Usage: $0 <query> [max_results]"
|
||||
echo ""
|
||||
echo "Examples:"
|
||||
echo " $0 'python async programming' 5"
|
||||
echo " $0 'latest AI news' 10"
|
||||
echo ""
|
||||
echo "Requires: pip install ddgs"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if ddgs is available
|
||||
if ! command -v ddgs &> /dev/null; then
|
||||
echo "Error: ddgs not found. Install with: pip install ddgs"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ddgs text -k "$QUERY" -m "$MAX_RESULTS"
|
||||
Loading…
Add table
Add a link
Reference in a new issue